From b22fb914cf5da29cb1454f1a631c4f6c5e195b92 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 28 Nov 2024 08:47:22 +0000 Subject: [PATCH] Record what reserved bits are valid Throw protocol error on receiving a frame with a reserved bit set that isnt used. --- .../PerMessageDeflateExtension.swift | 3 +++ .../WSCore/Extensions/WebSocketExtension.swift | 7 +++++++ Sources/WSCore/WebSocketHandler.swift | 15 +++++++++++++++ Sources/WSCore/WebSocketInboundStream.swift | 3 +++ 4 files changed, 28 insertions(+) diff --git a/Sources/WSCompression/PerMessageDeflateExtension.swift b/Sources/WSCompression/PerMessageDeflateExtension.swift index 1d613dd..8f93304 100644 --- a/Sources/WSCompression/PerMessageDeflateExtension.swift +++ b/Sources/WSCompression/PerMessageDeflateExtension.swift @@ -296,6 +296,9 @@ struct PerMessageDeflateExtension: WebSocketExtension { } return frame } + + /// Reserved bits extension uses + var reservedBits: WebSocketFrame.ReservedBits { .rsv1 } } extension WebSocketExtensionFactory { diff --git a/Sources/WSCore/Extensions/WebSocketExtension.swift b/Sources/WSCore/Extensions/WebSocketExtension.swift index c333df0..98c4cc2 100644 --- a/Sources/WSCore/Extensions/WebSocketExtension.swift +++ b/Sources/WSCore/Extensions/WebSocketExtension.swift @@ -35,6 +35,13 @@ public protocol WebSocketExtension: Sendable { func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame /// Process frame about to be sent to websocket func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame + /// Reserved bits extension uses + var reservedBits: WebSocketFrame.ReservedBits { get } /// shutdown extension func shutdown() async } + +extension WebSocketExtension { + /// Reserved bits extension uses (default is none) + public var reservedBits: WebSocketFrame.ReservedBits { .init() } +} diff --git a/Sources/WSCore/WebSocketHandler.swift b/Sources/WSCore/WebSocketHandler.swift index c47785e..9668e11 100644 --- a/Sources/WSCore/WebSocketHandler.swift +++ b/Sources/WSCore/WebSocketHandler.swift @@ -71,11 +71,16 @@ public struct WebSocketCloseFrame: Sendable { let extensions: [any WebSocketExtension] let autoPing: AutoPingSetup let validateUTF8: Bool + let reservedBits: WebSocketFrame.ReservedBits @_spi(WSInternal) public init(extensions: [any WebSocketExtension], autoPing: AutoPingSetup, validateUTF8: Bool) { self.extensions = extensions self.autoPing = autoPing self.validateUTF8 = validateUTF8 + // store reserved bits used by this handler + self.reservedBits = extensions.reduce(.init()) { partialResult, `extension` in + partialResult.union(`extension`.reservedBits) + } } } @@ -301,6 +306,16 @@ public struct WebSocketCloseFrame: Sendable { } func receivedClose(_ frame: WebSocketFrame) async throws { + guard frame.reservedBits.isEmpty else { + try await self.sendClose(code: .protocolError, reason: nil) + // Only server should initiate a connection close. Clients should wait for the + // server to close the connection when it receives the WebSocket close packet + // See https://www.rfc-editor.org/rfc/rfc6455#section-7.1.1 + if self.type == .server { + self.outbound.finish() + } + return + } switch self.stateMachine.receivedClose(frameData: frame.unmaskedData, validateUTF8: self.configuration.validateUTF8) { case .sendClose(let errorCode): try await self.sendClose(code: errorCode, reason: nil) diff --git a/Sources/WSCore/WebSocketInboundStream.swift b/Sources/WSCore/WebSocketInboundStream.swift index ad34519..40aa3c8 100644 --- a/Sources/WSCore/WebSocketInboundStream.swift +++ b/Sources/WSCore/WebSocketInboundStream.swift @@ -67,6 +67,9 @@ public final class WebSocketInboundStream: AsyncSequence, Sendable { case .pong: try await self.handler.onPong(frame) case .text, .binary, .continuation: + guard self.handler.configuration.reservedBits.contains(frame.reservedBits) else { + throw WebSocketHandler.InternalError.close(.protocolError) + } // apply extensions var frame = frame for ext in self.handler.configuration.extensions.reversed() {