diff --git a/drivers/base/power/domain.c b/drivers/base/power/domain.c
index 45eb3b155b6d0f74012245b0f55b1727a426542d..9727bc56320ad43286de61f32cd75356848ad2fb 100644
--- a/drivers/base/power/domain.c
+++ b/drivers/base/power/domain.c
@@ -1603,33 +1603,52 @@ int pm_genpd_remove_subdomain(struct generic_pm_domain *genpd,
  * @dev: Device to add the callbacks to.
  * @ops: Set of callbacks to add.
  * @td: Timing data to add to the device along with the callbacks (optional).
+ *
+ * Every call to this routine should be balanced with a call to
+ * __pm_genpd_remove_callbacks() and they must not be nested.
  */
 int pm_genpd_add_callbacks(struct device *dev, struct gpd_dev_ops *ops,
 			   struct gpd_timing_data *td)
 {
-	struct pm_domain_data *pdd;
+	struct generic_pm_domain_data *gpd_data_new, *gpd_data = NULL;
 	int ret = 0;
 
-	if (!(dev && dev->power.subsys_data && ops))
+	if (!(dev && ops))
 		return -EINVAL;
 
+	gpd_data_new = __pm_genpd_alloc_dev_data(dev);
+	if (!gpd_data_new)
+		return -ENOMEM;
+
 	pm_runtime_disable(dev);
 	device_pm_lock();
 
-	pdd = dev->power.subsys_data->domain_data;
-	if (pdd) {
-		struct generic_pm_domain_data *gpd_data = to_gpd_data(pdd);
+	ret = dev_pm_get_subsys_data(dev);
+	if (ret)
+		goto out;
+
+	spin_lock_irq(&dev->power.lock);
 
-		gpd_data->ops = *ops;
-		if (td)
-			gpd_data->td = *td;
+	if (dev->power.subsys_data->domain_data) {
+		gpd_data = to_gpd_data(dev->power.subsys_data->domain_data);
 	} else {
-		ret = -EINVAL;
+		gpd_data = gpd_data_new;
+		dev->power.subsys_data->domain_data = &gpd_data->base;
 	}
+	gpd_data->refcount++;
+	gpd_data->ops = *ops;
+	if (td)
+		gpd_data->td = *td;
 
+	spin_unlock_irq(&dev->power.lock);
+
+ out:
 	device_pm_unlock();
 	pm_runtime_enable(dev);
 
+	if (gpd_data != gpd_data_new)
+		__pm_genpd_free_dev_data(dev, gpd_data_new);
+
 	return ret;
 }
 EXPORT_SYMBOL_GPL(pm_genpd_add_callbacks);
@@ -1638,10 +1657,13 @@ EXPORT_SYMBOL_GPL(pm_genpd_add_callbacks);
  * __pm_genpd_remove_callbacks - Remove PM domain callbacks from a given device.
  * @dev: Device to remove the callbacks from.
  * @clear_td: If set, clear the device's timing data too.
+ *
+ * This routine can only be called after pm_genpd_add_callbacks().
  */
 int __pm_genpd_remove_callbacks(struct device *dev, bool clear_td)
 {
-	struct pm_domain_data *pdd;
+	struct generic_pm_domain_data *gpd_data = NULL;
+	bool remove = false;
 	int ret = 0;
 
 	if (!(dev && dev->power.subsys_data))
@@ -1650,21 +1672,35 @@ int __pm_genpd_remove_callbacks(struct device *dev, bool clear_td)
 	pm_runtime_disable(dev);
 	device_pm_lock();
 
-	pdd = dev->power.subsys_data->domain_data;
-	if (pdd) {
-		struct generic_pm_domain_data *gpd_data = to_gpd_data(pdd);
+	spin_lock_irq(&dev->power.lock);
 
+	if (dev->power.subsys_data->domain_data) {
+		gpd_data = to_gpd_data(dev->power.subsys_data->domain_data);
 		gpd_data->ops = (struct gpd_dev_ops){ 0 };
 		if (clear_td)
 			gpd_data->td = (struct gpd_timing_data){ 0 };
+
+		if (--gpd_data->refcount == 0) {
+			dev->power.subsys_data->domain_data = NULL;
+			remove = true;
+		}
 	} else {
 		ret = -EINVAL;
 	}
 
+	spin_unlock_irq(&dev->power.lock);
+
 	device_pm_unlock();
 	pm_runtime_enable(dev);
 
-	return ret;
+	if (ret)
+		return ret;
+
+	dev_pm_put_subsys_data(dev);
+	if (remove)
+		__pm_genpd_free_dev_data(dev, gpd_data);
+
+	return 0;
 }
 EXPORT_SYMBOL_GPL(__pm_genpd_remove_callbacks);