diff --git a/include/linux/sched/mm.h b/include/linux/sched/mm.h index 1a80fb128e74..d5ece7a9a403 100644 --- a/include/linux/sched/mm.h +++ b/include/linux/sched/mm.h @@ -279,6 +279,7 @@ static inline void memalloc_nocma_restore(unsigned int flags) #endif #ifdef CONFIG_MEMCG +DECLARE_PER_CPU(struct mem_cgroup *, int_active_memcg); /** * set_active_memcg - Starts the remote memcg charging scope. * @memcg: memcg to charge. @@ -293,8 +294,16 @@ static inline void memalloc_nocma_restore(unsigned int flags) static inline struct mem_cgroup * set_active_memcg(struct mem_cgroup *memcg) { - struct mem_cgroup *old = current->active_memcg; - current->active_memcg = memcg; + struct mem_cgroup *old; + + if (in_interrupt()) { + old = this_cpu_read(int_active_memcg); + this_cpu_write(int_active_memcg, memcg); + } else { + old = current->active_memcg; + current->active_memcg = memcg; + } + return old; } #else diff --git a/mm/memcontrol.c b/mm/memcontrol.c index 51b1698bf06c..a3318b66e41e 100644 --- a/mm/memcontrol.c +++ b/mm/memcontrol.c @@ -73,6 +73,9 @@ EXPORT_SYMBOL(memory_cgrp_subsys); struct mem_cgroup *root_mem_cgroup __read_mostly; +/* Active memory cgroup to use from an interrupt context */ +DEFINE_PER_CPU(struct mem_cgroup *, int_active_memcg); + /* Socket memory accounting disabled? */ static bool cgroup_memory_nosocket; @@ -1061,26 +1064,43 @@ struct mem_cgroup *get_mem_cgroup_from_page(struct page *page) } EXPORT_SYMBOL(get_mem_cgroup_from_page); +static __always_inline struct mem_cgroup *active_memcg(void) +{ + if (in_interrupt()) + return this_cpu_read(int_active_memcg); + else + return current->active_memcg; +} + +static __always_inline struct mem_cgroup *get_active_memcg(void) +{ + struct mem_cgroup *memcg; + + rcu_read_lock(); + memcg = active_memcg(); + if (memcg) { + /* current->active_memcg must hold a ref. */ + if (WARN_ON_ONCE(!css_tryget(&memcg->css))) + memcg = root_mem_cgroup; + else + memcg = current->active_memcg; + } + rcu_read_unlock(); + + return memcg; +} + /** - * If current->active_memcg is non-NULL, do not fallback to current->mm->memcg. + * If active memcg is set, do not fallback to current->mm->memcg. */ static __always_inline struct mem_cgroup *get_mem_cgroup_from_current(void) { if (memcg_kmem_bypass()) return NULL; - if (unlikely(current->active_memcg)) { - struct mem_cgroup *memcg; + if (unlikely(active_memcg())) + return get_active_memcg(); - rcu_read_lock(); - /* current->active_memcg must hold a ref. */ - if (WARN_ON_ONCE(!css_tryget(¤t->active_memcg->css))) - memcg = root_mem_cgroup; - else - memcg = current->active_memcg; - rcu_read_unlock(); - return memcg; - } return get_mem_cgroup_from_mm(current->mm); } @@ -2940,8 +2960,8 @@ __always_inline struct obj_cgroup *get_obj_cgroup_from_current(void) return NULL; rcu_read_lock(); - if (unlikely(current->active_memcg)) - memcg = rcu_dereference(current->active_memcg); + if (unlikely(active_memcg())) + memcg = active_memcg(); else memcg = mem_cgroup_from_task(current);