diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 0d8ad1ebd1d1b6164bd8088d982b6ef6da5a653d..054ff47c4478ddb6d6ef77cb66f6fd8330bdbea4 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -102,14 +102,19 @@ struct task_struct *find_lock_task_mm(struct task_struct *p)
 {
 	struct task_struct *t;
 
+	rcu_read_lock();
+
 	for_each_thread(p, t) {
 		task_lock(t);
 		if (likely(t->mm))
-			return t;
+			goto found;
 		task_unlock(t);
 	}
+	t = NULL;
+found:
+	rcu_read_unlock();
 
-	return NULL;
+	return t;
 }
 
 /* return true if the task is not adequate as candidate victim task. */
@@ -461,10 +466,8 @@ void oom_kill_process(struct task_struct *p, gfp_t gfp_mask, int order,
 	}
 	read_unlock(&tasklist_lock);
 
-	rcu_read_lock();
 	p = find_lock_task_mm(victim);
 	if (!p) {
-		rcu_read_unlock();
 		put_task_struct(victim);
 		return;
 	} else if (victim != p) {
@@ -490,6 +493,7 @@ void oom_kill_process(struct task_struct *p, gfp_t gfp_mask, int order,
 	 * That thread will now get access to memory reserves since it has a
 	 * pending fatal signal.
 	 */
+	rcu_read_lock();
 	for_each_process(p)
 		if (p->mm == mm && !same_thread_group(p, victim) &&
 		    !(p->flags & PF_KTHREAD)) {