usb: xhci: Implement DMA mapping

An XHCI controller that sits behind an IOMMU needs to map and unmap
its memory buffers to do DMA.  Implement this by inroducing new
xhci_dma_map() and xhci_dma_unmap() helper functions.  The
xhci_dma_map() function replaces the existing xhci_virt_to_bus()
function in the sense that it returns the bus address in the case
of simple address translation in the absence of an IOMMU.  The
xhci_bus_to_virt() function is eliminated by storing the CPU
address of the allocated scratchpad memory in struct xhci_ctrl.

Signed-off-by: Mark Kettenis <kettenis@openbsd.org>
Reviewed-by: Marek Vasut <marex@denx.de>
diff --git a/drivers/usb/host/xhci-mem.c b/drivers/usb/host/xhci-mem.c
index 21cd03b..72b7530 100644
--- a/drivers/usb/host/xhci-mem.c
+++ b/drivers/usb/host/xhci-mem.c
@@ -64,8 +64,9 @@
  * @param ptr	pointer to "segement" to be freed
  * Return: none
  */
-static void xhci_segment_free(struct xhci_segment *seg)
+static void xhci_segment_free(struct xhci_ctrl *ctrl, struct xhci_segment *seg)
 {
+	xhci_dma_unmap(ctrl, seg->dma, SEGMENT_SIZE);
 	free(seg->trbs);
 	seg->trbs = NULL;
 
@@ -78,7 +79,7 @@
  * @param ptr	pointer to "ring" to be freed
  * Return: none
  */
-static void xhci_ring_free(struct xhci_ring *ring)
+static void xhci_ring_free(struct xhci_ctrl *ctrl, struct xhci_ring *ring)
 {
 	struct xhci_segment *seg;
 	struct xhci_segment *first_seg;
@@ -89,10 +90,10 @@
 	seg = first_seg->next;
 	while (seg != first_seg) {
 		struct xhci_segment *next = seg->next;
-		xhci_segment_free(seg);
+		xhci_segment_free(ctrl, seg);
 		seg = next;
 	}
-	xhci_segment_free(first_seg);
+	xhci_segment_free(ctrl, first_seg);
 
 	free(ring);
 }
@@ -105,12 +106,20 @@
  */
 static void xhci_scratchpad_free(struct xhci_ctrl *ctrl)
 {
+	struct xhci_hccr *hccr = ctrl->hccr;
+	int num_sp;
+
 	if (!ctrl->scratchpad)
 		return;
 
+	num_sp = HCS_MAX_SCRATCHPAD(xhci_readl(&hccr->cr_hcsparams2));
+	xhci_dma_unmap(ctrl, ctrl->scratchpad->sp_array[0],
+		       num_sp * ctrl->page_size);
+	xhci_dma_unmap(ctrl, ctrl->dcbaa->dev_context_ptrs[0],
+		       num_sp * sizeof(u64));
 	ctrl->dcbaa->dev_context_ptrs[0] = 0;
 
-	free(xhci_bus_to_virt(ctrl, le64_to_cpu(ctrl->scratchpad->sp_array[0])));
+	free(ctrl->scratchpad->scratchpad);
 	free(ctrl->scratchpad->sp_array);
 	free(ctrl->scratchpad);
 	ctrl->scratchpad = NULL;
@@ -122,8 +131,10 @@
  * @param ptr	pointer to "xhci_container_ctx" to be freed
  * Return: none
  */
-static void xhci_free_container_ctx(struct xhci_container_ctx *ctx)
+static void xhci_free_container_ctx(struct xhci_ctrl *ctrl,
+				    struct xhci_container_ctx *ctx)
 {
+	xhci_dma_unmap(ctrl, ctx->dma, ctx->size);
 	free(ctx->bytes);
 	free(ctx);
 }
@@ -153,12 +164,12 @@
 
 		for (i = 0; i < 31; ++i)
 			if (virt_dev->eps[i].ring)
-				xhci_ring_free(virt_dev->eps[i].ring);
+				xhci_ring_free(ctrl, virt_dev->eps[i].ring);
 
 		if (virt_dev->in_ctx)
-			xhci_free_container_ctx(virt_dev->in_ctx);
+			xhci_free_container_ctx(ctrl, virt_dev->in_ctx);
 		if (virt_dev->out_ctx)
-			xhci_free_container_ctx(virt_dev->out_ctx);
+			xhci_free_container_ctx(ctrl, virt_dev->out_ctx);
 
 		free(virt_dev);
 		/* make sure we are pointing to NULL */
@@ -174,11 +185,15 @@
  */
 void xhci_cleanup(struct xhci_ctrl *ctrl)
 {
-	xhci_ring_free(ctrl->event_ring);
-	xhci_ring_free(ctrl->cmd_ring);
+	xhci_ring_free(ctrl, ctrl->event_ring);
+	xhci_ring_free(ctrl, ctrl->cmd_ring);
 	xhci_scratchpad_free(ctrl);
 	xhci_free_virt_devices(ctrl);
+	xhci_dma_unmap(ctrl, ctrl->erst.erst_dma_addr,
+		       sizeof(struct xhci_erst_entry) * ERST_NUM_SEGS);
 	free(ctrl->erst.entries);
+	xhci_dma_unmap(ctrl, ctrl->dcbaa->dma,
+		       sizeof(struct xhci_device_context_array));
 	free(ctrl->dcbaa);
 	memset(ctrl, '\0', sizeof(struct xhci_ctrl));
 }
@@ -218,15 +233,13 @@
 			       struct xhci_segment *next, bool link_trbs)
 {
 	u32 val;
-	u64 val_64 = 0;
 
 	if (!prev || !next)
 		return;
 	prev->next = next;
 	if (link_trbs) {
-		val_64 = xhci_virt_to_bus(ctrl, next->trbs);
 		prev->trbs[TRBS_PER_SEGMENT-1].link.segment_ptr =
-			cpu_to_le64(val_64);
+			cpu_to_le64(next->dma);
 
 		/*
 		 * Set the last TRB in the segment to
@@ -273,7 +286,7 @@
  * @param	none
  * Return: pointer to the newly allocated SEGMENT
  */
-static struct xhci_segment *xhci_segment_alloc(void)
+static struct xhci_segment *xhci_segment_alloc(struct xhci_ctrl *ctrl)
 {
 	struct xhci_segment *seg;
 
@@ -281,6 +294,7 @@
 	BUG_ON(!seg);
 
 	seg->trbs = xhci_malloc(SEGMENT_SIZE);
+	seg->dma = xhci_dma_map(ctrl, seg->trbs, SEGMENT_SIZE);
 
 	seg->next = NULL;
 
@@ -314,7 +328,7 @@
 	if (num_segs == 0)
 		return ring;
 
-	ring->first_seg = xhci_segment_alloc();
+	ring->first_seg = xhci_segment_alloc(ctrl);
 	BUG_ON(!ring->first_seg);
 
 	num_segs--;
@@ -323,7 +337,7 @@
 	while (num_segs > 0) {
 		struct xhci_segment *next;
 
-		next = xhci_segment_alloc();
+		next = xhci_segment_alloc(ctrl);
 		BUG_ON(!next);
 
 		xhci_link_segments(ctrl, prev, next, link_trbs);
@@ -372,7 +386,8 @@
 	if (!scratchpad->sp_array)
 		goto fail_sp2;
 
-	val_64 = xhci_virt_to_bus(ctrl, scratchpad->sp_array);
+	val_64 = xhci_dma_map(ctrl, scratchpad->sp_array,
+			      num_sp * sizeof(u64));
 	ctrl->dcbaa->dev_context_ptrs[0] = cpu_to_le64(val_64);
 
 	xhci_flush_cache((uintptr_t)&ctrl->dcbaa->dev_context_ptrs[0],
@@ -386,16 +401,18 @@
 	}
 	BUG_ON(i == 16);
 
-	page_size = 1 << (i + 12);
-	buf = memalign(page_size, num_sp * page_size);
+	ctrl->page_size = 1 << (i + 12);
+	buf = memalign(ctrl->page_size, num_sp * ctrl->page_size);
 	if (!buf)
 		goto fail_sp3;
-	memset(buf, '\0', num_sp * page_size);
-	xhci_flush_cache((uintptr_t)buf, num_sp * page_size);
+	memset(buf, '\0', num_sp * ctrl->page_size);
+	xhci_flush_cache((uintptr_t)buf, num_sp * ctrl->page_size);
 
+	scratchpad->scratchpad = buf;
+	val_64 = xhci_dma_map(ctrl, buf, num_sp * ctrl->page_size);
 	for (i = 0; i < num_sp; i++) {
-		val_64 = xhci_virt_to_bus(ctrl, buf + i * page_size);
 		scratchpad->sp_array[i] = cpu_to_le64(val_64);
+		val_64 += ctrl->page_size;
 	}
 
 	xhci_flush_cache((uintptr_t)scratchpad->sp_array,
@@ -437,6 +454,7 @@
 		ctx->size += CTX_SIZE(xhci_readl(&ctrl->hccr->cr_hccparams));
 
 	ctx->bytes = xhci_malloc(ctx->size);
+	ctx->dma = xhci_dma_map(ctrl, ctx->bytes, ctx->size);
 
 	return ctx;
 }
@@ -487,7 +505,7 @@
 	/* Allocate endpoint 0 ring */
 	virt_dev->eps[0].ring = xhci_ring_alloc(ctrl, 1, true);
 
-	byte_64 = xhci_virt_to_bus(ctrl, virt_dev->out_ctx->bytes);
+	byte_64 = virt_dev->out_ctx->dma;
 
 	/* Point to output device context in dcbaa. */
 	ctrl->dcbaa->dev_context_ptrs[slot_id] = cpu_to_le64(byte_64);
@@ -523,15 +541,16 @@
 		return -ENOMEM;
 	}
 
-	val_64 = xhci_virt_to_bus(ctrl, ctrl->dcbaa);
+	ctrl->dcbaa->dma = xhci_dma_map(ctrl, ctrl->dcbaa,
+				sizeof(struct xhci_device_context_array));
 	/* Set the pointer in DCBAA register */
-	xhci_writeq(&hcor->or_dcbaap, val_64);
+	xhci_writeq(&hcor->or_dcbaap, ctrl->dcbaa->dma);
 
 	/* Command ring control pointer register initialization */
 	ctrl->cmd_ring = xhci_ring_alloc(ctrl, 1, true);
 
 	/* Set the address in the Command Ring Control register */
-	trb_64 = xhci_virt_to_bus(ctrl, ctrl->cmd_ring->first_seg->trbs);
+	trb_64 = ctrl->cmd_ring->first_seg->dma;
 	val_64 = xhci_readq(&hcor->or_crcr);
 	val_64 = (val_64 & (u64) CMD_RING_RSVD_BITS) |
 		(trb_64 & (u64) ~CMD_RING_RSVD_BITS) |
@@ -555,6 +574,8 @@
 	ctrl->event_ring = xhci_ring_alloc(ctrl, ERST_NUM_SEGS, false);
 	ctrl->erst.entries = xhci_malloc(sizeof(struct xhci_erst_entry) *
 					 ERST_NUM_SEGS);
+	ctrl->erst.erst_dma_addr = xhci_dma_map(ctrl, ctrl->erst.entries,
+			sizeof(struct xhci_erst_entry) * ERST_NUM_SEGS);
 
 	ctrl->erst.num_entries = ERST_NUM_SEGS;
 
@@ -562,7 +583,7 @@
 			val < ERST_NUM_SEGS;
 			val++) {
 		struct xhci_erst_entry *entry = &ctrl->erst.entries[val];
-		trb_64 = xhci_virt_to_bus(ctrl, seg->trbs);
+		trb_64 = seg->dma;
 		entry->seg_addr = cpu_to_le64(trb_64);
 		entry->seg_size = cpu_to_le32(TRBS_PER_SEGMENT);
 		entry->rsvd = 0;
@@ -571,7 +592,8 @@
 	xhci_flush_cache((uintptr_t)ctrl->erst.entries,
 			 ERST_NUM_SEGS * sizeof(struct xhci_erst_entry));
 
-	deq = xhci_virt_to_bus(ctrl, ctrl->event_ring->dequeue);
+	deq = xhci_trb_virt_to_dma(ctrl->event_ring->deq_seg,
+				   ctrl->event_ring->dequeue);
 
 	/* Update HC event ring dequeue pointer */
 	xhci_writeq(&ctrl->ir_set->erst_dequeue,
@@ -586,7 +608,7 @@
 	/* this is the event ring segment table pointer */
 	val_64 = xhci_readq(&ctrl->ir_set->erst_base);
 	val_64 &= ERST_PTR_MASK;
-	val_64 |= xhci_virt_to_bus(ctrl, ctrl->erst.entries) & ~ERST_PTR_MASK;
+	val_64 |= ctrl->erst.erst_dma_addr & ~ERST_PTR_MASK;
 
 	xhci_writeq(&ctrl->ir_set->erst_base, val_64);
 
@@ -849,7 +871,7 @@
 	/* EP 0 can handle "burst" sizes of 1, so Max Burst Size field is 0 */
 	ep0_ctx->ep_info2 |= cpu_to_le32(MAX_BURST(0) | ERROR_COUNT(3));
 
-	trb_64 = xhci_virt_to_bus(ctrl, virt_dev->eps[0].ring->first_seg->trbs);
+	trb_64 = virt_dev->eps[0].ring->first_seg->dma;
 	ep0_ctx->deq = cpu_to_le64(trb_64 | virt_dev->eps[0].ring->cycle_state);
 
 	/*
diff --git a/drivers/usb/host/xhci-ring.c b/drivers/usb/host/xhci-ring.c
index eb6dfcd..c8260cb 100644
--- a/drivers/usb/host/xhci-ring.c
+++ b/drivers/usb/host/xhci-ring.c
@@ -24,6 +24,24 @@
 
 #include <usb/xhci.h>
 
+/*
+ * Returns zero if the TRB isn't in this segment, otherwise it returns the DMA
+ * address of the TRB.
+ */
+dma_addr_t xhci_trb_virt_to_dma(struct xhci_segment *seg,
+				union xhci_trb *trb)
+{
+	unsigned long segment_offset;
+
+	if (!seg || !trb || trb < seg->trbs)
+		return 0;
+	/* offset in TRBs */
+	segment_offset = trb - seg->trbs;
+	if (segment_offset >= TRBS_PER_SEGMENT)
+		return 0;
+	return seg->dma + (segment_offset * sizeof(*trb));
+}
+
 /**
  * Is this TRB a link TRB or was the last TRB the last TRB in this event ring
  * segment?  I.e. would the updated event TRB pointer step off the end of the
@@ -180,10 +198,8 @@
  * @param trb_fields	pointer to trb field array containing TRB contents
  * Return: pointer to the enqueued trb
  */
-static struct xhci_generic_trb *queue_trb(struct xhci_ctrl *ctrl,
-					  struct xhci_ring *ring,
-					  bool more_trbs_coming,
-					  unsigned int *trb_fields)
+static dma_addr_t queue_trb(struct xhci_ctrl *ctrl, struct xhci_ring *ring,
+			    bool more_trbs_coming, unsigned int *trb_fields)
 {
 	struct xhci_generic_trb *trb;
 	int i;
@@ -197,7 +213,7 @@
 
 	inc_enq(ctrl, ring, more_trbs_coming);
 
-	return trb;
+	return xhci_trb_virt_to_dma(ring->enq_seg, (union xhci_trb *)trb);
 }
 
 /**
@@ -271,19 +287,15 @@
  * @param cmd		Command type to enqueue
  * Return: none
  */
-void xhci_queue_command(struct xhci_ctrl *ctrl, u8 *ptr, u32 slot_id,
+void xhci_queue_command(struct xhci_ctrl *ctrl, dma_addr_t addr, u32 slot_id,
 			u32 ep_index, trb_type cmd)
 {
 	u32 fields[4];
-	u64 val_64 = 0;
 
 	BUG_ON(prepare_ring(ctrl, ctrl->cmd_ring, EP_STATE_RUNNING));
 
-	if (ptr)
-		val_64 = xhci_virt_to_bus(ctrl, ptr);
-
-	fields[0] = lower_32_bits(val_64);
-	fields[1] = upper_32_bits(val_64);
+	fields[0] = lower_32_bits(addr);
+	fields[1] = upper_32_bits(addr);
 	fields[2] = 0;
 	fields[3] = TRB_TYPE(cmd) | SLOT_ID_FOR_TRB(slot_id) |
 		    ctrl->cmd_ring->cycle_state;
@@ -399,12 +411,15 @@
  */
 void xhci_acknowledge_event(struct xhci_ctrl *ctrl)
 {
+	dma_addr_t deq;
+
 	/* Advance our dequeue pointer to the next event */
 	inc_deq(ctrl, ctrl->event_ring);
 
 	/* Inform the hardware */
-	xhci_writeq(&ctrl->ir_set->erst_dequeue,
-		    xhci_virt_to_bus(ctrl, ctrl->event_ring->dequeue) | ERST_EHB);
+	deq = xhci_trb_virt_to_dma(ctrl->event_ring->deq_seg,
+				   ctrl->event_ring->dequeue);
+	xhci_writeq(&ctrl->ir_set->erst_dequeue, deq | ERST_EHB);
 }
 
 /**
@@ -490,17 +505,19 @@
 	struct xhci_ctrl *ctrl = xhci_get_ctrl(udev);
 	struct xhci_ring *ring =  ctrl->devs[udev->slot_id]->eps[ep_index].ring;
 	union xhci_trb *event;
+	u64 addr;
 	u32 field;
 
 	printf("Resetting EP %d...\n", ep_index);
-	xhci_queue_command(ctrl, NULL, udev->slot_id, ep_index, TRB_RESET_EP);
+	xhci_queue_command(ctrl, 0, udev->slot_id, ep_index, TRB_RESET_EP);
 	event = xhci_wait_for_event(ctrl, TRB_COMPLETION);
 	field = le32_to_cpu(event->trans_event.flags);
 	BUG_ON(TRB_TO_SLOT_ID(field) != udev->slot_id);
 	xhci_acknowledge_event(ctrl);
 
-	xhci_queue_command(ctrl, (void *)((uintptr_t)ring->enqueue |
-		ring->cycle_state), udev->slot_id, ep_index, TRB_SET_DEQ);
+	addr = xhci_trb_virt_to_dma(ring->enq_seg,
+		(void *)((uintptr_t)ring->enqueue | ring->cycle_state));
+	xhci_queue_command(ctrl, addr, udev->slot_id, ep_index, TRB_SET_DEQ);
 	event = xhci_wait_for_event(ctrl, TRB_COMPLETION);
 	BUG_ON(TRB_TO_SLOT_ID(le32_to_cpu(event->event_cmd.flags))
 		!= udev->slot_id || GET_COMP_CODE(le32_to_cpu(
@@ -521,9 +538,10 @@
 	struct xhci_ctrl *ctrl = xhci_get_ctrl(udev);
 	struct xhci_ring *ring =  ctrl->devs[udev->slot_id]->eps[ep_index].ring;
 	union xhci_trb *event;
+	u64 addr;
 	u32 field;
 
-	xhci_queue_command(ctrl, NULL, udev->slot_id, ep_index, TRB_STOP_RING);
+	xhci_queue_command(ctrl, 0, udev->slot_id, ep_index, TRB_STOP_RING);
 
 	event = xhci_wait_for_event(ctrl, TRB_TRANSFER);
 	field = le32_to_cpu(event->trans_event.flags);
@@ -539,8 +557,9 @@
 		event->event_cmd.status)) != COMP_SUCCESS);
 	xhci_acknowledge_event(ctrl);
 
-	xhci_queue_command(ctrl, (void *)((uintptr_t)ring->enqueue |
-		ring->cycle_state), udev->slot_id, ep_index, TRB_SET_DEQ);
+	addr = xhci_trb_virt_to_dma(ring->enq_seg,
+		(void *)((uintptr_t)ring->enqueue | ring->cycle_state));
+	xhci_queue_command(ctrl, addr, udev->slot_id, ep_index, TRB_SET_DEQ);
 	event = xhci_wait_for_event(ctrl, TRB_COMPLETION);
 	BUG_ON(TRB_TO_SLOT_ID(le32_to_cpu(event->event_cmd.flags))
 		!= udev->slot_id || GET_COMP_CODE(le32_to_cpu(
@@ -609,8 +628,8 @@
 	u64 addr;
 	int ret;
 	u32 trb_fields[4];
-	u64 val_64 = xhci_virt_to_bus(ctrl, buffer);
-	void *last_transfer_trb_addr;
+	u64 buf_64 = xhci_dma_map(ctrl, buffer, length);
+	dma_addr_t last_transfer_trb_addr;
 	int available_length;
 
 	debug("dev=%p, pipe=%lx, buffer=%p, length=%d\n",
@@ -633,7 +652,7 @@
 	 * we send request in more than 1 TRB by chaining them.
 	 */
 	running_total = TRB_MAX_BUFF_SIZE -
-			(lower_32_bits(val_64) & (TRB_MAX_BUFF_SIZE - 1));
+			(lower_32_bits(buf_64) & (TRB_MAX_BUFF_SIZE - 1));
 	trb_buff_len = running_total;
 	running_total &= TRB_MAX_BUFF_SIZE - 1;
 
@@ -678,7 +697,7 @@
 	 * that the buffer should not span 64KB boundary. if so
 	 * we send request in more than 1 TRB by chaining them.
 	 */
-	addr = val_64;
+	addr = buf_64;
 
 	if (trb_buff_len > length)
 		trb_buff_len = length;
@@ -754,7 +773,7 @@
 	}
 
 	if ((uintptr_t)(le64_to_cpu(event->trans_event.buffer)) !=
-	    (uintptr_t)xhci_virt_to_bus(ctrl, last_transfer_trb_addr)) {
+	    (uintptr_t)last_transfer_trb_addr) {
 		available_length -=
 			(int)EVENT_TRB_LEN(le32_to_cpu(event->trans_event.transfer_len));
 		xhci_acknowledge_event(ctrl);
@@ -768,6 +787,7 @@
 	record_transfer_result(udev, event, available_length);
 	xhci_acknowledge_event(ctrl);
 	xhci_inval_cache((uintptr_t)buffer, length);
+	xhci_dma_unmap(ctrl, buf_64, length);
 
 	return (udev->status != USB_ST_NOT_PROC) ? 0 : -1;
 }
@@ -911,7 +931,7 @@
 	if (length > 0) {
 		if (req->requesttype & USB_DIR_IN)
 			field |= TRB_DIR_IN;
-		buf_64 = xhci_virt_to_bus(ctrl, buffer);
+		buf_64 = xhci_dma_map(ctrl, buffer, length);
 
 		trb_fields[0] = lower_32_bits(buf_64);
 		trb_fields[1] = upper_32_bits(buf_64);
@@ -961,8 +981,10 @@
 	}
 
 	/* Invalidate buffer to make it available to usb-core */
-	if (length > 0)
+	if (length > 0) {
 		xhci_inval_cache((uintptr_t)buffer, length);
+		xhci_dma_unmap(ctrl, buf_64, length);
+	}
 
 	if (GET_COMP_CODE(le32_to_cpu(event->trans_event.transfer_len))
 			== COMP_SHORT_TX) {
diff --git a/drivers/usb/host/xhci.c b/drivers/usb/host/xhci.c
index dbeb88a..440b022 100644
--- a/drivers/usb/host/xhci.c
+++ b/drivers/usb/host/xhci.c
@@ -448,7 +448,7 @@
 	in_ctx = virt_dev->in_ctx;
 
 	xhci_flush_cache((uintptr_t)in_ctx->bytes, in_ctx->size);
-	xhci_queue_command(ctrl, in_ctx->bytes, udev->slot_id, 0,
+	xhci_queue_command(ctrl, in_ctx->dma, udev->slot_id, 0,
 			   ctx_change ? TRB_EVAL_CONTEXT : TRB_CONFIG_EP);
 	event = xhci_wait_for_event(ctrl, TRB_COMPLETION);
 	BUG_ON(TRB_TO_SLOT_ID(le32_to_cpu(event->event_cmd.flags))
@@ -585,7 +585,8 @@
 			cpu_to_le32(MAX_BURST(max_burst) |
 			ERROR_COUNT(err_count));
 
-		trb_64 = xhci_virt_to_bus(ctrl, virt_dev->eps[ep_index].ring->enqueue);
+		trb_64 = xhci_trb_virt_to_dma(virt_dev->eps[ep_index].ring->enq_seg,
+				virt_dev->eps[ep_index].ring->enqueue);
 		ep_ctx[ep_index]->deq = cpu_to_le64(trb_64 |
 				virt_dev->eps[ep_index].ring->cycle_state);
 
@@ -643,7 +644,8 @@
 	ctrl_ctx->add_flags = cpu_to_le32(SLOT_FLAG | EP0_FLAG);
 	ctrl_ctx->drop_flags = 0;
 
-	xhci_queue_command(ctrl, (void *)ctrl_ctx, slot_id, 0, TRB_ADDR_DEV);
+	xhci_queue_command(ctrl, virt_dev->in_ctx->dma,
+			   slot_id, 0, TRB_ADDR_DEV);
 	event = xhci_wait_for_event(ctrl, TRB_COMPLETION);
 	BUG_ON(TRB_TO_SLOT_ID(le32_to_cpu(event->event_cmd.flags)) != slot_id);
 
@@ -718,7 +720,7 @@
 		return 0;
 	}
 
-	xhci_queue_command(ctrl, NULL, 0, 0, TRB_ENABLE_SLOT);
+	xhci_queue_command(ctrl, 0, 0, 0, TRB_ENABLE_SLOT);
 	event = xhci_wait_for_event(ctrl, TRB_COMPLETION);
 	BUG_ON(GET_COMP_CODE(le32_to_cpu(event->event_cmd.status))
 		!= COMP_SUCCESS);
diff --git a/include/usb/xhci.h b/include/usb/xhci.h
index ea4cf3f..85c359f 100644
--- a/include/usb/xhci.h
+++ b/include/usb/xhci.h
@@ -16,6 +16,7 @@
 #ifndef HOST_XHCI_H_
 #define HOST_XHCI_H_
 
+#include <iommu.h>
 #include <phys2bus.h>
 #include <asm/types.h>
 #include <asm/cache.h>
@@ -490,6 +491,7 @@
 
 	int size;
 	u8 *bytes;
+	dma_addr_t dma;
 };
 
 /**
@@ -688,6 +690,8 @@
 struct xhci_device_context_array {
 	/* 64-bit device addresses; we only write 32-bit addresses */
 	__le64			dev_context_ptrs[MAX_HC_SLOTS];
+	/* private xHCD pointers */
+	dma_addr_t	dma;
 };
 /* TODO: write function to set the 64-bit device DMA address */
 /*
@@ -997,6 +1001,7 @@
 	union xhci_trb		*trbs;
 	/* private to HCD */
 	struct xhci_segment	*next;
+	dma_addr_t		dma;
 };
 
 struct xhci_ring {
@@ -1025,11 +1030,14 @@
 struct xhci_erst {
 	struct xhci_erst_entry	*entries;
 	unsigned int		num_entries;
+	/* xhci->event_ring keeps track of segment dma addresses */
+	dma_addr_t		erst_dma_addr;
 	/* Num entries the ERST can contain */
 	unsigned int		erst_size;
 };
 
 struct xhci_scratchpad {
+	void *scratchpad;
 	u64 *sp_array;
 };
 
@@ -1216,6 +1224,7 @@
 	struct xhci_virt_device *devs[MAX_HC_SLOTS];
 	int rootdev;
 	u16 hci_version;
+	int page_size;
 	u32 quirks;
 #define XHCI_MTK_HOST		BIT(0)
 };
@@ -1226,7 +1235,7 @@
 #define xhci_to_dev(_ctrl)	NULL
 #endif
 
-unsigned long trb_addr(struct xhci_segment *seg, union xhci_trb *trb);
+dma_addr_t xhci_trb_virt_to_dma(struct xhci_segment *seg, union xhci_trb *trb);
 struct xhci_input_control_ctx
 		*xhci_get_input_control_ctx(struct xhci_container_ctx *ctx);
 struct xhci_slot_ctx *xhci_get_slot_ctx(struct xhci_ctrl *ctrl,
@@ -1243,7 +1252,7 @@
 		    struct xhci_container_ctx *out_ctx);
 void xhci_setup_addressable_virt_dev(struct xhci_ctrl *ctrl,
 				     struct usb_device *udev, int hop_portnr);
-void xhci_queue_command(struct xhci_ctrl *ctrl, u8 *ptr,
+void xhci_queue_command(struct xhci_ctrl *ctrl, dma_addr_t addr,
 			u32 slot_id, u32 ep_index, trb_type cmd);
 void xhci_acknowledge_event(struct xhci_ctrl *ctrl);
 union xhci_trb *xhci_wait_for_event(struct xhci_ctrl *ctrl, trb_type expected);
@@ -1284,14 +1293,22 @@
 
 struct xhci_ctrl *xhci_get_ctrl(struct usb_device *udev);
 
-static inline dma_addr_t xhci_virt_to_bus(struct xhci_ctrl *ctrl, void *addr)
+static inline dma_addr_t xhci_dma_map(struct xhci_ctrl *ctrl, void *addr,
+				      size_t size)
 {
+#if CONFIG_IS_ENABLED(IOMMU)
+	return dev_iommu_dma_map(xhci_to_dev(ctrl), addr, size);
+#else
 	return dev_phys_to_bus(xhci_to_dev(ctrl), virt_to_phys(addr));
+#endif
 }
 
-static inline void *xhci_bus_to_virt(struct xhci_ctrl *ctrl, dma_addr_t addr)
+static inline void xhci_dma_unmap(struct xhci_ctrl *ctrl, dma_addr_t addr,
+				  size_t size)
 {
-	return phys_to_virt(dev_bus_to_phys(xhci_to_dev(ctrl), addr));
+#if CONFIG_IS_ENABLED(IOMMU)
+	dev_iommu_dma_unmap(xhci_to_dev(ctrl), addr, size);
+#endif
 }
 
 #endif /* HOST_XHCI_H_ */