Skip to content

Commit

Permalink
set the xhci device context into a configured state with interrupt in…
Browse files Browse the repository at this point in the history
… endpoint enabled
  • Loading branch information
FlareCoding committed Oct 4, 2024
1 parent 7afc665 commit 12c36d3
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 10 deletions.
Binary file modified efi/OVMF_VARS.fd
Binary file not shown.
157 changes: 147 additions & 10 deletions kernel/src/drivers/usb/xhci/xhci.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,45 @@ void XhciDriver::_configureDeviceInputContext(XhciDevice* device, uint16_t maxPa
controlEndpointContext->averageTrbLength = 8;
}

void XhciDriver::_configureDeviceInterruptEndpoint(XhciDevice* device, UsbEndpointDescriptor* epDesc) {
XhciInputControlContext32* inputControlContext = device->getInputControlContext(m_64ByteContextSize);
XhciSlotContext32* slotContext = device->getInputSlotContext(m_64ByteContextSize);

uint8_t endpointNumber = epDesc->bEndpointAddress & 0x0F;
kprint("endpointnumber: %i\n", endpointNumber);
uint8_t endpointDirectionIn = (epDesc->bEndpointAddress & 0x80) ? 1 : 0;
uint8_t endpointId = (endpointNumber * 2) + endpointDirectionIn;
kprint("endpointId: %i\n", endpointId);

// Enable the input control context flags
inputControlContext->addFlags = (1 << endpointId) | (1 << 0);
if (endpointId > slotContext->contextEntries) {
slotContext->contextEntries = endpointId;
}

// Configure the endpoint context
XhciEndpointContext32* interruptEndpointContext = device->getInputEndpointContext(m_64ByteContextSize, endpointId);
zeromem(interruptEndpointContext, sizeof(XhciEndpointContext32));
interruptEndpointContext->endpointState = XHCI_ENDPOINT_STATE_DISABLED;
interruptEndpointContext->endpointType = XHCI_ENDPOINT_TYPE_INTERRUPT_IN;
interruptEndpointContext->maxPacketSize = epDesc->wMaxPacketSize;
interruptEndpointContext->errorCount = 3;
interruptEndpointContext->maxBurstSize = 0;
interruptEndpointContext->averageTrbLength = 8;
interruptEndpointContext->transferRingDequeuePtr = device->getInterruptInEndpointTransferRing()->getPhysicalDequeuePointerBase();
interruptEndpointContext->dcs = device->getInterruptInEndpointTransferRing()->getCycleBit();

if (device->speed == XHCI_USB_SPEED_HIGH_SPEED || device->speed == XHCI_USB_SPEED_SUPER_SPEED) {
interruptEndpointContext->interval = epDesc->bInterval - 1;
} else {
interruptEndpointContext->interval = epDesc->bInterval;
}

kprint("transferRingDequeuePtr: 0x%llx\n", interruptEndpointContext->transferRingDequeuePtr);
const int intervalInMs = ((2 << (interruptEndpointContext->interval - 1)) * 125);
kprint("interval: %i (%i us / %i ms)\n", interruptEndpointContext->interval, intervalInMs, intervalInMs / 1000);
}

void XhciDriver::_setupDevice(uint8_t port) {
XhciDevice* device = new XhciDevice();
device->portRegSet = port;
Expand Down Expand Up @@ -724,6 +763,9 @@ void XhciDriver::_setupDevice(uint8_t port) {
// Send the address device command again with BSR=0 this time
_addressDevice(device, false);

// Copy the output device context into the device's input context
device->copyOutputDeviceContextToInputDeviceContext(m_64ByteContextSize, (void*)m_dcbaa[device->slotId]);

// Read the full device descriptor
if (!_getDeviceDescriptor(device, deviceDescriptor, deviceDescriptor->header.bLength)) {
kprintError("[XHCI] Failed to get full device descriptor\n");
Expand Down Expand Up @@ -760,11 +802,6 @@ void XhciDriver::_setupDevice(uint8_t port) {
if (!_getConfigurationDescriptor(device, configurationDescriptor)) {
return;
}

// Set device configuration
if (!_setDeviceConfiguration(device, configurationDescriptor->bConfigurationValue)) {
return;
}

UsbInterfaceDescriptor* iface = nullptr;
UsbEndpointDescriptor* epDescriptor = nullptr;
Expand Down Expand Up @@ -798,11 +835,6 @@ void XhciDriver::_setupDevice(uint8_t port) {
return;
}

const uint8_t bootProtocol = 0;
if (!_setProtocol(device, iface->bInterfaceNumber, bootProtocol)) {
return;
}

kprint("---- USB Device Info ----\n");
kprint(" Product Name : %s\n", product);
kprint(" Manufacturer : %s\n", manufacturer);
Expand All @@ -829,6 +861,41 @@ void XhciDriver::_setupDevice(uint8_t port) {
kprint(" wMaxPacketSize - %i\n", epDescriptor->wMaxPacketSize);
kprint(" bInterval - %i\n", epDescriptor->bInterval);
kprint("\n");

// Allocate a transfer ring for the interrupt endpoint
device->allocateInterruptInEndpointTransferRing();

// Re-configure the input context to enable the interrupt endpoint
_configureDeviceInterruptEndpoint(device, epDescriptor);

// Evaluate the new input context
if (!_configureEndpoint(device)) {
return;
}

device->copyOutputDeviceContextToInputDeviceContext(m_64ByteContextSize, (void*)m_dcbaa[device->slotId]);

// Sanity-check the actual device context entry in DCBAA
XhciDeviceContext32* deviceContext = &virtbase(device->getInputContextPhysicalBase(), XhciInputContext32)->deviceContext;

kprint(" DeviceContext[slotId=%i] address:0x%llx slotState:%s epSate:%s epType:%i\n maxPacketSize:%i\n",
device->slotId, deviceContext->slotContext.deviceAddress,
xhciSlotStateToString(deviceContext->slotContext.slotState),
xhciEndpointStateToString(deviceContext->ep[1].endpointState),
deviceContext->ep[1].endpointType,
deviceContext->ep[1].maxPacketSize
);

// Set device configuration
if (!_setDeviceConfiguration(device, configurationDescriptor->bConfigurationValue)) {
return;
}

// Set BOOT protocol
const uint8_t bootProtocol = 0;
if (!_setProtocol(device, iface->bInterfaceNumber, bootProtocol)) {
return;
}
}

bool XhciDriver::_addressDevice(XhciDevice* device, bool bsr) {
Expand Down Expand Up @@ -872,6 +939,76 @@ bool XhciDriver::_addressDevice(XhciDevice* device, bool bsr) {
return true;
}

bool XhciDriver::_configureEndpoint(XhciDevice* device) {
XhciConfigureEndpointCommandTrb_t configureEndpointTrb;
zeromem(&configureEndpointTrb, sizeof(XhciConfigureEndpointCommandTrb_t));
configureEndpointTrb.trbType = XHCI_TRB_TYPE_CONFIGURE_ENDPOINT_CMD;
configureEndpointTrb.inputContextPhysicalBase = device->getInputContextPhysicalBase();
configureEndpointTrb.slotId = device->slotId;

// Send the Configure Endpoint command
XhciCommandCompletionTrb_t* completionTrb = _sendCommand((XhciTrb_t*)&configureEndpointTrb, 200);
if (!completionTrb) {
kprintError("[*] Failed to send Configure Endpoint command\n");
return false;
}

// Check the completion code
if (completionTrb->completionCode != XHCI_TRB_COMPLETION_CODE_SUCCESS) {
kprintError("[*] Evaluate Context command failed with completion code: %s\n",
trbCompletionCodeToString(completionTrb->completionCode));
return false;
}

return true;
}

bool XhciDriver::_evaluateContext(XhciDevice* device) {
// Construct the Evaluate Context Command TRB
XhciEvaluateContextCommandTrb_t evaluateContextTrb;
zeromem(&evaluateContextTrb, sizeof(XhciEvaluateContextCommandTrb_t));
evaluateContextTrb.trbType = XHCI_TRB_TYPE_EVALUATE_CONTEXT_CMD;
evaluateContextTrb.inputContextPhysicalBase = device->getInputContextPhysicalBase();
evaluateContextTrb.slotId = device->slotId;

// Send the Evaluate Context command
XhciCommandCompletionTrb_t* completionTrb = _sendCommand((XhciTrb_t*)&evaluateContextTrb, 200);
if (!completionTrb) {
kprintError("[*] Failed to send Evaluate Context command\n");
return false;
}

// Check the completion code
if (completionTrb->completionCode != XHCI_TRB_COMPLETION_CODE_SUCCESS) {
kprintError("[*] Evaluate Context command failed with completion code: %s\n",
trbCompletionCodeToString(completionTrb->completionCode));
return false;
}

// Optionally, perform a sanity check similar to _addressDevice
if (m_64ByteContextSize) {
XhciDeviceContext64* deviceContext = virtbase(m_dcbaa[device->slotId], XhciDeviceContext64);

kprint(" DeviceContext[slotId=%i] address:0x%llx slotState:%s epState:%s maxPacketSize:%i\n",
device->slotId, deviceContext->slotContext.deviceAddress,
xhciSlotStateToString(deviceContext->slotContext.slotState),
xhciEndpointStateToString(deviceContext->controlEndpointContext.endpointState),
deviceContext->controlEndpointContext.maxPacketSize
);
} else {
XhciDeviceContext32* deviceContext = virtbase(m_dcbaa[device->slotId], XhciDeviceContext32);

kprint(" DeviceContext[slotId=%i] address:0x%llx slotState:%s epState:%s maxPacketSize:%i\n",
device->slotId, deviceContext->slotContext.deviceAddress,
xhciSlotStateToString(deviceContext->slotContext.slotState),
xhciEndpointStateToString(deviceContext->controlEndpointContext.endpointState),
deviceContext->controlEndpointContext.maxPacketSize
);
}

return true;
}

bool XhciDriver::_sendUsbRequestPacket(XhciDevice* device, XhciDeviceRequestPacket& req, void* outputBuffer, uint32_t length) {
XhciTransferRing* transferRing = device->getControlEndpointTransferRing();

Expand Down
4 changes: 4 additions & 0 deletions kernel/src/drivers/usb/xhci/xhci.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,12 @@ class XhciDriver : public DeviceDriver {
uint8_t _enableDeviceSlot();
void _configureDeviceInputContext(XhciDevice* device, uint16_t maxPacketSize);

void _configureDeviceInterruptEndpoint(XhciDevice* device, UsbEndpointDescriptor* epDesc);

void _setupDevice(uint8_t port);
bool _addressDevice(XhciDevice* device, bool bsr);
bool _configureEndpoint(XhciDevice* device);
bool _evaluateContext(XhciDevice* device);

bool _sendUsbRequestPacket(XhciDevice* device, XhciDeviceRequestPacket& req, void* outputBuffer, uint32_t length);

Expand Down
29 changes: 29 additions & 0 deletions kernel/src/drivers/usb/xhci/xhci_device_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ void XhciDevice::allocateControlEndpointTransferRing() {
m_controlEndpointTransferRing = new XhciTransferRing(XHCI_TRANSFER_RING_TRB_COUNT, slotId);
}

void XhciDevice::allocateInterruptInEndpointTransferRing() {
m_interruptInEndpointTransferRing = new XhciTransferRing(XHCI_TRANSFER_RING_TRB_COUNT, slotId);
}

XhciInputControlContext32* XhciDevice::getInputControlContext(bool use64ByteContexts) {
if (use64ByteContexts) {
XhciInputContext64* inputCtx = static_cast<XhciInputContext64*>(m_inputContext);
Expand Down Expand Up @@ -52,6 +56,31 @@ XhciEndpointContext32* XhciDevice::getInputControlEndpointContext(bool use64Byte
}
}

XhciEndpointContext32* XhciDevice::getInputEndpointContext(bool use64ByteContexts, uint8_t endpointID) {
uint8_t endpointIndex = endpointID - 2;
kprint("endpointContextIndex: %i\n", endpointIndex);

if (use64ByteContexts) {
XhciInputContext64* inputCtx = static_cast<XhciInputContext64*>(m_inputContext);
return reinterpret_cast<XhciEndpointContext32*>(&inputCtx->deviceContext.ep[endpointIndex]);
} else {
XhciInputContext32* inputCtx = static_cast<XhciInputContext32*>(m_inputContext);
return &inputCtx->deviceContext.ep[endpointIndex];
}
}

void XhciDevice::copyOutputDeviceContextToInputDeviceContext(bool use64ByteContexts, void* outputDeviceContext) {
if (use64ByteContexts) {
XhciInputContext64* inputCtx = static_cast<XhciInputContext64*>(m_inputContext);
XhciDeviceContext64* inputDeviceCtx = &inputCtx->deviceContext;
memcpy(inputDeviceCtx, outputDeviceContext, sizeof(XhciDeviceContext64));
} else {
XhciInputContext32* inputCtx = static_cast<XhciInputContext32*>(m_inputContext);
XhciDeviceContext32* inputDeviceCtx = &inputCtx->deviceContext;
memcpy(inputDeviceCtx, outputDeviceContext, sizeof(XhciDeviceContext32));
}
}

void printUsbDeviceDescriptor(const UsbDeviceDescriptor *desc) {
kprint("USB Device Descriptor:\n");
kprint("-------------------------------\n");
Expand Down
9 changes: 9 additions & 0 deletions kernel/src/drivers/usb/xhci/xhci_device_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -592,18 +592,27 @@ class XhciDevice {
uint64_t getInputContextPhysicalBase();

void allocateControlEndpointTransferRing();
void allocateInterruptInEndpointTransferRing();

__force_inline__ XhciTransferRing* getControlEndpointTransferRing() {
return m_controlEndpointTransferRing;
}

__force_inline__ XhciTransferRing* getInterruptInEndpointTransferRing() {
return m_interruptInEndpointTransferRing;
}

XhciInputControlContext32* getInputControlContext(bool use64ByteContexts);
XhciSlotContext32* getInputSlotContext(bool use64ByteContexts);
XhciEndpointContext32* getInputControlEndpointContext(bool use64ByteContexts);
XhciEndpointContext32* getInputEndpointContext(bool use64ByteContexts, uint8_t endpointID);

void copyOutputDeviceContextToInputDeviceContext(bool use64ByteContexts, void* outputDeviceContext);

private:
void* m_inputContext = nullptr;
XhciTransferRing* m_controlEndpointTransferRing = nullptr;
XhciTransferRing* m_interruptInEndpointTransferRing = nullptr;
};

void printUsbDeviceDescriptor(const UsbDeviceDescriptor *desc);
Expand Down
28 changes: 28 additions & 0 deletions kernel/src/drivers/usb/xhci/xhci_trb.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,34 @@ typedef struct XhciAddressDeviceRequestBlock {
} XhciAddressDeviceCommandTrb_t;
static_assert(sizeof(XhciAddressDeviceCommandTrb_t) == sizeof(uint32_t) * 4);

typedef struct XhciEvaluateContextCommandRequestBlock {
uint64_t inputContextPhysicalBase;
uint32_t rsvd0;
struct {
uint32_t cycleBit : 1;
uint32_t rsvd1 : 8;
uint32_t rsvd2 : 1; // Block Set Address Request bit in the Address Device TRB
uint32_t trbType : 6;
uint32_t rsvd3 : 8;
uint32_t slotId : 8;
};
} XhciEvaluateContextCommandTrb_t;
static_assert(sizeof(XhciEvaluateContextCommandTrb_t) == sizeof(uint32_t) * 4);

typedef struct XhciConfigureEndpointCommandRequestBlock {
uint64_t inputContextPhysicalBase;
uint32_t rsvd0;
struct {
uint32_t cycleBit : 1;
uint32_t rsvd1 : 8;
uint32_t deconfigure : 1;
uint32_t trbType : 6;
uint32_t rsvd3 : 8;
uint32_t slotId : 8;
};
} XhciConfigureEndpointCommandTrb_t;
static_assert(sizeof(XhciConfigureEndpointCommandTrb_t) == sizeof(uint32_t) * 4);

typedef struct XhciCommandCompletionRequestBlock {
uint64_t commandTrbPointer;
struct {
Expand Down

0 comments on commit 12c36d3

Please sign in to comment.