diff --git a/include/net/sock.h b/include/net/sock.h
index 81198632ac2a35e508d301936392ca8959e42189..43a470d40d76194f0544ce656e190d2d527c1f81 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -754,8 +754,13 @@ static inline __must_check int sk_add_backlog(struct sock *sk, struct sk_buff *s
 	return 0;
 }
 
+extern int __sk_backlog_rcv(struct sock *sk, struct sk_buff *skb);
+
 static inline int sk_backlog_rcv(struct sock *sk, struct sk_buff *skb)
 {
+	if (sk_memalloc_socks() && skb_pfmemalloc(skb))
+		return __sk_backlog_rcv(sk, skb);
+
 	return sk->sk_backlog_rcv(sk, skb);
 }
 
diff --git a/net/core/dev.c b/net/core/dev.c
index 0ebaea16632fc348f6a48789831b737c7c51b707..ce132443d5d12609e44070bb4bb10de10239c848 100644
--- a/net/core/dev.c
+++ b/net/core/dev.c
@@ -3155,6 +3155,23 @@ void netdev_rx_handler_unregister(struct net_device *dev)
 }
 EXPORT_SYMBOL_GPL(netdev_rx_handler_unregister);
 
+/*
+ * Limit the use of PFMEMALLOC reserves to those protocols that implement
+ * the special handling of PFMEMALLOC skbs.
+ */
+static bool skb_pfmemalloc_protocol(struct sk_buff *skb)
+{
+	switch (skb->protocol) {
+	case __constant_htons(ETH_P_ARP):
+	case __constant_htons(ETH_P_IP):
+	case __constant_htons(ETH_P_IPV6):
+	case __constant_htons(ETH_P_8021Q):
+		return true;
+	default:
+		return false;
+	}
+}
+
 static int __netif_receive_skb(struct sk_buff *skb)
 {
 	struct packet_type *ptype, *pt_prev;
@@ -3164,14 +3181,27 @@ static int __netif_receive_skb(struct sk_buff *skb)
 	bool deliver_exact = false;
 	int ret = NET_RX_DROP;
 	__be16 type;
+	unsigned long pflags = current->flags;
 
 	net_timestamp_check(!netdev_tstamp_prequeue, skb);
 
 	trace_netif_receive_skb(skb);
 
+	/*
+	 * PFMEMALLOC skbs are special, they should
+	 * - be delivered to SOCK_MEMALLOC sockets only
+	 * - stay away from userspace
+	 * - have bounded memory usage
+	 *
+	 * Use PF_MEMALLOC as this saves us from propagating the allocation
+	 * context down to all allocation sites.
+	 */
+	if (sk_memalloc_socks() && skb_pfmemalloc(skb))
+		current->flags |= PF_MEMALLOC;
+
 	/* if we've gotten here through NAPI, check netpoll */
 	if (netpoll_receive_skb(skb))
-		return NET_RX_DROP;
+		goto out;
 
 	orig_dev = skb->dev;
 
@@ -3191,7 +3221,7 @@ static int __netif_receive_skb(struct sk_buff *skb)
 	if (skb->protocol == cpu_to_be16(ETH_P_8021Q)) {
 		skb = vlan_untag(skb);
 		if (unlikely(!skb))
-			goto out;
+			goto unlock;
 	}
 
 #ifdef CONFIG_NET_CLS_ACT
@@ -3201,6 +3231,9 @@ static int __netif_receive_skb(struct sk_buff *skb)
 	}
 #endif
 
+	if (sk_memalloc_socks() && skb_pfmemalloc(skb))
+		goto skip_taps;
+
 	list_for_each_entry_rcu(ptype, &ptype_all, list) {
 		if (!ptype->dev || ptype->dev == skb->dev) {
 			if (pt_prev)
@@ -3209,13 +3242,18 @@ static int __netif_receive_skb(struct sk_buff *skb)
 		}
 	}
 
+skip_taps:
 #ifdef CONFIG_NET_CLS_ACT
 	skb = handle_ing(skb, &pt_prev, &ret, orig_dev);
 	if (!skb)
-		goto out;
+		goto unlock;
 ncls:
 #endif
 
+	if (sk_memalloc_socks() && skb_pfmemalloc(skb)
+				&& !skb_pfmemalloc_protocol(skb))
+		goto drop;
+
 	rx_handler = rcu_dereference(skb->dev->rx_handler);
 	if (vlan_tx_tag_present(skb)) {
 		if (pt_prev) {
@@ -3225,7 +3263,7 @@ static int __netif_receive_skb(struct sk_buff *skb)
 		if (vlan_do_receive(&skb, !rx_handler))
 			goto another_round;
 		else if (unlikely(!skb))
-			goto out;
+			goto unlock;
 	}
 
 	if (rx_handler) {
@@ -3235,7 +3273,7 @@ static int __netif_receive_skb(struct sk_buff *skb)
 		}
 		switch (rx_handler(&skb)) {
 		case RX_HANDLER_CONSUMED:
-			goto out;
+			goto unlock;
 		case RX_HANDLER_ANOTHER:
 			goto another_round;
 		case RX_HANDLER_EXACT:
@@ -3268,6 +3306,7 @@ static int __netif_receive_skb(struct sk_buff *skb)
 		else
 			ret = pt_prev->func(skb, skb->dev, pt_prev, orig_dev);
 	} else {
+drop:
 		atomic_long_inc(&skb->dev->rx_dropped);
 		kfree_skb(skb);
 		/* Jamal, now you will not able to escape explaining
@@ -3276,8 +3315,10 @@ static int __netif_receive_skb(struct sk_buff *skb)
 		ret = NET_RX_DROP;
 	}
 
-out:
+unlock:
 	rcu_read_unlock();
+out:
+	tsk_restore_flags(current, pflags, PF_MEMALLOC);
 	return ret;
 }
 
diff --git a/net/core/sock.c b/net/core/sock.c
index c8c5816289fed78279b2169fd00e375b4c2315d4..32fdcd2d6e8f60760ecf57a2aaa14a6a931b28cc 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -298,6 +298,22 @@ void sk_clear_memalloc(struct sock *sk)
 }
 EXPORT_SYMBOL_GPL(sk_clear_memalloc);
 
+int __sk_backlog_rcv(struct sock *sk, struct sk_buff *skb)
+{
+	int ret;
+	unsigned long pflags = current->flags;
+
+	/* these should have been dropped before queueing */
+	BUG_ON(!sock_flag(sk, SOCK_MEMALLOC));
+
+	current->flags |= PF_MEMALLOC;
+	ret = sk->sk_backlog_rcv(sk, skb);
+	tsk_restore_flags(current, pflags, PF_MEMALLOC);
+
+	return ret;
+}
+EXPORT_SYMBOL(__sk_backlog_rcv);
+
 #if defined(CONFIG_CGROUPS)
 #if !defined(CONFIG_NET_CLS_CGROUP)
 int net_cls_subsys_id = -1;