diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
index 306b90cf051f63b9bd79337d94b804309288aca4..b28e66c4376ab3c320e144714f6f38a02e05fc9a 100644
--- a/drivers/vfio/pci/vfio_pci.c
+++ b/drivers/vfio/pci/vfio_pci.c
@@ -43,6 +43,10 @@ static int vfio_pci_enable(struct vfio_pci_device *vdev)
 	u16 cmd;
 	u8 msix_pos;
 
+	ret = pci_enable_device(pdev);
+	if (ret)
+		return ret;
+
 	vdev->reset_works = (pci_reset_function(pdev) == 0);
 	pci_save_state(pdev);
 	vdev->pci_saved_state = pci_store_saved_state(pdev);
@@ -51,8 +55,11 @@ static int vfio_pci_enable(struct vfio_pci_device *vdev)
 			 __func__, dev_name(&pdev->dev));
 
 	ret = vfio_config_init(vdev);
-	if (ret)
-		goto out;
+	if (ret) {
+		pci_load_and_free_saved_state(pdev, &vdev->pci_saved_state);
+		pci_disable_device(pdev);
+		return ret;
+	}
 
 	if (likely(!nointxmask))
 		vdev->pci_2_3 = pci_intx_mask_supported(pdev);
@@ -77,17 +84,7 @@ static int vfio_pci_enable(struct vfio_pci_device *vdev)
 	} else
 		vdev->msix_bar = 0xFF;
 
-	ret = pci_enable_device(pdev);
-	if (ret)
-		goto out;
-
-	return ret;
-
-out:
-	kfree(vdev->pci_saved_state);
-	vdev->pci_saved_state = NULL;
-	vfio_config_free(vdev);
-	return ret;
+	return 0;
 }
 
 static void vfio_pci_disable(struct vfio_pci_device *vdev)