dm: core: Avoid partially removing devices

At present if device_remove() decides that the device should not actually
be removed, it still calls the uclass pre_remove() method and powers the
device down.

Signed-off-by: Simon Glass <sjg@chromium.org>
diff --git a/drivers/core/device-remove.c b/drivers/core/device-remove.c
index 35b625e..bc99ef0 100644
--- a/drivers/core/device-remove.c
+++ b/drivers/core/device-remove.c
@@ -8,6 +8,8 @@
  * Pavel Herrmann <morpheus.ibis@gmail.com>
  */
 
+#define LOG_CATEGORY	LOGC_DM
+
 #include <common.h>
 #include <errno.h>
 #include <log.h>
@@ -54,7 +56,7 @@
 			continue;
 
 		ret = device_remove(pos, flags);
-		if (ret)
+		if (ret && ret != -EKEYREJECTED)
 			return ret;
 	}
 
@@ -149,13 +151,24 @@
 	devres_release_probe(dev);
 }
 
-static bool flags_remove(uint flags, uint drv_flags)
+/**
+ * flags_remove() - Figure out whether to remove a device
+ *
+ * @flags: Flags passed to device_remove()
+ * @drv_flags: Driver flags
+ * @return 0 if the device should be removed,
+ * -EKEYREJECTED if @flags includes a flag in DM_REMOVE_ACTIVE_ALL but
+ *	@drv_flags does not (indicates that this device has nothing to do for
+ *	DMA shutdown or OS prepare)
+ */
+static int flags_remove(uint flags, uint drv_flags)
 {
-	if ((flags & DM_REMOVE_NORMAL) ||
-	    (flags && (drv_flags & (DM_FLAG_ACTIVE_DMA | DM_FLAG_OS_PREPARE))))
-		return true;
+	if (flags & DM_REMOVE_NORMAL)
+		return 0;
+	if (flags && (drv_flags & DM_REMOVE_ACTIVE_ALL))
+		return 0;
 
-	return false;
+	return -EKEYREJECTED;
 }
 
 int device_remove(struct udevice *dev, uint flags)
@@ -169,22 +182,32 @@
 	if (!(dev_get_flags(dev) & DM_FLAG_ACTIVATED))
 		return 0;
 
+	/*
+	 * If the child returns EKEYREJECTED, continue. It just means that it
+	 * didn't match the flags.
+	 */
+	ret = device_chld_remove(dev, NULL, flags);
+	if (ret && ret != -EKEYREJECTED)
+		return ret;
+
+	/*
+	 * Remove the device if called with the "normal" remove flag set,
+	 * or if the remove flag matches any of the drivers remove flags
+	 */
 	drv = dev->driver;
 	assert(drv);
-
-	ret = device_chld_remove(dev, NULL, flags);
-	if (ret)
+	ret = flags_remove(flags, drv->flags);
+	if (ret) {
+		log_debug("%s: When removing: flags=%x, drv->flags=%x, err=%d\n",
+			  dev->name, flags, drv->flags, ret);
 		return ret;
+	}
 
 	ret = uclass_pre_remove_device(dev);
 	if (ret)
 		return ret;
 
-	/*
-	 * Remove the device if called with the "normal" remove flag set,
-	 * or if the remove flag matches any of the drivers remove flags
-	 */
-	if (drv->remove && flags_remove(flags, drv->flags)) {
+	if (drv->remove) {
 		ret = drv->remove(dev);
 		if (ret)
 			goto err_remove;
@@ -204,13 +227,11 @@
 	    dev != gd->cur_serial_dev)
 		dev_power_domain_off(dev);
 
-	if (flags_remove(flags, drv->flags)) {
-		device_free(dev);
+	device_free(dev);
 
-		dev_bic_flags(dev, DM_FLAG_ACTIVATED);
-	}
+	dev_bic_flags(dev, DM_FLAG_ACTIVATED);
 
-	return ret;
+	return 0;
 
 err_remove:
 	/* We can't put the children back */
diff --git a/include/dm/device-internal.h b/include/dm/device-internal.h
index 639bbd2..b513b68 100644
--- a/include/dm/device-internal.h
+++ b/include/dm/device-internal.h
@@ -123,7 +123,8 @@
  *
  * @dev: Pointer to device to remove
  * @flags: Flags for selective device removal (DM_REMOVE_...)
- * @return 0 if OK, -ve on error (an error here is normally a very bad thing)
+ * @return 0 if OK, -EKEYREJECTED if not removed due to flags, other -ve on
+ *	error (such an error here is normally a very bad thing)
  */
 #if CONFIG_IS_ENABLED(DM_DEVICE_REMOVE)
 int device_remove(struct udevice *dev, uint flags);
@@ -173,6 +174,12 @@
 
 /**
  * device_chld_remove() - Stop all device's children
+ *
+ * This continues through all children recursively stopping part-way through if
+ * an error occurs. Return values of -EKEYREJECTED are ignored and processing
+ * continues, since they just indicate that the child did not elect to be
+ * removed based on the value of @flags.
+ *
  * @dev:	The device whose children are to be removed
  * @drv:	The targeted driver
  * @flags:	Flag, if this functions is called in the pre-OS stage
diff --git a/test/dm/virtio.c b/test/dm/virtio.c
index ad35598..9a7e658 100644
--- a/test/dm/virtio.c
+++ b/test/dm/virtio.c
@@ -123,7 +123,9 @@
 
 	/* check the device can be successfully removed */
 	dev_or_flags(dev, DM_FLAG_ACTIVATED);
-	ut_assertok(device_remove(bus, DM_REMOVE_ACTIVE_ALL));
+	ut_asserteq(-EKEYREJECTED, device_remove(bus, DM_REMOVE_ACTIVE_ALL));
+
+	ut_asserteq(false, device_active(dev));
 
 	return 0;
 }