diff --git a/include/net/sock.h b/include/net/sock.h
index 97fc0ad47da053ce276274e0ba5de5b7aaaa3063..0e7a9b05f92bb891e1e766e034ff3d6d5af019fc 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -922,7 +922,7 @@ inline void sk_refcnt_debug_release(const struct sock *sk)
 #define sk_refcnt_debug_release(sk) do { } while (0)
 #endif /* SOCK_REFCNT_DEBUG */
 
-#ifdef CONFIG_CGROUP_MEM_RES_CTLR_KMEM
+#if defined(CONFIG_CGROUP_MEM_RES_CTLR_KMEM) && defined(CONFIG_NET)
 extern struct jump_label_key memcg_socket_limit_enabled;
 static inline struct cg_proto *parent_cg_proto(struct proto *proto,
 					       struct cg_proto *cg_proto)
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 3dbff4dcde35191a62b0245ac7ee185bc8fc7084..c3688dfd9a5fefe695e4bccbbb4be51c284d9aea 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -379,7 +379,7 @@ static void mem_cgroup_put(struct mem_cgroup *memcg);
 static bool mem_cgroup_is_root(struct mem_cgroup *memcg);
 void sock_update_memcg(struct sock *sk)
 {
-	if (static_branch(&memcg_socket_limit_enabled)) {
+	if (mem_cgroup_sockets_enabled) {
 		struct mem_cgroup *memcg;
 
 		BUG_ON(!sk->sk_prot->proto_cgroup);
@@ -411,7 +411,7 @@ EXPORT_SYMBOL(sock_update_memcg);
 
 void sock_release_memcg(struct sock *sk)
 {
-	if (static_branch(&memcg_socket_limit_enabled) && sk->sk_cgrp) {
+	if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
 		struct mem_cgroup *memcg;
 		WARN_ON(!sk->sk_cgrp->memcg);
 		memcg = sk->sk_cgrp->memcg;