diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 85599662bd9032ce35b6118d775eb096168571e2..95d6c256b54c461b8e3544355cbe16423390e05b 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -246,7 +246,10 @@ struct mem_cgroup {
 	 * Should the accounting and control be hierarchical, per subtree?
 	 */
 	bool use_hierarchy;
-	atomic_t	oom_lock;
+
+	bool		oom_lock;
+	atomic_t	under_oom;
+
 	atomic_t	refcnt;
 
 	int	swappiness;
@@ -1722,37 +1725,83 @@ static int mem_cgroup_hierarchical_reclaim(struct mem_cgroup *root_mem,
 /*
  * Check OOM-Killer is already running under our hierarchy.
  * If someone is running, return false.
+ * Has to be called with memcg_oom_mutex
  */
 static bool mem_cgroup_oom_lock(struct mem_cgroup *mem)
 {
-	int x, lock_count = 0;
-	struct mem_cgroup *iter;
+	int lock_count = -1;
+	struct mem_cgroup *iter, *failed = NULL;
+	bool cond = true;
 
-	for_each_mem_cgroup_tree(iter, mem) {
-		x = atomic_inc_return(&iter->oom_lock);
-		lock_count = max(x, lock_count);
+	for_each_mem_cgroup_tree_cond(iter, mem, cond) {
+		bool locked = iter->oom_lock;
+
+		iter->oom_lock = true;
+		if (lock_count == -1)
+			lock_count = iter->oom_lock;
+		else if (lock_count != locked) {
+			/*
+			 * this subtree of our hierarchy is already locked
+			 * so we cannot give a lock.
+			 */
+			lock_count = 0;
+			failed = iter;
+			cond = false;
+		}
 	}
 
-	if (lock_count == 1)
-		return true;
-	return false;
+	if (!failed)
+		goto done;
+
+	/*
+	 * OK, we failed to lock the whole subtree so we have to clean up
+	 * what we set up to the failing subtree
+	 */
+	cond = true;
+	for_each_mem_cgroup_tree_cond(iter, mem, cond) {
+		if (iter == failed) {
+			cond = false;
+			continue;
+		}
+		iter->oom_lock = false;
+	}
+done:
+	return lock_count;
 }
 
+/*
+ * Has to be called with memcg_oom_mutex
+ */
 static int mem_cgroup_oom_unlock(struct mem_cgroup *mem)
 {
 	struct mem_cgroup *iter;
 
+	for_each_mem_cgroup_tree(iter, mem)
+		iter->oom_lock = false;
+	return 0;
+}
+
+static void mem_cgroup_mark_under_oom(struct mem_cgroup *mem)
+{
+	struct mem_cgroup *iter;
+
+	for_each_mem_cgroup_tree(iter, mem)
+		atomic_inc(&iter->under_oom);
+}
+
+static void mem_cgroup_unmark_under_oom(struct mem_cgroup *mem)
+{
+	struct mem_cgroup *iter;
+
 	/*
 	 * When a new child is created while the hierarchy is under oom,
 	 * mem_cgroup_oom_lock() may not be called. We have to use
 	 * atomic_add_unless() here.
 	 */
 	for_each_mem_cgroup_tree(iter, mem)
-		atomic_add_unless(&iter->oom_lock, -1, 0);
-	return 0;
+		atomic_add_unless(&iter->under_oom, -1, 0);
 }
 
-
 static DEFINE_MUTEX(memcg_oom_mutex);
 static DECLARE_WAIT_QUEUE_HEAD(memcg_oom_waitq);
 
@@ -1794,7 +1843,7 @@ static void memcg_wakeup_oom(struct mem_cgroup *mem)
 
 static void memcg_oom_recover(struct mem_cgroup *mem)
 {
-	if (mem && atomic_read(&mem->oom_lock))
+	if (mem && atomic_read(&mem->under_oom))
 		memcg_wakeup_oom(mem);
 }
 
@@ -1812,6 +1861,8 @@ bool mem_cgroup_handle_oom(struct mem_cgroup *mem, gfp_t mask)
 	owait.wait.private = current;
 	INIT_LIST_HEAD(&owait.wait.task_list);
 	need_to_kill = true;
+	mem_cgroup_mark_under_oom(mem);
+
 	/* At first, try to OOM lock hierarchy under mem.*/
 	mutex_lock(&memcg_oom_mutex);
 	locked = mem_cgroup_oom_lock(mem);
@@ -1835,10 +1886,13 @@ bool mem_cgroup_handle_oom(struct mem_cgroup *mem, gfp_t mask)
 		finish_wait(&memcg_oom_waitq, &owait.wait);
 	}
 	mutex_lock(&memcg_oom_mutex);
-	mem_cgroup_oom_unlock(mem);
+	if (locked)
+		mem_cgroup_oom_unlock(mem);
 	memcg_wakeup_oom(mem);
 	mutex_unlock(&memcg_oom_mutex);
 
+	mem_cgroup_unmark_under_oom(mem);
+
 	if (test_thread_flag(TIF_MEMDIE) || fatal_signal_pending(current))
 		return false;
 	/* Give chance to dying process */
@@ -4505,7 +4559,7 @@ static int mem_cgroup_oom_register_event(struct cgroup *cgrp,
 	list_add(&event->list, &memcg->oom_notify);
 
 	/* already in OOM ? */
-	if (atomic_read(&memcg->oom_lock))
+	if (atomic_read(&memcg->under_oom))
 		eventfd_signal(eventfd, 1);
 	mutex_unlock(&memcg_oom_mutex);
 
@@ -4540,7 +4594,7 @@ static int mem_cgroup_oom_control_read(struct cgroup *cgrp,
 
 	cb->fill(cb, "oom_kill_disable", mem->oom_kill_disable);
 
-	if (atomic_read(&mem->oom_lock))
+	if (atomic_read(&mem->under_oom))
 		cb->fill(cb, "under_oom", 1);
 	else
 		cb->fill(cb, "under_oom", 0);