diff --git a/README.md b/README.md index 830b323..097b35a 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![JSR](https://jsr.io/badges/@deco/warp)](https://jsr.io/@deco/warp) +[![JSR](https://jsr.io/badges/@deco/warp)](https://jsr.io/@deco/warp) [![JSR Score](https://jsr.io/badges/@deco/warp/score)](https://jsr.io/@deco/warp) # Warp diff --git a/client.ts b/client.ts index c8618b7..eb512b0 100644 --- a/client.ts +++ b/client.ts @@ -1,4 +1,4 @@ -import { type Channel, makeWebSocket } from "./channel.ts"; +import { makeWebSocket } from "./channel.ts"; import denoJSON from "./deno.json" with { type: "json" }; import { handleServerMessage } from "./handlers.client.ts"; import type { ClientMessage, ClientState, ServerMessage } from "./messages.ts"; @@ -62,7 +62,6 @@ export const connectMainThread = async ( apiKey: opts.apiKey, domain: opts.domain, }); - const requestBody: Record> = {}; const wsSockets: Record = {}; (async () => { @@ -71,7 +70,7 @@ export const connectMainThread = async ( client, localAddr: opts.localAddr, live: false, - requestBody, + requests: {}, wsSockets, ch, }; diff --git a/deno.json b/deno.json index 01ae145..1b92ee8 100644 --- a/deno.json +++ b/deno.json @@ -1,6 +1,6 @@ { "name": "@deco/warp", - "version": "0.3.4", + "version": "0.3.5", "exports": "./mod.ts", "tasks": { "check": "deno fmt && deno lint && deno check mod.ts" diff --git a/handlers.client.ts b/handlers.client.ts index 96b4440..8f97b7a 100644 --- a/handlers.client.ts +++ b/handlers.client.ts @@ -1,6 +1,7 @@ import { type Channel, ignoreIfClosed, + link, makeChan, makeChanStream, makeReadableStream, @@ -11,6 +12,7 @@ import type { ClientState, ErrorMessage, RegisteredMessage, + RequestAbortedMessage, RequestDataEndMessage, RequestDataMessage, RequestStartMessage, @@ -49,17 +51,22 @@ const onRequestStart: ServerMessageHandler = async ( await handleWebSocket(message, state); return; } + const abortCtrl = new AbortController(); + state.requests[message.id] = { abortCtrl }; if (!message.hasBody) { - doFetch(message, state, state.ch.out).catch(ignoreIfClosed); + doFetch(message, state, state.ch.out, abortCtrl.signal).catch( + ignoreIfClosed, + ); } else { const bodyData = makeChan(); - state.requestBody[message.id] = bodyData; + state.requests[message.id]!.body = bodyData; doFetch( { ...message, body: makeReadableStream(bodyData) }, state, state.ch.out, + abortCtrl.signal, ).catch(ignoreIfClosed).finally(() => { - delete state.requestBody[message.id]; + delete state.requests[message.id]; }); } }; @@ -73,7 +80,7 @@ const onRequestData: ServerMessageHandler = async ( state, message, ) => { - const reqBody = state.requestBody[message.id]; + const reqBody = state.requests[message.id]?.body; if (!reqBody) { console.info("[req-data] req not found", message.id); return; @@ -81,6 +88,18 @@ const onRequestData: ServerMessageHandler = async ( await reqBody.send?.(message.chunk); }; +/** + * Handler for the 'request-aborted' server message. + * @param {ClientState} state - The client state. + * @param {RequestAbortedMessage} message - The message data. + */ +const onRequestAborted: ServerMessageHandler = ( + state, + message, +) => { + state.requests[message.id]?.abortCtrl?.abort(); +}; + /** * Handler for the 'request-data-end' server message. * @param {ClientState} state - The client state. @@ -90,7 +109,7 @@ const onRequestDataEnd: ServerMessageHandler = ( state, message, ) => { - const reqBody = state.requestBody[message.id]; + const reqBody = state.requests[message.id]?.body; if (!reqBody) { return; } @@ -129,6 +148,7 @@ const handlersByType: Record> = { registered, error, + "request-aborted": onRequestAborted, "request-start": onRequestStart, "request-data": onRequestData, "request-end": onRequestDataEnd, @@ -204,9 +224,10 @@ async function doFetch( request: RequestStartMessage & { body?: ReadableStream }, state: ClientState, clientCh: Channel, + reqSignal: AbortSignal, ) { // Read from the stream - const signal = clientCh.signal; + const signal = link(clientCh.signal, reqSignal); try { const response = await fetch( new URL(request.url, state.localAddr), diff --git a/messages.ts b/messages.ts index 4466704..c83dd56 100644 --- a/messages.ts +++ b/messages.ts @@ -72,6 +72,11 @@ export interface RequestDataMessage { id: string; chunk: Uint8Array; } + +export interface RequestAbortedMessage { + type: "request-aborted"; + id: string; +} export interface RegisteredMessage { type: "registered"; id: string; @@ -82,6 +87,7 @@ export interface ErrorMessage { message: string; } export type ServerMessage = + | RequestAbortedMessage | WSMessage | WSConnectionClosed | RequestStartMessage @@ -90,9 +96,13 @@ export type ServerMessage = | RegisteredMessage | ErrorMessage; +export interface RequestState { + body?: Channel; + abortCtrl: AbortController; +} export interface ClientState { ch: DuplexChannel; - requestBody: Record>; + requests: Record; wsSockets: Record; live: boolean; localAddr: string; diff --git a/server.ts b/server.ts index 0818a7d..68bd542 100644 --- a/server.ts +++ b/server.ts @@ -165,6 +165,14 @@ export const serveHandler = ( await ch.out.send(requestForward); const dataChan = req.body ? makeChanStream(req.body) : undefined; const linked = link(ch.out.signal, req.signal); + req.signal.addEventListener("abort", () => { + if (!ch.out.signal.aborted) { + ch.out.send({ + type: "request-aborted", + id: messageId, + }).catch(() => {}); + } + }); (async () => { try { for await (const chunk of dataChan?.recv(linked) ?? []) {