diff --git a/include/net/sock.h b/include/net/sock.h
index d89f0582b6b6f1a907d108712ddf44eaa13b08d6..4a45216995635cccc4a919b5e5506a87fe90a49c 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -46,6 +46,7 @@
 #include <linux/list_nulls.h>
 #include <linux/timer.h>
 #include <linux/cache.h>
+#include <linux/bitops.h>
 #include <linux/lockdep.h>
 #include <linux/netdevice.h>
 #include <linux/skbuff.h>	/* struct sk_buff */
@@ -921,12 +922,23 @@ struct proto {
 #endif
 };
 
+/*
+ * Bits in struct cg_proto.flags
+ */
+enum cg_proto_flags {
+	/* Currently active and new sockets should be assigned to cgroups */
+	MEMCG_SOCK_ACTIVE,
+	/* It was ever activated; we must disarm static keys on destruction */
+	MEMCG_SOCK_ACTIVATED,
+};
+
 struct cg_proto {
 	void			(*enter_memory_pressure)(struct sock *sk);
 	struct res_counter	*memory_allocated;	/* Current allocated memory. */
 	struct percpu_counter	*sockets_allocated;	/* Current number of sockets. */
 	int			*memory_pressure;
 	long			*sysctl_mem;
+	unsigned long		flags;
 	/*
 	 * memcg field is used to find which memcg we belong directly
 	 * Each memcg struct can hold more than one cg_proto, so container_of
@@ -942,6 +954,16 @@ struct cg_proto {
 extern int proto_register(struct proto *prot, int alloc_slab);
 extern void proto_unregister(struct proto *prot);
 
+static inline bool memcg_proto_active(struct cg_proto *cg_proto)
+{
+	return test_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags);
+}
+
+static inline bool memcg_proto_activated(struct cg_proto *cg_proto)
+{
+	return test_bit(MEMCG_SOCK_ACTIVATED, &cg_proto->flags);
+}
+
 #ifdef SOCK_REFCNT_DEBUG
 static inline void sk_refcnt_debug_inc(struct sock *sk)
 {
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 6fbf50977f777a985a2306464a9b798b7b0c3c3c..ac35bccadb7b9f53606d445a961e442e891aa94a 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -417,6 +417,7 @@ void sock_update_memcg(struct sock *sk)
 {
 	if (mem_cgroup_sockets_enabled) {
 		struct mem_cgroup *memcg;
+		struct cg_proto *cg_proto;
 
 		BUG_ON(!sk->sk_prot->proto_cgroup);
 
@@ -436,9 +437,10 @@ void sock_update_memcg(struct sock *sk)
 
 		rcu_read_lock();
 		memcg = mem_cgroup_from_task(current);
-		if (!mem_cgroup_is_root(memcg)) {
+		cg_proto = sk->sk_prot->proto_cgroup(memcg);
+		if (!mem_cgroup_is_root(memcg) && memcg_proto_active(cg_proto)) {
 			mem_cgroup_get(memcg);
-			sk->sk_cgrp = sk->sk_prot->proto_cgroup(memcg);
+			sk->sk_cgrp = cg_proto;
 		}
 		rcu_read_unlock();
 	}
@@ -467,6 +469,19 @@ EXPORT_SYMBOL(tcp_proto_cgroup);
 #endif /* CONFIG_INET */
 #endif /* CONFIG_CGROUP_MEM_RES_CTLR_KMEM */
 
+#if defined(CONFIG_INET) && defined(CONFIG_CGROUP_MEM_RES_CTLR_KMEM)
+static void disarm_sock_keys(struct mem_cgroup *memcg)
+{
+	if (!memcg_proto_activated(&memcg->tcp_mem.cg_proto))
+		return;
+	static_key_slow_dec(&memcg_socket_limit_enabled);
+}
+#else
+static void disarm_sock_keys(struct mem_cgroup *memcg)
+{
+}
+#endif
+
 static void drain_all_stock_async(struct mem_cgroup *memcg);
 
 static struct mem_cgroup_per_zone *
@@ -4712,6 +4727,18 @@ static void free_work(struct work_struct *work)
 	int size = sizeof(struct mem_cgroup);
 
 	memcg = container_of(work, struct mem_cgroup, work_freeing);
+	/*
+	 * We need to make sure that (at least for now), the jump label
+	 * destruction code runs outside of the cgroup lock. This is because
+	 * get_online_cpus(), which is called from the static_branch update,
+	 * can't be called inside the cgroup_lock. cpusets are the ones
+	 * enforcing this dependency, so if they ever change, we might as well.
+	 *
+	 * schedule_work() will guarantee this happens. Be careful if you need
+	 * to move this code around, and make sure it is outside
+	 * the cgroup_lock.
+	 */
+	disarm_sock_keys(memcg);
 	if (size < PAGE_SIZE)
 		kfree(memcg);
 	else
diff --git a/net/ipv4/tcp_memcontrol.c b/net/ipv4/tcp_memcontrol.c
index 151703791bb0d43818700349fb0a585736c7dc19..b6f3583ddfe83eae73370c9070c567e452518f57 100644
--- a/net/ipv4/tcp_memcontrol.c
+++ b/net/ipv4/tcp_memcontrol.c
@@ -74,9 +74,6 @@ void tcp_destroy_cgroup(struct mem_cgroup *memcg)
 	percpu_counter_destroy(&tcp->tcp_sockets_allocated);
 
 	val = res_counter_read_u64(&tcp->tcp_memory_allocated, RES_LIMIT);
-
-	if (val != RESOURCE_MAX)
-		static_key_slow_dec(&memcg_socket_limit_enabled);
 }
 EXPORT_SYMBOL(tcp_destroy_cgroup);
 
@@ -107,10 +104,33 @@ static int tcp_update_limit(struct mem_cgroup *memcg, u64 val)
 		tcp->tcp_prot_mem[i] = min_t(long, val >> PAGE_SHIFT,
 					     net->ipv4.sysctl_tcp_mem[i]);
 
-	if (val == RESOURCE_MAX && old_lim != RESOURCE_MAX)
-		static_key_slow_dec(&memcg_socket_limit_enabled);
-	else if (old_lim == RESOURCE_MAX && val != RESOURCE_MAX)
-		static_key_slow_inc(&memcg_socket_limit_enabled);
+	if (val == RESOURCE_MAX)
+		clear_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags);
+	else if (val != RESOURCE_MAX) {
+		/*
+		 * The active bit needs to be written after the static_key
+		 * update. This is what guarantees that the socket activation
+		 * function is the last one to run. See sock_update_memcg() for
+		 * details, and note that we don't mark any socket as belonging
+		 * to this memcg until that flag is up.
+		 *
+		 * We need to do this, because static_keys will span multiple
+		 * sites, but we can't control their order. If we mark a socket
+		 * as accounted, but the accounting functions are not patched in
+		 * yet, we'll lose accounting.
+		 *
+		 * We never race with the readers in sock_update_memcg(),
+		 * because when this value change, the code to process it is not
+		 * patched in yet.
+		 *
+		 * The activated bit is used to guarantee that no two writers
+		 * will do the update in the same memcg. Without that, we can't
+		 * properly shutdown the static key.
+		 */
+		if (!test_and_set_bit(MEMCG_SOCK_ACTIVATED, &cg_proto->flags))
+			static_key_slow_inc(&memcg_socket_limit_enabled);
+		set_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags);
+	}
 
 	return 0;
 }