diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 26a38b7c7739d36c5ed734ec17f0dfcaecc2bd2e..408a5c75d77d3dc309b4dbf37589e8574b0e261a 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -152,12 +152,15 @@ struct mem_cgroup_stat_cpu {
 };
 
 struct mem_cgroup_reclaim_iter {
-	/* last scanned hierarchy member with elevated css ref count */
+	/*
+	 * last scanned hierarchy member. Valid only if last_dead_count
+	 * matches memcg->dead_count of the hierarchy root group.
+	 */
 	struct mem_cgroup *last_visited;
+	unsigned long last_dead_count;
+
 	/* scan generation, increased every round-trip */
 	unsigned int generation;
-	/* lock to protect the position and generation */
-	spinlock_t iter_lock;
 };
 
 /*
@@ -337,6 +340,7 @@ struct mem_cgroup {
 	struct mem_cgroup_stat_cpu nocpu_base;
 	spinlock_t pcp_counter_lock;
 
+	atomic_t	dead_count;
 #if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_INET)
 	struct tcp_memcontrol tcp_mem;
 #endif
@@ -1092,6 +1096,7 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
 {
 	struct mem_cgroup *memcg = NULL;
 	struct mem_cgroup *last_visited = NULL;
+	unsigned long uninitialized_var(dead_count);
 
 	if (mem_cgroup_disabled())
 		return NULL;
@@ -1120,16 +1125,33 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
 
 			mz = mem_cgroup_zoneinfo(root, nid, zid);
 			iter = &mz->reclaim_iter[reclaim->priority];
-			spin_lock(&iter->iter_lock);
 			last_visited = iter->last_visited;
 			if (prev && reclaim->generation != iter->generation) {
-				if (last_visited) {
-					css_put(&last_visited->css);
-					iter->last_visited = NULL;
-				}
-				spin_unlock(&iter->iter_lock);
+				iter->last_visited = NULL;
 				goto out_unlock;
 			}
+
+			/*
+			 * If the dead_count mismatches, a destruction
+			 * has happened or is happening concurrently.
+			 * If the dead_count matches, a destruction
+			 * might still happen concurrently, but since
+			 * we checked under RCU, that destruction
+			 * won't free the object until we release the
+			 * RCU reader lock.  Thus, the dead_count
+			 * check verifies the pointer is still valid,
+			 * css_tryget() verifies the cgroup pointed to
+			 * is alive.
+			 */
+			dead_count = atomic_read(&root->dead_count);
+			smp_rmb();
+			last_visited = iter->last_visited;
+			if (last_visited) {
+				if ((dead_count != iter->last_dead_count) ||
+					!css_tryget(&last_visited->css)) {
+					last_visited = NULL;
+				}
+			}
 		}
 
 		/*
@@ -1169,16 +1191,14 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
 			if (css && !memcg)
 				curr = mem_cgroup_from_css(css);
 
-			/* make sure that the cached memcg is not removed */
-			if (curr)
-				css_get(&curr->css);
 			iter->last_visited = curr;
+			smp_wmb();
+			iter->last_dead_count = dead_count;
 
 			if (!css)
 				iter->generation++;
 			else if (!prev && memcg)
 				reclaim->generation = iter->generation;
-			spin_unlock(&iter->iter_lock);
 		} else if (css && !memcg) {
 			last_visited = mem_cgroup_from_css(css);
 		}
@@ -5975,12 +5995,8 @@ static int alloc_mem_cgroup_per_zone_info(struct mem_cgroup *memcg, int node)
 		return 1;
 
 	for (zone = 0; zone < MAX_NR_ZONES; zone++) {
-		int prio;
-
 		mz = &pn->zoneinfo[zone];
 		lruvec_init(&mz->lruvec);
-		for (prio = 0; prio < DEF_PRIORITY + 1; prio++)
-			spin_lock_init(&mz->reclaim_iter[prio].iter_lock);
 		mz->usage_in_excess = 0;
 		mz->on_tree = false;
 		mz->memcg = memcg;
@@ -6235,10 +6251,29 @@ mem_cgroup_css_online(struct cgroup *cont)
 	return error;
 }
 
+/*
+ * Announce all parents that a group from their hierarchy is gone.
+ */
+static void mem_cgroup_invalidate_reclaim_iterators(struct mem_cgroup *memcg)
+{
+	struct mem_cgroup *parent = memcg;
+
+	while ((parent = parent_mem_cgroup(parent)))
+		atomic_inc(&parent->dead_count);
+
+	/*
+	 * if the root memcg is not hierarchical we have to check it
+	 * explicitely.
+	 */
+	if (!root_mem_cgroup->use_hierarchy)
+		atomic_inc(&root_mem_cgroup->dead_count);
+}
+
 static void mem_cgroup_css_offline(struct cgroup *cont)
 {
 	struct mem_cgroup *memcg = mem_cgroup_from_cont(cont);
 
+	mem_cgroup_invalidate_reclaim_iterators(memcg);
 	mem_cgroup_reparent_charges(memcg);
 	mem_cgroup_destroy_all_caches(memcg);
 }