From 6499be6aa0ba3212728b4ba1eae88544f5f923ef Mon Sep 17 00:00:00 2001 From: Fredrik Lindberg Date: Tue, 17 Oct 2023 20:13:27 +0200 Subject: [PATCH] refactor: http listener - Convert to typescript - Add unit test --- src/controller/admin-api-controller.ts | 10 +- src/controller/api-controller.ts | 10 +- src/controller/koa-controller.ts | 40 ++-- src/ingress/http-ingress.js | 23 +- src/listener/http-listener.js | 201 ------------------ src/listener/http-listener.ts | 255 +++++++++++++++++++++++ src/listener/index.js | 42 ---- src/listener/listener-interface.js | 65 ------ src/listener/listener.ts | 122 +++++++++++ src/transport/transport-service.ts | 10 +- src/transport/ws/ws-endpoint.ts | 30 ++- test/unit/listener/test_http-listener.ts | 255 +++++++++++++++++++++++ 12 files changed, 692 insertions(+), 371 deletions(-) delete mode 100644 src/listener/http-listener.js create mode 100644 src/listener/http-listener.ts delete mode 100644 src/listener/index.js delete mode 100644 src/listener/listener-interface.js create mode 100644 src/listener/listener.ts create mode 100644 test/unit/listener/test_http-listener.ts diff --git a/src/controller/admin-api-controller.ts b/src/controller/admin-api-controller.ts index 8126ac2..1970db3 100644 --- a/src/controller/admin-api-controller.ts +++ b/src/controller/admin-api-controller.ts @@ -122,7 +122,7 @@ class AdminApiController extends KoaController { } }; - const tunnelProps = (tunnel: Tunnel, baseUrl: string) => { + const tunnelProps = (tunnel: Tunnel, baseUrl: URL | undefined) => { return { tunnel_id: tunnel.id, account_id: tunnel.account, @@ -154,10 +154,6 @@ class AdminApiController extends KoaController { } }; - const getBaseUrl = (req: any) => { - return req._exposrBaseUrl; - } - router.route({ method: 'post', path: '/v1/admin/account', @@ -276,7 +272,7 @@ class AdminApiController extends KoaController { try { const tunnel = await this._tunnelService.lookup(ctx.params.tunnel_id); ctx.status = 200; - ctx.body = tunnelProps(tunnel, getBaseUrl(ctx.req)); + ctx.body = tunnelProps(tunnel, this.getBaseUrl(ctx.req)); } catch (e: any) { if (e.message == 'no_such_tunnel') { ctx.status = 404; @@ -370,7 +366,7 @@ class AdminApiController extends KoaController { ctx.body = { cursor: res.cursor, tunnels: res.tunnels.map((t) => { - return ctx.query.verbose ? tunnelProps(t, getBaseUrl(ctx.req)) : t.id; + return ctx.query.verbose ? tunnelProps(t, this.getBaseUrl(ctx.req)) : t.id; }), }; }] diff --git a/src/controller/api-controller.ts b/src/controller/api-controller.ts index ac1861a..0ad3841 100644 --- a/src/controller/api-controller.ts +++ b/src/controller/api-controller.ts @@ -98,7 +98,7 @@ class ApiController extends KoaController { return next(); }; - const tunnelInfo = (tunnel: Tunnel, baseUrl: string) => { + const tunnelInfo = (tunnel: Tunnel, baseUrl: URL | undefined) => { const info = { id: tunnel.id, connection: { @@ -131,10 +131,6 @@ class ApiController extends KoaController { return info; }; - const getBaseUrl = (req: any) => { - return req._exposrBaseUrl; - }; - router.route({ method: ['put', 'patch'], path: '/v1/tunnel/:tunnel_id', @@ -209,7 +205,7 @@ class ApiController extends KoaController { tunnel.transport.ssh.enabled = body?.transport?.ssh?.enabled ?? tunnel.transport.ssh.enabled; }); - ctx.body = tunnelInfo(updatedTunnel, getBaseUrl(ctx.req)); + ctx.body = tunnelInfo(updatedTunnel, this.getBaseUrl(ctx.req)); ctx.status = 200; } catch (e: any) { if (e.message == 'permission_denied') { @@ -278,7 +274,7 @@ class ApiController extends KoaController { try { const tunnel = await this.tunnelService.get(tunnelId, account.id); ctx.status = 200; - ctx.body = tunnelInfo(tunnel, getBaseUrl(ctx.req)); + ctx.body = tunnelInfo(tunnel, this.getBaseUrl(ctx.req)); } catch (e: any) { ctx.status = 404; ctx.body = { diff --git a/src/controller/koa-controller.ts b/src/controller/koa-controller.ts index 869684e..1855ec2 100644 --- a/src/controller/koa-controller.ts +++ b/src/controller/koa-controller.ts @@ -1,7 +1,8 @@ +import { strict as assert } from 'assert'; import Koa from 'koa'; -import Router, { FullHandler } from 'koa-joi-router'; -import Listener from '../listener/index.js'; -import HttpListener from '../listener/http-listener.js'; +import Router from 'koa-joi-router'; +import Listener from '../listener/listener.js'; +import HttpListener, { HttpRequestCallback, HttpRequestType } from '../listener/http-listener.js'; import { IncomingMessage, ServerResponse } from 'http'; abstract class KoaController { @@ -9,14 +10,12 @@ abstract class KoaController { public readonly _name: string = 'controller' private _port!: number; private httpListener!: HttpListener; - private _requestHandler: any; + private _requestHandler!: HttpRequestCallback; private router!: Router.Router; private app!: Koa; constructor(opts: any) { - if (opts == undefined) { - return; - } + assert(opts != undefined); const {port, callback, logger, host, prio} = opts; if (opts?.enable === false) { @@ -26,8 +25,8 @@ abstract class KoaController { this._port = port; - const useCallback: FullHandler = async (ctx, next) => { - const setBaseUrl = (req: any, baseUrl: string) => { + const useCallback: HttpRequestCallback = this._requestHandler = async (ctx, next) => { + const setBaseUrl = (req: any, baseUrl: URL | undefined) => { req._exposrBaseUrl = baseUrl; }; setBaseUrl(ctx.req, ctx.baseUrl) @@ -36,17 +35,12 @@ abstract class KoaController { } } - const httpListener = this.httpListener = Listener.acquire('http', port, { app: new Koa() }); - this._requestHandler = httpListener.use('request', { host, logger, prio, logBody: true }, useCallback); - - httpListener.setState({ - app: new Koa(), - ...httpListener.state, - }); - this.app = httpListener.state.app; + const httpListener = this.httpListener = Listener.acquire(HttpListener, port); + httpListener.use(HttpRequestType.request, { host, logger, prio, logBody: true }, useCallback); this.router = Router(); this._initializeRoutes(this.router); + this.app = new Koa(); this.app.use(this.router.middleware()); this.httpListener.listen() @@ -73,13 +67,17 @@ abstract class KoaController { protected abstract _destroy(): Promise; - public async destroy() { - this.httpListener.removeHandler('request', this._requestHandler); - return Promise.allSettled([ - Listener.release('http', this._port), + public async destroy(): Promise { + this.httpListener?.removeHandler(HttpRequestType.request, this._requestHandler); + await Promise.allSettled([ + Listener.release(this._port), this._destroy(), ]); } + + protected getBaseUrl(req: IncomingMessage): URL | undefined { + return ((req as any)._exposrBaseUrl as (URL | undefined)); + } } export default KoaController; \ No newline at end of file diff --git a/src/ingress/http-ingress.js b/src/ingress/http-ingress.js index a69f85d..f4f8737 100644 --- a/src/ingress/http-ingress.js +++ b/src/ingress/http-ingress.js @@ -3,7 +3,7 @@ import http, { Agent } from 'http'; import net from 'net'; import NodeCache from 'node-cache'; import EventBus from '../cluster/eventbus.js'; -import Listener from '../listener/index.js'; +import Listener from '../listener/listener.js'; import IngressUtils from './utils.js'; import { Logger } from '../logger.js'; import TunnelService from '../tunnel/tunnel-service.js'; @@ -28,6 +28,7 @@ import { HTTP_HEADER_X_FORWARDED_PROTO, HTTP_HEADER_FORWARDED } from '../utils/http-headers.js'; +import HttpListener, { HttpRequestType } from '../listener/http-listener.js'; class HttpIngress { @@ -45,17 +46,21 @@ class HttpIngress { this.altNameService = new AltNameService(); this.tunnelService = opts.tunnelService; assert(this.tunnelService instanceof TunnelService); - this.httpListener = Listener.acquire('http', opts.port); - this._requestHandler = this.httpListener.use('request', { logger: this.logger, prio: 1 }, async (ctx, next) => { + this.httpListener = Listener.acquire(HttpListener, opts.port); + + this._requestHandler = async (ctx, next) => { if (!await this.handleRequest(ctx.req, ctx.res, ctx.baseUrl)) { next(); } - }); - this._upgradeHandler = this.httpListener.use('upgrade', { logger: this.logger }, async (ctx, next) => { + }; + this.httpListener.use(HttpRequestType.request, { logger: this.logger, prio: 1 }, this._requestHandler); + + this._upgradeHandler = async (ctx, next) => { if (!await this.handleUpgradeRequest(ctx.req, ctx.sock, ctx.head, ctx.baseUrl)) { next(); } - }); + }; + this.httpListener.use(HttpRequestType.upgrade, { logger: this.logger }, this._upgradeHandler); this._agentCache = new NodeCache({ useClones: false, @@ -467,12 +472,12 @@ class HttpIngress { return; } this.destroyed = true; - this.httpListener.removeHandler('request', this._requestHandler); - this.httpListener.removeHandler('upgrade', this._upgradeHandler); + this.httpListener.removeHandler(HttpRequestType.request, this._requestHandler); + this.httpListener.removeHandler(HttpRequestType.upgrade, this._upgradeHandler); return Promise.allSettled([ this.altNameService.destroy(), this.eventBus.destroy(), - Listener.release('http', this.opts.port), + Listener.release(this.opts.port), ]); } diff --git a/src/listener/http-listener.js b/src/listener/http-listener.js deleted file mode 100644 index 025ef7d..0000000 --- a/src/listener/http-listener.js +++ /dev/null @@ -1,201 +0,0 @@ -import http from 'http'; -import ListenerInterface from './listener-interface.js'; -import { Logger } from '../logger.js'; -import HttpCaptor from '../utils/http-captor.js'; -import { - HTTP_HEADER_FORWARDED, - HTTP_HEADER_HOST, - HTTP_HEADER_X_FORWARDED_PORT, - HTTP_HEADER_X_FORWARDED_PROTO, - HTTP_HEADER_X_SCHEME -} from '../utils/http-headers.js'; - -class HttpListener extends ListenerInterface { - constructor(opts) { - super(); - this.logger = Logger("http-listener"); - this.opts = opts; - this.callbacks = { - 'request': [], - 'upgrade': [] - }; - this.state = opts.state || {}; - - const parseForwarded = (forwarded) => { - return Object.fromEntries(forwarded - .split(';') - .map(x => x.trim()) - .filter(x => x.length > 0) - .map(x => x.split('=') - .map(y => y.trim()) - ) - ) - }; - - const getBaseUrl = (req) => { - const headers = req.headers || {}; - - const forwarded = parseForwarded(headers[HTTP_HEADER_FORWARDED] || ''); - const proto = forwarded?.proto - || headers[HTTP_HEADER_X_FORWARDED_PROTO] - || headers[HTTP_HEADER_X_SCHEME] - || req.protocol || 'http'; - const host = (forwarded?.host || headers[HTTP_HEADER_HOST])?.split(':')[0]; - const port = forwarded?.host?.split(':')[1] - || headers[HTTP_HEADER_X_FORWARDED_PORT] - || headers[HTTP_HEADER_HOST]?.split(':')[1]; - - try { - return new URL(`${proto}://${host.toLowerCase()}${port ? `:${port}` : ''}`); - } catch (e) { - this.logger.isTraceEnabled() && this.logger.trace({e}); - return undefined; - } - }; - - const handleRequest = async (event, ctx) => { - const captor = new HttpCaptor({ - request: ctx.req, - response: ctx.res, - opts: { - limit: 4*1024, - } - }); - - let next = true; - let customLogger; - const capture = captor.capture(); - - ctx.baseUrl = getBaseUrl(ctx.req); - if (ctx.baseUrl !== undefined) { - for (const obj of this.callbacks[event]) { - if (obj.opts.host && obj.opts?.host?.toLowerCase() !== ctx.baseUrl.host) { - next = true; - continue; - } - captor.captureRequestBody = obj.opts?.logBody || false; - captor.captureResponseBody = obj.opts?.logBody || false; - try { - next = false; - await obj.callback(ctx, () => { next = true }); - if (!next) { - customLogger = obj.opts?.logger; - break; - } - } catch (e) { - this.logger.error(e.message); - this.logger.debug(e.stack); - ctx.res.statusCode = 500; - ctx.res.end(); - } - } - } else { - ctx.res.statusCode = 400; - ctx.res.end(); - } - - customLogger ??= this.logger; - setImmediate(() => { - capture.then((res) => { - if (customLogger === false) { - return; - } - const logEntry = { - operation: 'http-request', - request: res.request, - response: res.response, - client: { - ip: res.client.ip, - remote: res.client.remoteAddr, - }, - duration: res.duration, - }; - customLogger.info(logEntry); - }); - }); - return !next; - } - - const server = this.server = http.createServer(); - this._clients = new Set(); - server.on('connection', (sock) => { - this._clients.add(sock); - - sock.once('close', () => { - this._clients.delete(sock); - }); - }); - - server.on('request', async (req, res) => { - if (!await handleRequest('request', {req, res})) { - res.statusCode = 404; - res.end(); - } - }); - - server.on('upgrade', async (req, sock, head) => { - if (!await handleRequest('upgrade', {req, sock, head})) { - sock.write(`HTTP/${req.httpVersion} 404 Not found\r\n`); - sock.end(); - sock.destroy(); - } - }); - } - - setState(state) { - this.state = state; - } - - getPort() { - return this.opts.port; - } - - use(event, opts, callback) { - if (typeof opts === 'function') { - callback = opts; - opts = {}; - } - - if (this.callbacks[event] === undefined) { - throw new Error("Unknown event " + event); - } - - opts.prio ??= 2**32; - - const pos = this.callbacks[event].reduce((pos, x) => x.opts.prio <= opts.prio ? pos + 1 : pos, 0); - this.callbacks[event].splice(pos, 0, {callback, opts}) - return callback; - } - - removeHandler(event, callback) { - this.callbacks[event] = this.callbacks[event].filter(obj => obj.callback != callback); - } - - async _listen() { - const listenError = (err) => { - this.logger.error(`Failed to start http listener: ${err.message}`); - }; - this.server.once('error', listenError); - return new Promise((resolve, reject) => { - this.server.listen({port: this.opts.port}, (err) => { - if (err) { - return reject(err); - } - this.server.removeListener('error', listenError); - resolve(); - }); - }); - } - - async _destroy() { - return new Promise((resolve) => { - this.server.once('close', () => { - resolve(); - }); - this.server.close(); - this._clients.forEach((sock) => sock.destroy()); - }); - } -} - -export default HttpListener; \ No newline at end of file diff --git a/src/listener/http-listener.ts b/src/listener/http-listener.ts new file mode 100644 index 0000000..e34ea58 --- /dev/null +++ b/src/listener/http-listener.ts @@ -0,0 +1,255 @@ +import http from 'http'; +import { Logger } from '../logger.js'; +import HttpCaptor from '../utils/http-captor.js'; +import { + HTTP_HEADER_FORWARDED, + HTTP_HEADER_HOST, + HTTP_HEADER_X_FORWARDED_PORT, + HTTP_HEADER_X_FORWARDED_PROTO, + HTTP_HEADER_X_SCHEME +} from '../utils/http-headers.js'; +import { Duplex } from 'stream'; +import { ListenerBase } from './listener.js'; + +interface HttpListenerArguments { + port: number, +} + +interface HttpUseOptions { + logger?: any, + prio?: number, + logBody?: boolean, + host?: string, +} + +export type HttpRequestCallback = (ctx: HttpRequestContext, next: () => void) => Promise; +export type HttpUpgradeCallback = (ctx: HttpUpgradeContext, next: () => void) => Promise; +type HttpCallback = (ctx: HttpCallbackContext, next: () => void) => Promise; + +interface HttpCallbackOptions { + logger?: any, + prio: number, + host?: string, + logBody?: boolean, +} + +interface _HttpCallback { + callback: HttpCallback, + opts: HttpCallbackOptions, +} + +export enum HttpRequestType { + request = "request", + upgrade = "upgrade" +} + +interface HttpCallbackContext { + req: http.IncomingMessage, + baseUrl?: URL, +} + +interface HttpRequestContext extends HttpCallbackContext { + res: http.ServerResponse, +} + +interface HttpUpgradeContext extends HttpCallbackContext { + sock: Duplex, + head: Buffer, +} + +export default class HttpListener extends ListenerBase { + private logger: any; + private server: http.Server; + private callbacks: { [ key in HttpRequestType ]: Array<_HttpCallback> }; + + constructor(port: number) { + super(port); + this.logger = Logger("http-listener"); + this.callbacks = { + 'request': [], + 'upgrade': [] + }; + + const server = this.server = http.createServer(); + + server.on('request', async (req, res) => { + const [success, statusCode] = await this.handleRequest(HttpRequestType.request, {req, res}); + if (!success) { + res.statusCode = statusCode || 500; + res.end(); + } + }); + + server.on('upgrade', async (req, sock, head) => { + let [success, statusCode] = await this.handleRequest(HttpRequestType.upgrade, {req, sock, head}); + if (!success) { + statusCode ??= 500; + sock.write(`HTTP/${req.httpVersion} ${statusCode} ${http.STATUS_CODES[statusCode]}\r\n`); + sock.end(`\r\n`); + sock.destroy(); + } + }); + } + + protected async _destroy(): Promise { + return this.close(); + } + + protected async _close(): Promise { + return new Promise((resolve) => { + this.server.once('close', () => { + this.removeHandler(HttpRequestType.request); + this.removeHandler(HttpRequestType.upgrade); + this.server.removeAllListeners(); + resolve(); + }); + this.server.close(); + this.server.closeAllConnections(); + }); + } + + private static parseForwarded(forwarded: string): any { + return Object.fromEntries(forwarded + .split(';') + .map(x => x.trim()) + .filter(x => x.length > 0) + .map(x => x.split('=') + .map(y => y.trim()) + ) + ) + } + + private getBaseUrl(req: http.IncomingMessage): URL | undefined { + const headers = req.headers || {}; + + const forwarded = HttpListener.parseForwarded(headers[HTTP_HEADER_FORWARDED] || ''); + const proto = forwarded?.proto + || headers[HTTP_HEADER_X_FORWARDED_PROTO] + || headers[HTTP_HEADER_X_SCHEME] + || 'http'; + const host = (forwarded?.host || headers[HTTP_HEADER_HOST])?.split(':')[0]; + const port = forwarded?.host?.split(':')[1] + || headers[HTTP_HEADER_X_FORWARDED_PORT] + || headers[HTTP_HEADER_HOST]?.split(':')[1]; + + try { + return new URL(`${proto}://${host.toLowerCase()}${port ? `:${port}` : ''}`); + } catch (e) { + this.logger.isTraceEnabled() && this.logger.trace({e}); + return undefined; + } + }; + + private async handleRequest(event: HttpRequestType.upgrade, ctx: HttpUpgradeContext): Promise<[boolean, number | undefined]>; + private async handleRequest(event: HttpRequestType.request, ctx: HttpRequestContext): Promise<[boolean, number | undefined]>; + private async handleRequest(event: HttpRequestType, ctx: HttpCallbackContext): Promise<[boolean, number | undefined]> { + + const captor = new HttpCaptor({ + request: ctx.req, + response: (ctx as HttpRequestContext).res, + opts: { + limit: 4*1024, + } + }); + + let statusCode: number | undefined = undefined; + let next = true; + let customLogger: any; + const capture = captor.capture(); + + ctx.baseUrl = this.getBaseUrl(ctx.req); + if (ctx.baseUrl !== undefined) { + for (const obj of this.callbacks[event]) { + if (obj.opts.host && obj.opts?.host?.toLowerCase() !== ctx.baseUrl.host) { + next = true; + continue; + } + captor.captureRequestBody = obj.opts?.logBody || false; + captor.captureResponseBody = obj.opts?.logBody || false; + try { + next = false; + await obj.callback(ctx, () => { next = true }); + if (!next) { + customLogger = obj.opts?.logger; + break; + } + } catch (e: any) { + this.logger.error(e.message); + this.logger.debug(e.stack); + statusCode = 500; + } + } + } else { + statusCode = 400; + } + + customLogger ??= this.logger; + setImmediate(() => { + capture.then((res) => { + if (customLogger === false) { + return; + } + const logEntry = { + operation: 'http-request', + request: res.request, + response: res.response, + client: { + ip: res.client.ip, + remote: res.client.remoteAddr, + }, + duration: res.duration, + }; + customLogger.info(logEntry); + }); + }); + return [!next, statusCode]; + } + + public use(event: HttpRequestType.request, callback: HttpRequestCallback): this; + public use(event: HttpRequestType.request, opts: HttpUseOptions, callback: HttpRequestCallback): this; + public use(event: HttpRequestType.upgrade, callback: HttpUpgradeCallback): this; + public use(event: HttpRequestType.upgrade, opts: HttpUseOptions, callback: HttpUpgradeCallback): this; + public use(event: HttpRequestType, opts: any, callback?: any): this { + if (typeof opts === 'function') { + callback = opts; + opts = {}; + } + + if (this.callbacks[event] === undefined) { + throw new Error("Unknown event " + event); + } + + opts.prio ??= 2**32; + + const pos = this.callbacks[event].reduce((pos, x) => x.opts.prio <= opts.prio ? pos + 1 : pos, 0); + this.callbacks[event].splice(pos, 0, {callback, opts: { + logger: opts.logger, + prio: opts.prio, + logBody: opts.logBody, + host: opts.host, + }}) + return this; + } + + public removeHandler(event: HttpRequestType.request, callback?: HttpRequestCallback): void; + public removeHandler(event: HttpRequestType.upgrade, callback?: HttpUpgradeCallback): void; + public removeHandler(event: HttpRequestType, callback?: any): void { + this.callbacks[event] = this.callbacks[event].filter(obj => callback != undefined && obj.callback != callback); + } + + protected async _listen(): Promise { + return new Promise((resolve, reject) => { + const listenError = (err: Error) => { + this.logger.error(`Failed to start http listener: ${err.message}`); + reject(); + }; + this.server.once('error', listenError); + + this.server.listen({port: this.port}, () => { + this.server.off('error', listenError); + resolve(); + }); + }); + } + +} \ No newline at end of file diff --git a/src/listener/index.js b/src/listener/index.js deleted file mode 100644 index 98c5b07..0000000 --- a/src/listener/index.js +++ /dev/null @@ -1,42 +0,0 @@ -import assert from 'assert/strict'; -import HttpListener from './http-listener.js'; - -class Listener { - - static { - this._listeners = {} - } - - static _getNewListener(method, port, state) { - switch (method) { - case 'http': - return new HttpListener({port, state}); - default: - assert.fail(`unknown listener method ${method}`); - } - } - - static acquire(listener, port, state = {}) { - const k = `${listener}-${port}`; - if (!this._listeners[k]) { - this._listeners[k] = this._getNewListener(listener, port, state); - } else { - this._listeners[k].acquire(); - } - return this._listeners[k]; - } - - static async release(listener, port) { - const k = `${listener}-${port}`; - if (!this._listeners[k]) { - return; - } - const released = await this._listeners[k].destroy(); - if (released) { - delete this._listeners[k]; - } - } - -} - -export default Listener; \ No newline at end of file diff --git a/src/listener/listener-interface.js b/src/listener/listener-interface.js deleted file mode 100644 index 3fb6e98..0000000 --- a/src/listener/listener-interface.js +++ /dev/null @@ -1,65 +0,0 @@ -import assert from 'assert/strict'; - -class ListenerInterface { - constructor() { - this._ref = 1; - } - - acquire() { - this._ref++; - } - - async _listen() { - assert.fail("_listen not implemented"); - } - - async _destroy() { - assert.fail("_destroy not implemented"); - } - - async listen() { - if (this._listening) { - return new Promise((resolve) => { resolve() }); - } - - if (this._pending) { - return new Promise((resolve, reject) => { - const pending = (_err) => { - _err ? reject(_err) : resolve(); - }; - this._pending.push(pending); - }) - } - - return new Promise(async (resolve, reject) => { - this._listening = false; - this._pending = []; - - let err = undefined; - try { - await this._listen(); - this._listening = true; - } catch (e) { - err = e; - } - - this._pending.push((_err) => { - _err ? reject(_err) : resolve(); - }); - - this._pending.map((fn) => fn(err)); - delete this._pending; - }) - } - - async destroy() { - if (--this._ref == 0) { - this._destroyed = true; - await this._destroy(); - return true; - } - return false; - } -} - -export default ListenerInterface; \ No newline at end of file diff --git a/src/listener/listener.ts b/src/listener/listener.ts new file mode 100644 index 0000000..ac346f8 --- /dev/null +++ b/src/listener/listener.ts @@ -0,0 +1,122 @@ + +type ListenerPending = (err?: Error) => void; + +export default class Listener { + private static instances: Map> = new Map(); + + public static acquire(type: { new(port:number): T}, port: number): T { + if (this.instances.has(port)) { + const instance = this.instances.get(port) as T; + instance.acquire(); + return instance; + } else { + const instance = new type(port); + this.instances.set(port, instance as ListenerBase); + return instance; + } + } + + public static async release(port: number): Promise { + const instance = this.instances.get(port) as T; + if (!instance) { + return; + } + const release = await instance["destroy"](); + if (release) { + this.instances.delete(port); + } + } +} + +export abstract class ListenerBase { + private _ref: number; + private _listen_ref: number; + private _listening: boolean; + private _pending: Array | undefined; + private _destroyed: boolean; + public readonly port: number; + + constructor(port: number) { + this.port = port; + this._ref = 1; + this._listen_ref = 0; + this._listening = false; + this._pending = undefined; + this._destroyed = false; + } + + public getPort(): number { + return this.port; + } + + public acquire(): void { + this._ref++; + } + + protected abstract _listen(): Promise; + + protected abstract _destroy(): Promise; + + protected abstract _close(): Promise; + + public async listen(): Promise { + this._listen_ref++; + if (this._listening) { + return; + } + + if (this._pending != undefined) { + return new Promise((resolve, reject) => { + const pending = (_err?: Error) => { + _err ? reject(_err) : resolve(); + }; + this._pending!.push(pending); + }) + } + + return new Promise(async (resolve, reject) => { + this._listening = false; + this._pending = []; + + let err: Error | undefined = undefined; + try { + await this._listen(); + this._listening = true; + } catch (e: any) { + err = e; + } + + this._pending.push((_err) => { + _err ? reject(_err) : resolve(); + }); + + this._pending.map((fn) => fn(err)); + this._pending = undefined; + }); + } + + public async close(): Promise { + if (!this._listening) { + return; + } + if (--this._listen_ref == 0) { + await this._close(); + this._listening = false; + } + } + + protected async destroy(): Promise { + if (this._destroyed) { + return false; + } + if (--this._ref == 0) { + this._destroyed = true; + await this._close(); + this._listen_ref = 0; + this._listening = false; + await this._destroy(); + return true; + } + return false; + } +} \ No newline at end of file diff --git a/src/transport/transport-service.ts b/src/transport/transport-service.ts index 58556b6..b7ed1f6 100644 --- a/src/transport/transport-service.ts +++ b/src/transport/transport-service.ts @@ -92,9 +92,9 @@ class TransportService { } public getTransports(tunnel: Tunnel, baseUrl: string): TunnelTransports; - public getTransports(tunnel: Tunnel, baseUrl: URL): TunnelTransports; + public getTransports(tunnel: Tunnel, baseUrl: URL | undefined): TunnelTransports; public getTransports(tunnel: Tunnel, baseUrl: any): TunnelTransports { - let _baseUrl: URL; + let _baseUrl: URL | undefined; const transports: TunnelTransports = { max_connections: this.max_connections, @@ -106,12 +106,16 @@ class TransportService { try { _baseUrl = new URL(baseUrl); } catch (e: any) { - return transports; + _baseUrl = undefined; } } else { _baseUrl = baseUrl; } + if (_baseUrl == undefined) { + return transports; + } + if (this.transports.ws instanceof WebSocketEndpoint) { transports.ws = { enabled: tunnel.config.transport?.ws?.enabled || false, diff --git a/src/transport/ws/ws-endpoint.ts b/src/transport/ws/ws-endpoint.ts index a0947d9..e570727 100644 --- a/src/transport/ws/ws-endpoint.ts +++ b/src/transport/ws/ws-endpoint.ts @@ -1,7 +1,7 @@ import net from 'net'; import querystring from 'querystring'; import { WebSocket, WebSocketServer } from 'ws'; -import Listener from '../../listener/index.js'; +import Listener from '../../listener/listener.js'; import { Logger } from '../../logger.js'; import TunnelService from '../../tunnel/tunnel-service.js'; import { @@ -9,10 +9,11 @@ import { } from '../../utils/errors.js'; import WebSocketTransport from './ws-transport.js'; import TransportEndpoint, { EndpointResult, TransportEndpointOptions } from '../transport-endpoint.js'; -import HttpListener from '../../listener/http-listener.js'; +import HttpListener, { HttpRequestType, HttpUpgradeCallback } from '../../listener/http-listener.js'; import Tunnel from '../../tunnel/tunnel.js'; import { URL } from 'url'; import { IncomingMessage } from 'http'; +import { Duplex } from 'stream'; export type WebSocketEndpointOptions = { enabled: boolean, @@ -48,23 +49,25 @@ export default class WebSocketEndpoint extends TransportEndpoint { private httpListener: HttpListener; private tunnelService: TunnelService; private wss: WebSocketServer; - private _upgradeHandler: any; + private _upgradeHandler: HttpUpgradeCallback; private connections: Array; constructor(opts: _WebSocketEndpointOptions) { super(opts); this.opts = opts; this.logger = Logger("ws-endpoint"); - this.httpListener = Listener.acquire('http', opts.port); + this.httpListener = Listener.acquire(HttpListener, opts.port); this.tunnelService = new TunnelService(); this.wss = new WebSocketServer({ noServer: true }); this.connections = []; - this._upgradeHandler = this.httpListener.use('upgrade', { logger: this.logger }, async (ctx: any, next: any) => { + this._upgradeHandler = async (ctx, next) => { if (!await this.handleUpgrade(ctx.req, ctx.sock, ctx.head)) { return next(); } - }); + }; + + this.httpListener.use(HttpRequestType.upgrade, { logger: this.logger }, this._upgradeHandler); this.httpListener.listen() .then(() => { @@ -92,7 +95,7 @@ export default class WebSocketEndpoint extends TransportEndpoint { } protected async _destroy(): Promise { - this.httpListener.removeHandler('upgrade', this._upgradeHandler); + this.httpListener.removeHandler(HttpRequestType.upgrade, this._upgradeHandler); for (const connection of this.connections) { const {wst, ws} = connection; await wst.destroy(); @@ -102,7 +105,7 @@ export default class WebSocketEndpoint extends TransportEndpoint { this.wss.close(); await Promise.allSettled([ this.tunnelService.destroy(), - Listener.release('http', this.opts.port), + Listener.release(this.opts.port), ]); } @@ -135,7 +138,7 @@ export default class WebSocketEndpoint extends TransportEndpoint { }; } - private _unauthorized(sock: net.Socket, request: IncomingMessage): RawHttpResponse { + private _unauthorized(sock: Duplex, request: IncomingMessage): RawHttpResponse { const response = { status: 401, statusLine: 'Unauthorized' @@ -143,7 +146,7 @@ export default class WebSocketEndpoint extends TransportEndpoint { return this._rawHttpResponse(sock, request, response); }; - private _rawHttpResponse(sock: net.Socket, request: IncomingMessage, response: RawHttpResponse): RawHttpResponse { + private _rawHttpResponse(sock: Duplex, request: IncomingMessage, response: RawHttpResponse): RawHttpResponse { sock.write(`HTTP/${request.httpVersion} ${response.status} ${response.statusLine}\r\n`); sock.write('\r\n'); response.body && sock.write(response.body); @@ -151,12 +154,7 @@ export default class WebSocketEndpoint extends TransportEndpoint { return response; } - async handleUpgrade(req: IncomingMessage, sock: net.Socket, head: Buffer) { - - //if (req.upgrade !== true) { - // this.logger.trace("upgrade called on non-upgrade request"); - // return undefined; - //} + async handleUpgrade(req: IncomingMessage, sock: Duplex, head: Buffer) { const parsed = this._parseRequest(req); if (parsed == undefined) { diff --git a/test/unit/listener/test_http-listener.ts b/test/unit/listener/test_http-listener.ts new file mode 100644 index 0000000..62fd40d --- /dev/null +++ b/test/unit/listener/test_http-listener.ts @@ -0,0 +1,255 @@ +import assert from "assert"; +import HttpListener, { HttpRequestCallback, HttpRequestType, HttpUpgradeCallback } from "../../../src/listener/http-listener.js"; +import Listener from "../../../src/listener/listener.js"; +import http from 'http'; +import { setTimeout } from "timers/promises"; +import { Socket } from "net"; + +describe('HTTP listener', () => { + + it(`can listen on port`, async () => { + const httpListener = Listener.acquire(HttpListener, 8080); + + const requestHandler: HttpRequestCallback = async (ctx, next): Promise => { + ctx.res.statusCode = 201; + ctx.res.end('foo') + }; + + httpListener.use(HttpRequestType.request, requestHandler); + assert(httpListener["callbacks"]["request"].length == 1); + + await httpListener.listen() + + let res = await fetch("http://localhost:8080"); + assert(res.status == 201); + + let data = await res.text(); + assert(data == 'foo'); + + httpListener.removeHandler(HttpRequestType.request, requestHandler); + assert(httpListener["callbacks"]["request"].length == 0); + + + await Listener.release(8080); + assert(httpListener["_destroyed"] == true); + + try { + await fetch("http://localhost:8080"); + assert(false, "listener is still listening"); + } catch (e:any) { + assert(true); + } + }); + + it(`destroy removes installed handlers`, async () => { + const httpListener = Listener.acquire(HttpListener, 8080); + + const requestHandler: HttpRequestCallback = async (ctx, next): Promise => { + ctx.res.statusCode = 201; + ctx.res.end('foo') + }; + + httpListener.use(HttpRequestType.request, requestHandler); + await httpListener.listen() + + await httpListener.close() + + await Listener.release(8080); + + assert(httpListener["_destroyed"] == true); + assert(httpListener["callbacks"]["request"].length == 0, "handler still installed"); + }); + + it(`listener can be acquired multiple times`, async () => { + const httpListener = Listener.acquire(HttpListener, 8080); + const httpListener2 = Listener.acquire(HttpListener, 8080); + assert(httpListener == httpListener2) + + const requestHandler: HttpRequestCallback = async (ctx, next): Promise => { + ctx.res.statusCode = 201; + ctx.res.end('foo') + }; + httpListener.use(HttpRequestType.request, requestHandler); + await httpListener.listen() + + let res = await fetch("http://localhost:8080"); + assert(res.status == 201); + let data = await res.text(); + assert(data == 'foo'); + + const requestHandler2: HttpRequestCallback = async (ctx, next): Promise => { + ctx.res.statusCode = 200; + ctx.res.end('bar') + }; + httpListener2.use(HttpRequestType.request, requestHandler2); + await httpListener2.listen() + + res = await fetch("http://localhost:8080"); + assert(res.status == 201); + data = await res.text(); + assert(data == 'foo', `got ${data}`); + + httpListener.removeHandler(HttpRequestType.request, requestHandler); + await httpListener.close(); + + res = await fetch("http://localhost:8080"); + assert(res.status == 200); + data = await res.text(); + assert(data == 'bar'); + + await Listener.release(8080); + assert(httpListener["_destroyed"] == false); + await Listener.release(8080); + assert(httpListener["_destroyed"] == true); + assert(httpListener["callbacks"]["request"].length == 0, "handler still installed"); + }); + + it(`callback can pass request to next handler`, async () => { + const httpListener = Listener.acquire(HttpListener, 8080); + const httpListener2 = Listener.acquire(HttpListener, 8080); + assert(httpListener == httpListener2) + + await Promise.all([httpListener.listen(), httpListener2.listen()]); + + const requestHandler: HttpRequestCallback = async (ctx, next): Promise => { + next(); + }; + httpListener.use(HttpRequestType.request, requestHandler); + + const requestHandler2: HttpRequestCallback = async (ctx, next): Promise => { + ctx.res.statusCode = 200; + ctx.res.end('bar') + }; + httpListener2.use(HttpRequestType.request, requestHandler2); + + let res = await fetch("http://localhost:8080"); + assert(res.status == 200); + let data = await res.text(); + assert(data == 'bar'); + + await Listener.release(8080); + assert(httpListener["_destroyed"] == false); + await Listener.release(8080); + assert(httpListener["_destroyed"] == true); + }); + + it(`listener on different ports return different instances`, async () => { + const httpListener = Listener.acquire(HttpListener, 8080); + const httpListener2 = Listener.acquire(HttpListener, 9090); + assert(httpListener != httpListener2) + + await Promise.all([httpListener.listen(), httpListener2.listen()]); + + const requestHandler: HttpRequestCallback = async (ctx, next): Promise => { + ctx.res.statusCode = 201; + ctx.res.end('foo') + }; + httpListener.use(HttpRequestType.request, requestHandler); + + const requestHandler2: HttpRequestCallback = async (ctx, next): Promise => { + ctx.res.statusCode = 200; + ctx.res.end('bar') + }; + httpListener2.use(HttpRequestType.request, requestHandler2); + + let res = await fetch("http://localhost:8080"); + assert(res.status == 201); + let data = await res.text(); + assert(data == 'foo', `got ${data}`); + + httpListener.removeHandler(HttpRequestType.request, requestHandler); + await httpListener.close(); + + res = await fetch("http://localhost:9090"); + assert(res.status == 200); + data = await res.text(); + assert(data == 'bar'); + + await Listener.release(8080); + await Listener.release(9090); + }); + + it(`callback handlers can be added with different priorities`, async () => { + const httpListener = Listener.acquire(HttpListener, 8080); + const httpListener2 = Listener.acquire(HttpListener, 8080); + assert(httpListener == httpListener2) + + await Promise.all([httpListener.listen(), httpListener2.listen()]); + + const requestHandler: HttpRequestCallback = async (ctx, next): Promise => { + ctx.res.statusCode = 201; + ctx.res.end('foo') + }; + httpListener.use(HttpRequestType.request, requestHandler); + + const requestHandler2: HttpRequestCallback = async (ctx, next): Promise => { + ctx.res.statusCode = 200; + ctx.res.end('bar') + }; + httpListener2.use(HttpRequestType.request, {prio: 1}, requestHandler2); + + let res = await fetch("http://localhost:8080"); + assert(res.status == 200); + let data = await res.text(); + assert(data == 'bar'); + + await Listener.release(8080); + assert(httpListener["_destroyed"] == false); + await Listener.release(8080); + assert(httpListener["_destroyed"] == true); + }); + + it(`can install an upgrade handler`, async () => { + const httpListener = Listener.acquire(HttpListener, 8080); + await httpListener.listen(); + + const upgradeHandler: HttpUpgradeCallback = async (ctx, next): Promise => { + ctx.sock.write(`HTTP/${ctx.req.httpVersion} 101 ${http.STATUS_CODES[101]}\r\n`); + ctx.sock.write('Upgrade: someprotocol\r\n'); + ctx.sock.write('Connection: Upgrade\r\n'); + ctx.sock.write('\r\n'); + + ctx.sock.write("upgraded"); + ctx.sock.end(); + }; + httpListener.use(HttpRequestType.upgrade, upgradeHandler); + + const req = http.request({ + hostname: 'localhost', + port: 8080, + method: 'GET', + path: '/', + headers: { + "Host": "localhost", + "Connection": 'Upgrade', + "Upgrade": 'someprotocol', + "Origin": `http://localhost`, + } + }); + + const done = (resolve: (value: any) => void) => { + req.on('upgrade', (res, socket, head) => { + resolve(head.toString()); + }); + }; + req.end(); + + let data = await new Promise(done); + assert(data == 'upgraded'); + + httpListener.removeHandler(HttpRequestType.upgrade, upgradeHandler); + await httpListener.close(); + await Listener.release(8080); + }); + + it(`request without a handler returns 500`, async () => { + const httpListener = Listener.acquire(HttpListener, 8080); + await httpListener.listen(); + + let res = await fetch("http://localhost:8080"); + assert(res.status == 500); + + await httpListener.close(); + await Listener.release(8080); + }); +}); \ No newline at end of file