feat(spmc/mem): support multiple endpoints in memory transactions

Enable FFA_MEM_LEND and FFA_MEM_SHARE transactions to support multiple
borrowers and add the appropriate validation. Since we currently
only support a single S-EL1 partition, this functionality is to
support the use case where a VM shares or lends memory to one or
more VMs in the normal world as part of the same transaction to
the SP.

Signed-off-by: Marc Bonnici <marc.bonnici@arm.com>
Change-Id: Ia12c4357e9d015cb5f9b38e518b7a25b1ea2e30e
diff --git a/include/services/el3_spmc_ffa_memory.h b/include/services/el3_spmc_ffa_memory.h
index d4738a1..6c86732 100644
--- a/include/services/el3_spmc_ffa_memory.h
+++ b/include/services/el3_spmc_ffa_memory.h
@@ -196,8 +196,7 @@
  * @reserved_24_27:
  *         Reserved bytes 24-27. Must be 0.
  * @emad_count:
- *         Number of entries in @emad. Must be 1 in current implementation.
- *         FFA spec allows more entries.
+ *         Number of entries in @emad.
  * @emad:
  *         Endpoint memory access descriptor array (see @struct ffa_emad_v1_0).
  */
diff --git a/services/std_svc/spm/el3_spmc/spmc.h b/services/std_svc/spm/el3_spmc/spmc.h
index 22a745e..18b71bb 100644
--- a/services/std_svc/spm/el3_spmc/spmc.h
+++ b/services/std_svc/spm/el3_spmc/spmc.h
@@ -273,4 +273,10 @@
  */
 struct mailbox *spmc_get_mbox_desc(bool secure_origin);
 
+/*
+ * Helper function to obtain the context of an SP with a given partition ID.
+ */
+struct secure_partition_desc *spmc_get_sp_ctx(uint16_t id);
+
+
 #endif /* SPMC_H */
diff --git a/services/std_svc/spm/el3_spmc/spmc_shared_mem.c b/services/std_svc/spm/el3_spmc/spmc_shared_mem.c
index b9ca2fe..227d7cf 100644
--- a/services/std_svc/spm/el3_spmc/spmc_shared_mem.c
+++ b/services/std_svc/spm/el3_spmc/spmc_shared_mem.c
@@ -166,18 +166,27 @@
  * spmc_shmem_check_obj - Check that counts in descriptor match overall size.
  * @obj:    Object containing ffa_memory_region_descriptor.
  *
- * Return: 0 if object is valid, -EINVAL if memory region attributes count is
- * not 1, -EINVAL if constituent_memory_region_descriptor offset or count is
- * invalid.
+ * Return: 0 if object is valid, -EINVAL if constituent_memory_region_descriptor
+ * offset or count is invalid.
  */
 static int spmc_shmem_check_obj(struct spmc_shmem_obj *obj)
 {
-	if (obj->desc.emad_count != 1) {
-		WARN("%s: unsupported attribute desc count %u != 1\n",
+	if (obj->desc.emad_count == 0U) {
+		WARN("%s: unsupported attribute desc count %u.\n",
 		     __func__, obj->desc.emad_count);
 		return -EINVAL;
 	}
 
+	/*
+	 * Ensure the emad array lies within the bounds of the descriptor by
+	 * checking the address of the element past the end of the array.
+	 */
+	if ((uintptr_t) &obj->desc.emad[obj->desc.emad_count] >
+	    (uintptr_t)((uint8_t *) &obj->desc + obj->desc_size)) {
+		WARN("Invalid emad access.\n");
+		return -EINVAL;
+	}
+
 	for (size_t emad_num = 0; emad_num < obj->desc.emad_count; emad_num++) {
 		size_t size;
 		size_t count;
@@ -330,6 +339,38 @@
 			 (uint32_t)obj->desc.sender_id << 16, 0, 0, 0);
 	}
 
+	/* The full descriptor has been received, perform any final checks. */
+
+	/*
+	 * If a partition ID resides in the secure world validate that the
+	 * partition ID is for a known partition. Ignore any partition ID
+	 * belonging to the normal world as it is assumed the Hypervisor will
+	 * have validated these.
+	 */
+	for (size_t i = 0; i < obj->desc.emad_count; i++) {
+		ffa_endpoint_id16_t ep_id = obj->desc.emad[i].mapd.endpoint_id;
+
+		if (ffa_is_secure_world_id(ep_id)) {
+			if (spmc_get_sp_ctx(ep_id) == NULL) {
+				WARN("%s: Invalid receiver id 0x%x\n",
+				     __func__, ep_id);
+				ret = FFA_ERROR_INVALID_PARAMETER;
+				goto err_bad_desc;
+			}
+		}
+	}
+
+	/* Ensure partition IDs are not duplicated. */
+	for (size_t i = 0; i < obj->desc.emad_count; i++) {
+		for (size_t j = i + 1; j < obj->desc.emad_count; j++) {
+			if (obj->desc.emad[i].mapd.endpoint_id ==
+				obj->desc.emad[j].mapd.endpoint_id) {
+				ret = FFA_ERROR_INVALID_PARAMETER;
+				goto err_bad_desc;
+			}
+		}
+	}
+
 	SMC_RET8(smc_handle, FFA_SUCCESS_SMC32, 0, handle_low, handle_high, 0,
 		 0, 0, 0);
 
@@ -565,15 +606,10 @@
 		goto err_unlock_mailbox;
 	}
 
-	/*
-	 * Ensure endpoint count is 1, additional receivers not currently
-	 * supported.
-	 */
-	if (req->emad_count != 1U) {
-		WARN("%s: unsupported retrieve descriptor count: %u\n",
-		     __func__, req->emad_count);
-		ret = FFA_ERROR_INVALID_PARAMETER;
-		goto err_unlock_mailbox;
+	if (req->emad_count == 0U) {
+		WARN("%s: unsupported attribute desc count %u.\n",
+		     __func__, obj->desc.emad_count);
+		return -EINVAL;
 	}
 
 	if (total_length < sizeof(*req)) {
@@ -612,6 +648,13 @@
 		goto err_unlock_all;
 	}
 
+	if (req->emad_count != 0U && req->emad_count != obj->desc.emad_count) {
+		WARN("%s: mistmatch of endpoint counts %u != %u\n",
+		     __func__, req->emad_count, obj->desc.emad_count);
+		ret = FFA_ERROR_INVALID_PARAMETER;
+		goto err_unlock_all;
+	}
+
 	if (req->flags != 0U) {
 		if ((req->flags & FFA_MTD_FLAG_TYPE_MASK) !=
 		    (obj->desc.flags & FFA_MTD_FLAG_TYPE_MASK)) {
@@ -637,15 +680,39 @@
 		}
 	}
 
-	/* TODO: support more than one endpoint ids. */
-	if (req->emad_count != 0U &&
-	    req->emad[0].mapd.endpoint_id !=
-	    obj->desc.emad[0].mapd.endpoint_id) {
-		WARN("%s: wrong receiver id 0x%x != 0x%x\n",
-		     __func__, req->emad[0].mapd.endpoint_id,
-		       obj->desc.emad[0].mapd.endpoint_id);
-		ret = FFA_ERROR_INVALID_PARAMETER;
-		goto err_unlock_all;
+	/*
+	 * Ensure the emad array lies within the bounds of the descriptor by
+	 * checking the address of the element past the end of the array.
+	 */
+	if ((uintptr_t) &req->emad[req->emad_count] >
+	    (uintptr_t)((uint8_t *) &req + total_length)) {
+		WARN("Invalid emad access.\n");
+		return -EINVAL;
+	}
+
+	/*
+	 * Validate all the endpoints match in the case of multiple
+	 * borrowers. We don't mandate that the order of the borrowers
+	 * must match in the descriptors therefore check to see if the
+	 * endpoints match in any order.
+	 */
+	for (size_t i = 0; i < req->emad_count; i++) {
+		bool found = false;
+
+		for (size_t j = 0; j < obj->desc.emad_count; j++) {
+			if (req->emad[i].mapd.endpoint_id ==
+			    obj->desc.emad[j].mapd.endpoint_id) {
+				found = true;
+				break;
+			}
+		}
+
+		if (!found) {
+			WARN("%s: invalid receiver id (0x%x).\n",
+			     __func__, req->emad[i].mapd.endpoint_id);
+			ret = FFA_ERROR_INVALID_PARAMETER;
+			goto err_unlock_all;
+		}
 	}
 
 	mbox->state = MAILBOX_STATE_FULL;
@@ -822,6 +889,12 @@
 		goto err_unlock_mailbox;
 	}
 
+	if (req->endpoint_count == 0) {
+		WARN("%s: endpoint count cannot be 0.\n", __func__);
+		ret = FFA_ERROR_INVALID_PARAMETER;
+		goto err_unlock_mailbox;
+	}
+
 	spin_lock(&spmc_shmem_obj_state.lock);
 
 	obj = spmc_shmem_obj_lookup(&spmc_shmem_obj_state, req->handle);
@@ -831,16 +904,32 @@
 	}
 
 	if (obj->desc.emad_count != req->endpoint_count) {
+		WARN("%s: mismatch of endpoint count %u != %u\n", __func__,
+		     obj->desc.emad_count, req->endpoint_count);
 		ret = FFA_ERROR_INVALID_PARAMETER;
 		goto err_unlock_all;
 	}
+
+	/* Validate requested endpoint IDs match descriptor. */
 	for (size_t i = 0; i < req->endpoint_count; i++) {
-		if (req->endpoint_array[i] !=
-		    obj->desc.emad[i].mapd.endpoint_id) {
+		bool found = false;
+
+		for (unsigned int j = 0; j < obj->desc.emad_count; j++) {
+			if (req->endpoint_array[i] ==
+			    obj->desc.emad[j].mapd.endpoint_id) {
+				found = true;
+				break;
+			}
+		}
+
+		if (!found) {
+			WARN("%s: Invalid endpoint ID (0x%x).\n",
+			     __func__, req->endpoint_array[i]);
 			ret = FFA_ERROR_INVALID_PARAMETER;
 			goto err_unlock_all;
 		}
 	}
+
 	if (obj->in_use == 0U) {
 		ret = FFA_ERROR_INVALID_PARAMETER;
 		goto err_unlock_all;