Merge pull request #1194 from robertovargas-arm/io-fix

io: block: fix block_read/write may read/write overlap buffer
diff --git a/drivers/io/io_block.c b/drivers/io/io_block.c
index 128246f..8226554 100644
--- a/drivers/io/io_block.c
+++ b/drivers/io/io_block.c
@@ -167,15 +167,98 @@
 	return 0;
 }
 
+/*
+ * This function allows the caller to read any number of bytes
+ * from any position. It hides from the caller that the low level
+ * driver only can read aligned blocks of data. For this reason
+ * we need to handle the use case where the first byte to be read is not
+ * aligned to start of the block, the last byte to be read is also not
+ * aligned to the end of a block, and there are zero or more blocks-worth
+ * of data in between.
+ *
+ * In such a case we need to read more bytes than requested (i.e. full
+ * blocks) and strip-out the leading bytes (aka skip) and the trailing
+ * bytes (aka padding). See diagram below
+ *
+ * cur->file_pos ------------
+ *                          |
+ * cur->base                |
+ *  |                       |
+ *  v                       v<----  length   ---->
+ *  --------------------------------------------------------------
+ * |           |         block#1    |        |   block#n          |
+ * |  block#0  |            +       |   ...  |     +              |
+ * |           | <- skip -> +       |        |     + <- padding ->|
+ *  ------------------------+----------------------+--------------
+ *             ^                                                  ^
+ *             |                                                  |
+ *             v    iteration#1                iteration#n        v
+ *              --------------------------------------------------
+ *             |                    |        |                    |
+ *             |<----  request ---->|  ...   |<----- request ---->|
+ *             |                    |        |                    |
+ *              --------------------------------------------------
+ *            /                   /          |                    |
+ *           /                   /           |                    |
+ *          /                   /            |                    |
+ *         /                   /             |                    |
+ *        /                   /              |                    |
+ *       /                   /               |                    |
+ *      /                   /                |                    |
+ *     /                   /                 |                    |
+ *    /                   /                  |                    |
+ *   /                   /                   |                    |
+ *  <---- request ------>                    <------ request  ----->
+ *  ---------------------                    -----------------------
+ *  |        |          |                    |          |           |
+ *  |<-skip->|<-nbytes->|           -------->|<-nbytes->|<-padding->|
+ *  |        |          |           |        |          |           |
+ *  ---------------------           |        -----------------------
+ *  ^        \           \          |        |          |
+ *  |         \           \         |        |          |
+ *  |          \           \        |        |          |
+ *  buf->offset \           \   buf->offset  |          |
+ *               \           \               |          |
+ *                \           \              |          |
+ *                 \           \             |          |
+ *                  \           \            |          |
+ *                   \           \           |          |
+ *                    \           \          |          |
+ *                     \           \         |          |
+ *                      --------------------------------
+ *                      |           |        |         |
+ * buffer-------------->|           | ...    |         |
+ *                      |           |        |         |
+ *                      --------------------------------
+ *                      <-count#1->|                   |
+ *                      <----------  count#n   -------->
+ *                      <----------  length  ---------->
+ *
+ * Additionally, the IO driver has an underlying buffer that is at least
+ * one block-size and may be big enough to allow.
+ */
 static int block_read(io_entity_t *entity, uintptr_t buffer, size_t length,
 		      size_t *length_read)
 {
 	block_dev_state_t *cur;
 	io_block_spec_t *buf;
 	io_block_ops_t *ops;
-	size_t aligned_length, skip, count, left, padding, block_size;
 	int lba;
-	int buffer_not_aligned;
+	size_t block_size, left;
+	size_t nbytes;  /* number of bytes read in one iteration */
+	size_t request; /* number of requested bytes in one iteration */
+	size_t count;   /* number of bytes already read */
+	/*
+	 * number of leading bytes from start of the block
+	 * to the first byte to be read
+	 */
+	size_t skip;
+
+	/*
+	 * number of trailing bytes between the last byte
+	 * to be read and the end of the block
+	 */
+	size_t padding;
 
 	assert(entity->info != (uintptr_t)NULL);
 	cur = (block_dev_state_t *)entity->info;
@@ -186,102 +269,107 @@
 	       (length > 0) &&
 	       (ops->read != 0));
 
-	if ((buffer & (block_size - 1)) != 0) {
+	/*
+	 * We don't know the number of bytes that we are going
+	 * to read in every iteration, because it will depend
+	 * on the low level driver.
+	 */
+	count = 0;
+	for (left = length; left > 0; left -= nbytes) {
 		/*
-		 * buffer isn't aligned with block size.
-		 * Block device always relies on DMA operation.
-		 * It's better to make the buffer as block size aligned.
+		 * We must only request operations aligned to the block
+		 * size. Therefore if file_pos is not block-aligned,
+		 * we have to request the operation to start at the
+		 * previous block boundary and skip the leading bytes. And
+		 * similarly, the number of bytes requested must be a
+		 * block size multiple
 		 */
-		buffer_not_aligned = 1;
-	} else {
-		buffer_not_aligned = 0;
-	}
+		skip = cur->file_pos & (block_size - 1);
 
-	skip = cur->file_pos % block_size;
-	aligned_length = ((skip + length) + (block_size - 1)) &
-			 ~(block_size - 1);
-	padding = aligned_length - (skip + length);
-	left = aligned_length;
-	do {
+		/*
+		 * Calculate the block number containing file_pos
+		 * - e.g. block 3.
+		 */
 		lba = (cur->file_pos + cur->base) / block_size;
-		if (left >= buf->length) {
+
+		if (skip + left > buf->length) {
 			/*
-			 * Since left is larger, it's impossible to padding.
-			 *
-			 * If buffer isn't aligned, we need to use aligned
-			 * buffer instead.
+			 * The underlying read buffer is too small to
+			 * read all the required data - limit to just
+			 * fill the buffer, and then read again.
 			 */
-			if (skip || buffer_not_aligned) {
-				/*
-				 * The beginning address (file_pos) isn't
-				 * aligned with block size, we need to use
-				 * block buffer to read block. Since block
-				 * device is always relied on DMA operation.
-				 */
-				count = ops->read(lba, buf->offset,
-						  buf->length);
-			} else {
-				count = ops->read(lba, buffer, buf->length);
-			}
-			assert(count == buf->length);
-			cur->file_pos += count - skip;
-			if (skip || buffer_not_aligned) {
-				/*
-				 * Since there's not aligned block size caused
-				 * by skip or not aligned buffer, block buffer
-				 * is used to store data.
-				 */
-				memcpy((void *)buffer,
-				       (void *)(buf->offset + skip),
-				       count - skip);
-			}
-			left = left - (count - skip);
+			request = buf->length;
 		} else {
-			if (skip || padding || buffer_not_aligned) {
-				/*
-				 * The beginning address (file_pos) isn't
-				 * aligned with block size, we have to read
-				 * full block by block buffer instead.
-				 * The size isn't aligned with block size.
-				 * Use block buffer to avoid overflow.
-				 *
-				 * If buffer isn't aligned, use block buffer
-				 * to avoid DMA error.
-				 */
-				count = ops->read(lba, buf->offset, left);
-			} else
-				count = ops->read(lba, buffer, left);
-			assert(count == left);
-			left = left - (skip + padding);
-			cur->file_pos += left;
-			if (skip || padding || buffer_not_aligned) {
-				/*
-				 * Since there's not aligned block size or
-				 * buffer, block buffer is used to store data.
-				 */
-				memcpy((void *)buffer,
-				       (void *)(buf->offset + skip),
-				       left);
-			}
-			/* It's already the last block operation */
-			left = 0;
+			/*
+			 * The underlying read buffer is big enough to
+			 * read all the required data. Calculate the
+			 * number of bytes to read to align with the
+			 * block size.
+			 */
+			request = skip + left;
+			request = (request + (block_size - 1)) & ~(block_size - 1);
+		}
+		request = ops->read(lba, buf->offset, request);
+
+		if (request <= skip) {
+			/*
+			 * We couldn't read enough bytes to jump over
+			 * the skip bytes, so we should have to read
+			 * again the same block, thus generating
+			 * the same error.
+			 */
+			return -EIO;
 		}
-		skip = cur->file_pos % block_size;
-	} while (left > 0);
-	*length_read = length;
+
+		/*
+		 * Need to remove skip and padding bytes,if any, from
+		 * the read data when copying to the user buffer.
+		 */
+		nbytes = request - skip;
+		padding = (nbytes > left) ? nbytes - left : 0;
+		nbytes -= padding;
+
+		memcpy((void *)(buffer + count),
+		       (void *)(buf->offset + skip),
+		       nbytes);
+
+		cur->file_pos += nbytes;
+		count += nbytes;
+	}
+	assert(count == length);
+	*length_read = count;
 
 	return 0;
 }
 
+/*
+ * This function allows the caller to write any number of bytes
+ * from any position. It hides from the caller that the low level
+ * driver only can write aligned blocks of data.
+ * See comments for block_read for more details.
+ */
 static int block_write(io_entity_t *entity, const uintptr_t buffer,
 		       size_t length, size_t *length_written)
 {
 	block_dev_state_t *cur;
 	io_block_spec_t *buf;
 	io_block_ops_t *ops;
-	size_t aligned_length, skip, count, left, padding, block_size;
 	int lba;
-	int buffer_not_aligned;
+	size_t block_size, left;
+	size_t nbytes;  /* number of bytes read in one iteration */
+	size_t request; /* number of requested bytes in one iteration */
+	size_t count;   /* number of bytes already read */
+	/*
+	 * number of leading bytes from start of the block
+	 * to the first byte to be read
+	 */
+	size_t skip;
+
+	/*
+	 * number of trailing bytes between the last byte
+	 * to be read and the end of the block
+	 */
+	size_t padding;
 
 	assert(entity->info != (uintptr_t)NULL);
 	cur = (block_dev_state_t *)entity->info;
@@ -293,75 +381,107 @@
 	       (ops->read != 0) &&
 	       (ops->write != 0));
 
-	if ((buffer & (block_size - 1)) != 0) {
+	/*
+	 * We don't know the number of bytes that we are going
+	 * to write in every iteration, because it will depend
+	 * on the low level driver.
+	 */
+	count = 0;
+	for (left = length; left > 0; left -= nbytes) {
 		/*
-		 * buffer isn't aligned with block size.
-		 * Block device always relies on DMA operation.
-		 * It's better to make the buffer as block size aligned.
+		 * We must only request operations aligned to the block
+		 * size. Therefore if file_pos is not block-aligned,
+		 * we have to request the operation to start at the
+		 * previous block boundary and skip the leading bytes. And
+		 * similarly, the number of bytes requested must be a
+		 * block size multiple
 		 */
-		buffer_not_aligned = 1;
-	} else {
-		buffer_not_aligned = 0;
-	}
+		skip = cur->file_pos & (block_size - 1);
 
-	skip = cur->file_pos % block_size;
-	aligned_length = ((skip + length) + (block_size - 1)) &
-			 ~(block_size - 1);
-	padding = aligned_length - (skip + length);
-	left = aligned_length;
-	do {
+		/*
+		 * Calculate the block number containing file_pos
+		 * - e.g. block 3.
+		 */
 		lba = (cur->file_pos + cur->base) / block_size;
-		if (left >= buf->length) {
-			/* Since left is larger, it's impossible to padding. */
-			if (skip || buffer_not_aligned) {
-				/*
-				 * The beginning address (file_pos) isn't
-				 * aligned with block size or buffer isn't
-				 * aligned, we need to use block buffer to
-				 * write block.
-				 */
-				count = ops->read(lba, buf->offset,
-						  buf->length);
-				assert(count == buf->length);
-				memcpy((void *)(buf->offset + skip),
-				       (void *)buffer,
-				       count - skip);
-				count = ops->write(lba, buf->offset,
-						   buf->length);
-			} else
-				count = ops->write(lba, buffer, buf->length);
-			assert(count == buf->length);
-			cur->file_pos += count - skip;
-			left = left - (count - skip);
+
+		if (skip + left > buf->length) {
+			/*
+			 * The underlying read buffer is too small to
+			 * read all the required data - limit to just
+			 * fill the buffer, and then read again.
+			 */
+			request = buf->length;
 		} else {
-			if (skip || padding || buffer_not_aligned) {
+			/*
+			 * The underlying read buffer is big enough to
+			 * read all the required data. Calculate the
+			 * number of bytes to read to align with the
+			 * block size.
+			 */
+			request = skip + left;
+			request = (request + (block_size - 1)) & ~(block_size - 1);
+		}
+
+		/*
+		 * The number of bytes that we are going to write
+		 * from the user buffer will depend of the size
+		 * of the current request.
+		 */
+		nbytes = request - skip;
+		padding = (nbytes > left) ? nbytes - left : 0;
+		nbytes -= padding;
+
+		/*
+		 * If we have skip or padding bytes then we have to preserve
+		 * some content and it means that we have to read before
+		 * writing
+		 */
+		if (skip > 0 || padding > 0) {
+			request = ops->read(lba, buf->offset, request);
+			/*
+			 * The read may return size less than
+			 * requested. Round down to the nearest block
+			 * boundary
+			 */
+			request &= ~(block_size-1);
+			if (request <= skip) {
 				/*
-				 * The beginning address (file_pos) isn't
-				 * aligned with block size, we need to avoid
-				 * poluate data in the beginning. Reading and
-				 * skipping the beginning is the only way.
-				 * The size isn't aligned with block size.
-				 * Use block buffer to avoid overflow.
-				 *
-				 * If buffer isn't aligned, use block buffer
-				 * to avoid DMA error.
+				 * We couldn't read enough bytes to jump over
+				 * the skip bytes, so we should have to read
+				 * again the same block, thus generating
+				 * the same error.
 				 */
-				count = ops->read(lba, buf->offset, left);
-				assert(count == left);
-				memcpy((void *)(buf->offset + skip),
-				       (void *)buffer,
-				       left - skip - padding);
-				count = ops->write(lba, buf->offset, left);
-			} else
-				count = ops->write(lba, buffer, left);
-			assert(count == left);
-			cur->file_pos += left - (skip + padding);
-			/* It's already the last block operation */
-			left = 0;
+				return -EIO;
+			}
+			nbytes = request - skip;
+			padding = (nbytes > left) ? nbytes - left : 0;
+			nbytes -= padding;
 		}
-		skip = cur->file_pos % block_size;
-	} while (left > 0);
-	*length_written = length;
+
+		memcpy((void *)(buf->offset + skip),
+		       (void *)(buffer + count),
+		       nbytes);
+
+		request = ops->write(lba, buf->offset, request);
+		if (request <= skip)
+			return -EIO;
+
+		/*
+		 * And the previous write operation may modify the size
+		 * of the request, so again, we have to calculate the
+		 * number of bytes that we consumed from the user
+		 * buffer
+		 */
+		nbytes = request - skip;
+		padding = (nbytes > left) ? nbytes - left : 0;
+		nbytes -= padding;
+
+		cur->file_pos += nbytes;
+		count += nbytes;
+	}
+	assert(count == length);
+	*length_written = count;
+
 	return 0;
 }