fix(el3-spm): improve direct messaging validation

Perform additional validation of the source and destination
IDs of direct messages.
Additionally track the sender of a direct request to allow
validating the target of the corresponding direct response.

Signed-off-by: Marc Bonnici <marc.bonnici@arm.com>
Change-Id: I8d39d53a02b8333246f1500c79ba04f149459c16
diff --git a/services/std_svc/spm/el3_spmc/spmc.h b/services/std_svc/spm/el3_spmc/spmc.h
index 61afee3..13875b9 100644
--- a/services/std_svc/spm/el3_spmc/spmc.h
+++ b/services/std_svc/spm/el3_spmc/spmc.h
@@ -131,6 +131,9 @@
 
 	/* Track the current runtime model of the SP. */
 	enum sp_runtime_model rt_model;
+
+	/* Track the source partition ID to validate a direct response. */
+	uint16_t dir_req_origin_id;
 };
 
 /*
diff --git a/services/std_svc/spm/el3_spmc/spmc_main.c b/services/std_svc/spm/el3_spmc/spmc_main.c
index 08e7218..ada6f45 100644
--- a/services/std_svc/spm/el3_spmc/spmc_main.c
+++ b/services/std_svc/spm/el3_spmc/spmc_main.c
@@ -260,6 +260,65 @@
 }
 
 /*******************************************************************************
+ * Helper function to validate the destination ID of a direct response.
+ ******************************************************************************/
+static bool direct_msg_validate_dst_id(uint16_t dst_id)
+{
+	struct secure_partition_desc *sp;
+
+	/* Check if we're targeting a normal world partition. */
+	if (ffa_is_normal_world_id(dst_id)) {
+		return true;
+	}
+
+	/* Or directed to the SPMC itself.*/
+	if (dst_id == FFA_SPMC_ID) {
+		return true;
+	}
+
+	/* Otherwise ensure the SP exists. */
+	sp = spmc_get_sp_ctx(dst_id);
+	if (sp != NULL) {
+		return true;
+	}
+
+	return false;
+}
+
+/*******************************************************************************
+ * Helper function to validate the response from a Logical Partition.
+ ******************************************************************************/
+static bool direct_msg_validate_lp_resp(uint16_t origin_id, uint16_t lp_id,
+					void *handle)
+{
+	/* Retrieve populated Direct Response Arguments. */
+	uint64_t x1 = SMC_GET_GP(handle, CTX_GPREG_X1);
+	uint64_t x2 = SMC_GET_GP(handle, CTX_GPREG_X2);
+	uint16_t src_id = ffa_endpoint_source(x1);
+	uint16_t dst_id = ffa_endpoint_destination(x1);
+
+	if (src_id != lp_id) {
+		ERROR("Invalid EL3 LP source ID (0x%x).\n", src_id);
+		return false;
+	}
+
+	/*
+	 * Check the destination ID is valid and ensure the LP is responding to
+	 * the original request.
+	 */
+	if ((!direct_msg_validate_dst_id(dst_id)) || (dst_id != origin_id)) {
+		ERROR("Invalid EL3 LP destination ID (0x%x).\n", dst_id);
+		return false;
+	}
+
+	if (!direct_msg_validate_arg2(x2)) {
+		ERROR("Invalid EL3 LP message encoding.\n");
+		return false;
+	}
+	return true;
+}
+
+/*******************************************************************************
  * Handle direct request messages and route to the appropriate destination.
  ******************************************************************************/
 static uint64_t direct_req_smc_handler(uint32_t smc_fid,
@@ -272,6 +331,7 @@
 				       void *handle,
 				       uint64_t flags)
 {
+	uint16_t src_id = ffa_endpoint_source(x1);
 	uint16_t dst_id = ffa_endpoint_destination(x1);
 	struct el3_lp_desc *el3_lp_descs;
 	struct secure_partition_desc *sp;
@@ -283,14 +343,29 @@
 					     FFA_ERROR_INVALID_PARAMETER);
 	}
 
+	/* Validate Sender is either the current SP or from the normal world. */
+	if ((secure_origin && src_id != spmc_get_current_sp_ctx()->sp_id) ||
+		(!secure_origin && !ffa_is_normal_world_id(src_id))) {
+		ERROR("Invalid direct request source ID (0x%x).\n", src_id);
+		return spmc_ffa_error_return(handle,
+					FFA_ERROR_INVALID_PARAMETER);
+	}
+
 	el3_lp_descs = get_el3_lp_array();
 
 	/* Check if the request is destined for a Logical Partition. */
 	for (unsigned int i = 0U; i < MAX_EL3_LP_DESCS_COUNT; i++) {
 		if (el3_lp_descs[i].sp_id == dst_id) {
-			return el3_lp_descs[i].direct_req(
-					smc_fid, secure_origin, x1, x2, x3, x4,
-					cookie, handle, flags);
+			uint64_t ret = el3_lp_descs[i].direct_req(
+						smc_fid, secure_origin, x1, x2,
+						x3, x4, cookie, handle, flags);
+			if (!direct_msg_validate_lp_resp(src_id, dst_id,
+							 handle)) {
+				panic();
+			}
+
+			/* Message checks out. */
+			return ret;
 		}
 	}
 
@@ -332,6 +407,7 @@
 	 */
 	sp->ec[idx].rt_state = RT_STATE_RUNNING;
 	sp->ec[idx].rt_model = RT_MODEL_DIR_REQ;
+	sp->ec[idx].dir_req_origin_id = src_id;
 	return spmc_smc_return(smc_fid, secure_origin, x1, x2, x3, x4,
 			       handle, cookie, flags, dst_id);
 }
@@ -370,7 +446,7 @@
 	 * Check that the response is either targeted to the Normal world or the
 	 * SPMC e.g. a PM response.
 	 */
-	if ((dst_id != FFA_SPMC_ID) && ffa_is_secure_world_id(dst_id)) {
+	if (!direct_msg_validate_dst_id(dst_id)) {
 		VERBOSE("Direct response to invalid partition ID (0x%x).\n",
 			dst_id);
 		return spmc_ffa_error_return(handle,
@@ -397,9 +473,18 @@
 		return spmc_ffa_error_return(handle, FFA_ERROR_DENIED);
 	}
 
+	if (sp->ec[idx].dir_req_origin_id != dst_id) {
+		WARN("Invalid direct resp partition ID 0x%x != 0x%x on core%u.\n",
+		     dst_id, sp->ec[idx].dir_req_origin_id, idx);
+		return spmc_ffa_error_return(handle, FFA_ERROR_DENIED);
+	}
+
 	/* Update the state of the SP execution context. */
 	sp->ec[idx].rt_state = RT_STATE_WAITING;
 
+	/* Clear the ongoing direct request ID. */
+	sp->ec[idx].dir_req_origin_id = INV_SP_ID;
+
 	/*
 	 * If the receiver is not the SPMC then forward the response to the
 	 * Normal world.
diff --git a/services/std_svc/spm/el3_spmc/spmc_pm.c b/services/std_svc/spm/el3_spmc/spmc_pm.c
index d25344c..c7e864f 100644
--- a/services/std_svc/spm/el3_spmc/spmc_pm.c
+++ b/services/std_svc/spm/el3_spmc/spmc_pm.c
@@ -83,6 +83,7 @@
 	/* Update the runtime model and state of the partition. */
 	ec->rt_model = RT_MODEL_INIT;
 	ec->rt_state = RT_STATE_RUNNING;
+	ec->dir_req_origin_id = INV_SP_ID;
 
 	INFO("SP (0x%x) init start on core%u.\n", sp->sp_id, linear_id);
 
@@ -132,6 +133,7 @@
 	/* Update the runtime model and state of the partition. */
 	ec->rt_model = RT_MODEL_DIR_REQ;
 	ec->rt_state = RT_STATE_RUNNING;
+	ec->dir_req_origin_id = FFA_SPMC_ID;
 
 	rc = spmc_sp_synchronous_entry(ec);
 	if (rc != 0ULL) {