diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index c0bff8976a69fb5b526fffaa425a4e83d02c6b36..2a80544aec99b7eeee63493176801cfa4a9b6ccd 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -180,7 +180,8 @@ static inline void mem_cgroup_dec_page_stat(struct page *page,
 unsigned long mem_cgroup_soft_limit_reclaim(struct zone *zone, int order,
 						gfp_t gfp_mask,
 						unsigned long *total_scanned);
-u64 mem_cgroup_get_limit(struct mem_cgroup *memcg);
+extern void __mem_cgroup_out_of_memory(struct mem_cgroup *memcg, gfp_t gfp_mask,
+				       int order);
 
 void mem_cgroup_count_vm_event(struct mm_struct *mm, enum vm_event_item idx);
 #ifdef CONFIG_TRANSPARENT_HUGEPAGE
@@ -364,12 +365,6 @@ unsigned long mem_cgroup_soft_limit_reclaim(struct zone *zone, int order,
 	return 0;
 }
 
-static inline
-u64 mem_cgroup_get_limit(struct mem_cgroup *memcg)
-{
-	return 0;
-}
-
 static inline void mem_cgroup_split_huge_fixup(struct page *head)
 {
 }
diff --git a/include/linux/oom.h b/include/linux/oom.h
index eb9dc14362c552063cf0dc04e3e57eb01521ba33..5dc0e384ae9ee793572028ac63f21a1987da0c93 100644
--- a/include/linux/oom.h
+++ b/include/linux/oom.h
@@ -40,17 +40,33 @@ enum oom_constraint {
 	CONSTRAINT_MEMCG,
 };
 
+enum oom_scan_t {
+	OOM_SCAN_OK,		/* scan thread and find its badness */
+	OOM_SCAN_CONTINUE,	/* do not consider thread for oom kill */
+	OOM_SCAN_ABORT,		/* abort the iteration and return */
+	OOM_SCAN_SELECT,	/* always select this thread first */
+};
+
 extern void compare_swap_oom_score_adj(int old_val, int new_val);
 extern int test_set_oom_score_adj(int new_val);
 
 extern unsigned long oom_badness(struct task_struct *p,
 		struct mem_cgroup *memcg, const nodemask_t *nodemask,
 		unsigned long totalpages);
+extern void oom_kill_process(struct task_struct *p, gfp_t gfp_mask, int order,
+			     unsigned int points, unsigned long totalpages,
+			     struct mem_cgroup *memcg, nodemask_t *nodemask,
+			     const char *message);
+
 extern int try_set_zonelist_oom(struct zonelist *zonelist, gfp_t gfp_flags);
 extern void clear_zonelist_oom(struct zonelist *zonelist, gfp_t gfp_flags);
 
+extern enum oom_scan_t oom_scan_process_thread(struct task_struct *task,
+		unsigned long totalpages, const nodemask_t *nodemask,
+		bool force_kill);
 extern void mem_cgroup_out_of_memory(struct mem_cgroup *memcg, gfp_t gfp_mask,
 				     int order);
+
 extern void out_of_memory(struct zonelist *zonelist, gfp_t gfp_mask,
 		int order, nodemask_t *mask, bool force_kill);
 extern int register_oom_notifier(struct notifier_block *nb);
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 4f73c823c59f32602c32e28aeba61a4fff97f83a..b78972e2f43fa2ab992ee25b6054394e37accb9b 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -1453,7 +1453,7 @@ static int mem_cgroup_count_children(struct mem_cgroup *memcg)
 /*
  * Return the memory (and swap, if configured) limit for a memcg.
  */
-u64 mem_cgroup_get_limit(struct mem_cgroup *memcg)
+static u64 mem_cgroup_get_limit(struct mem_cgroup *memcg)
 {
 	u64 limit;
 	u64 memsw;
@@ -1469,6 +1469,65 @@ u64 mem_cgroup_get_limit(struct mem_cgroup *memcg)
 	return min(limit, memsw);
 }
 
+void __mem_cgroup_out_of_memory(struct mem_cgroup *memcg, gfp_t gfp_mask,
+				int order)
+{
+	struct mem_cgroup *iter;
+	unsigned long chosen_points = 0;
+	unsigned long totalpages;
+	unsigned int points = 0;
+	struct task_struct *chosen = NULL;
+
+	totalpages = mem_cgroup_get_limit(memcg) >> PAGE_SHIFT ? : 1;
+	for_each_mem_cgroup_tree(iter, memcg) {
+		struct cgroup *cgroup = iter->css.cgroup;
+		struct cgroup_iter it;
+		struct task_struct *task;
+
+		cgroup_iter_start(cgroup, &it);
+		while ((task = cgroup_iter_next(cgroup, &it))) {
+			switch (oom_scan_process_thread(task, totalpages, NULL,
+							false)) {
+			case OOM_SCAN_SELECT:
+				if (chosen)
+					put_task_struct(chosen);
+				chosen = task;
+				chosen_points = ULONG_MAX;
+				get_task_struct(chosen);
+				/* fall through */
+			case OOM_SCAN_CONTINUE:
+				continue;
+			case OOM_SCAN_ABORT:
+				cgroup_iter_end(cgroup, &it);
+				mem_cgroup_iter_break(memcg, iter);
+				if (chosen)
+					put_task_struct(chosen);
+				return;
+			case OOM_SCAN_OK:
+				break;
+			};
+			points = oom_badness(task, memcg, NULL, totalpages);
+			if (points > chosen_points) {
+				if (chosen)
+					put_task_struct(chosen);
+				chosen = task;
+				chosen_points = points;
+				get_task_struct(chosen);
+			}
+		}
+		cgroup_iter_end(cgroup, &it);
+	}
+
+	if (!chosen)
+		return;
+	points = chosen_points * 1000 / totalpages;
+	read_lock(&tasklist_lock);
+	oom_kill_process(chosen, gfp_mask, order, points, totalpages, memcg,
+			 NULL, "Memory cgroup out of memory");
+	read_unlock(&tasklist_lock);
+	put_task_struct(chosen);
+}
+
 static unsigned long mem_cgroup_reclaim(struct mem_cgroup *memcg,
 					gfp_t gfp_mask,
 					unsigned long flags)
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index f8eba9651c0c28ee88452e2fd05bcadcbf629ef0..c0c97aea837f15b8fae520ee95296614bc429507 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -288,20 +288,13 @@ static enum oom_constraint constrained_alloc(struct zonelist *zonelist,
 }
 #endif
 
-enum oom_scan_t {
-	OOM_SCAN_OK,		/* scan thread and find its badness */
-	OOM_SCAN_CONTINUE,	/* do not consider thread for oom kill */
-	OOM_SCAN_ABORT,		/* abort the iteration and return */
-	OOM_SCAN_SELECT,	/* always select this thread first */
-};
-
-static enum oom_scan_t oom_scan_process_thread(struct task_struct *task,
-		struct mem_cgroup *memcg, unsigned long totalpages,
-		const nodemask_t *nodemask, bool force_kill)
+enum oom_scan_t oom_scan_process_thread(struct task_struct *task,
+		unsigned long totalpages, const nodemask_t *nodemask,
+		bool force_kill)
 {
 	if (task->exit_state)
 		return OOM_SCAN_CONTINUE;
-	if (oom_unkillable_task(task, memcg, nodemask))
+	if (oom_unkillable_task(task, NULL, nodemask))
 		return OOM_SCAN_CONTINUE;
 
 	/*
@@ -348,8 +341,8 @@ static enum oom_scan_t oom_scan_process_thread(struct task_struct *task,
  * (not docbooked, we don't want this one cluttering up the manual)
  */
 static struct task_struct *select_bad_process(unsigned int *ppoints,
-		unsigned long totalpages, struct mem_cgroup *memcg,
-		const nodemask_t *nodemask, bool force_kill)
+		unsigned long totalpages, const nodemask_t *nodemask,
+		bool force_kill)
 {
 	struct task_struct *g, *p;
 	struct task_struct *chosen = NULL;
@@ -358,7 +351,7 @@ static struct task_struct *select_bad_process(unsigned int *ppoints,
 	do_each_thread(g, p) {
 		unsigned int points;
 
-		switch (oom_scan_process_thread(p, memcg, totalpages, nodemask,
+		switch (oom_scan_process_thread(p, totalpages, nodemask,
 						force_kill)) {
 		case OOM_SCAN_SELECT:
 			chosen = p;
@@ -371,7 +364,7 @@ static struct task_struct *select_bad_process(unsigned int *ppoints,
 		case OOM_SCAN_OK:
 			break;
 		};
-		points = oom_badness(p, memcg, nodemask, totalpages);
+		points = oom_badness(p, NULL, nodemask, totalpages);
 		if (points > chosen_points) {
 			chosen = p;
 			chosen_points = points;
@@ -443,10 +436,10 @@ static void dump_header(struct task_struct *p, gfp_t gfp_mask, int order,
 }
 
 #define K(x) ((x) << (PAGE_SHIFT-10))
-static void oom_kill_process(struct task_struct *p, gfp_t gfp_mask, int order,
-			     unsigned int points, unsigned long totalpages,
-			     struct mem_cgroup *memcg, nodemask_t *nodemask,
-			     const char *message)
+void oom_kill_process(struct task_struct *p, gfp_t gfp_mask, int order,
+		      unsigned int points, unsigned long totalpages,
+		      struct mem_cgroup *memcg, nodemask_t *nodemask,
+		      const char *message)
 {
 	struct task_struct *victim = p;
 	struct task_struct *child;
@@ -564,10 +557,6 @@ static void check_panic_on_oom(enum oom_constraint constraint, gfp_t gfp_mask,
 void mem_cgroup_out_of_memory(struct mem_cgroup *memcg, gfp_t gfp_mask,
 			      int order)
 {
-	unsigned long limit;
-	unsigned int points = 0;
-	struct task_struct *p;
-
 	/*
 	 * If current has a pending SIGKILL, then automatically select it.  The
 	 * goal is to allow it to allocate so that it may quickly exit and free
@@ -579,13 +568,7 @@ void mem_cgroup_out_of_memory(struct mem_cgroup *memcg, gfp_t gfp_mask,
 	}
 
 	check_panic_on_oom(CONSTRAINT_MEMCG, gfp_mask, order, NULL);
-	limit = mem_cgroup_get_limit(memcg) >> PAGE_SHIFT ? : 1;
-	read_lock(&tasklist_lock);
-	p = select_bad_process(&points, limit, memcg, NULL, false);
-	if (p && PTR_ERR(p) != -1UL)
-		oom_kill_process(p, gfp_mask, order, points, limit, memcg, NULL,
-				 "Memory cgroup out of memory");
-	read_unlock(&tasklist_lock);
+	__mem_cgroup_out_of_memory(memcg, gfp_mask, order);
 }
 #endif
 
@@ -710,7 +693,7 @@ void out_of_memory(struct zonelist *zonelist, gfp_t gfp_mask,
 	struct task_struct *p;
 	unsigned long totalpages;
 	unsigned long freed = 0;
-	unsigned int points;
+	unsigned int uninitialized_var(points);
 	enum oom_constraint constraint = CONSTRAINT_NONE;
 	int killed = 0;
 
@@ -748,8 +731,7 @@ void out_of_memory(struct zonelist *zonelist, gfp_t gfp_mask,
 		goto out;
 	}
 
-	p = select_bad_process(&points, totalpages, NULL, mpol_mask,
-			       force_kill);
+	p = select_bad_process(&points, totalpages, mpol_mask, force_kill);
 	/* Found nothing?!?! Either we hang forever, or we panic. */
 	if (!p) {
 		dump_header(NULL, gfp_mask, order, NULL, mpol_mask);