From 15d83f860c840531f55ded26bcf7cad142449882 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 27 Nov 2024 16:30:05 +0000 Subject: [PATCH] run swift-format --- Package.swift | 65 ++++++++----- .../WSClient/Client/ClientConnection.swift | 20 ++-- Sources/WSClient/Client/Parser.swift | 51 ++++++---- Sources/WSClient/Client/TSTLSOptions.swift | 13 ++- Sources/WSClient/Client/URI.swift | 24 ++--- Sources/WSClient/WebSocketClientChannel.swift | 10 +- .../PerMessageDeflateExtension.swift | 47 ++++----- .../WebSocketExtensionBuilder.swift | 4 +- Sources/WSCore/String+validatingString.swift | 2 +- Sources/WSCore/WebSocketDataFrame.swift | 4 +- Sources/WSCore/WebSocketHandler.swift | 18 +++- Sources/WSCore/WebSocketStateMachine.swift | 32 ++++--- Tests/WebSocketTests/AutobahnTests.swift | 95 +++++++++++++++---- .../WebSocketExtensionNegotiationTests.swift | 23 +++-- .../WebSocketStateMachineTests.swift | 63 +++++++++--- 15 files changed, 309 insertions(+), 162 deletions(-) diff --git a/Package.swift b/Package.swift index 3ffb649..73c1fe9 100644 --- a/Package.swift +++ b/Package.swift @@ -24,32 +24,47 @@ let package = Package( .package(url: "https://github.com/swift-server/swift-service-lifecycle", from: "2.0.0"), ], targets: [ - .target(name: "WSClient", dependencies: [ - .byName(name: "WSCore"), - .product(name: "HTTPTypes", package: "swift-http-types"), - .product(name: "Logging", package: "swift-log"), - .product(name: "NIOCore", package: "swift-nio"), - .product(name: "NIOHTTPTypesHTTP1", package: "swift-nio-extras"), - .product(name: "NIOPosix", package: "swift-nio"), - .product(name: "NIOSSL", package: "swift-nio-ssl"), - .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), - .product(name: "NIOWebSocket", package: "swift-nio"), - ], swiftSettings: swiftSettings), - .target(name: "WSCore", dependencies: [ - .product(name: "HTTPTypes", package: "swift-http-types"), - .product(name: "NIOCore", package: "swift-nio"), - .product(name: "NIOWebSocket", package: "swift-nio"), - .product(name: "ServiceLifecycle", package: "swift-service-lifecycle"), - ], swiftSettings: swiftSettings), - .target(name: "WSCompression", dependencies: [ - .byName(name: "WSCore"), - .product(name: "CompressNIO", package: "compress-nio"), - ], swiftSettings: swiftSettings), + .target( + name: "WSClient", + dependencies: [ + .byName(name: "WSCore"), + .product(name: "HTTPTypes", package: "swift-http-types"), + .product(name: "Logging", package: "swift-log"), + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOHTTPTypesHTTP1", package: "swift-nio-extras"), + .product(name: "NIOPosix", package: "swift-nio"), + .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), + .product(name: "NIOWebSocket", package: "swift-nio"), + ], + swiftSettings: swiftSettings + ), + .target( + name: "WSCore", + dependencies: [ + .product(name: "HTTPTypes", package: "swift-http-types"), + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOWebSocket", package: "swift-nio"), + .product(name: "ServiceLifecycle", package: "swift-service-lifecycle"), + ], + swiftSettings: swiftSettings + ), + .target( + name: "WSCompression", + dependencies: [ + .byName(name: "WSCore"), + .product(name: "CompressNIO", package: "compress-nio"), + ], + swiftSettings: swiftSettings + ), - .testTarget(name: "WebSocketTests", dependencies: [ - .byName(name: "WSClient"), - .byName(name: "WSCompression"), - ]), + .testTarget( + name: "WebSocketTests", + dependencies: [ + .byName(name: "WSClient"), + .byName(name: "WSCompression"), + ] + ), ], swiftLanguageVersions: [.v5, .version("6")] ) diff --git a/Sources/WSClient/Client/ClientConnection.swift b/Sources/WSClient/Client/ClientConnection.swift index fb373a3..6fd623f 100644 --- a/Sources/WSClient/Client/ClientConnection.swift +++ b/Sources/WSClient/Client/ClientConnection.swift @@ -15,11 +15,12 @@ import Logging import NIOCore import NIOPosix +import NIOWebSocket + #if canImport(Network) import Network import NIOTransportServices #endif -import NIOWebSocket /// A generic client connection to a server. /// @@ -104,7 +105,9 @@ public struct ClientConnection: Sendable bootstrap = tsBootstrap } else { #if os(iOS) || os(tvOS) - self.logger.warning("Running BSD sockets on iOS or tvOS is not recommended. Please use NIOTSEventLoopGroup, to run with the Network framework") + self.logger.warning( + "Running BSD sockets on iOS or tvOS is not recommended. Please use NIOTSEventLoopGroup, to run with the Network framework" + ) #endif bootstrap = self.createSocketsBootstrap() } @@ -117,13 +120,15 @@ public struct ClientConnection: Sendable do { switch address.value { case .hostname(let host, let port): - result = try await bootstrap + result = + try await bootstrap .connect(host: host, port: port) { channel in clientChannel.setup(channel: channel, logger: self.logger) } self.logger.debug("Client connnected to \(host):\(port)") case .unixDomainSocket(let path): - result = try await bootstrap + result = + try await bootstrap .connect(unixDomainSocketPath: path) { channel in clientChannel.setup(channel: channel, logger: self.logger) } @@ -137,15 +142,16 @@ public struct ClientConnection: Sendable /// create a BSD sockets based bootstrap private func createSocketsBootstrap() -> ClientBootstrap { - return ClientBootstrap(group: self.eventLoopGroup) + ClientBootstrap(group: self.eventLoopGroup) .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) } #if canImport(Network) /// create a NIOTransportServices bootstrap using Network.framework private func createTSBootstrap() -> NIOTSConnectionBootstrap? { - guard let bootstrap = NIOTSConnectionBootstrap(validatingGroup: self.eventLoopGroup)? - .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) + guard + let bootstrap = NIOTSConnectionBootstrap(validatingGroup: self.eventLoopGroup)? + .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) else { return nil } diff --git a/Sources/WSClient/Client/Parser.swift b/Sources/WSClient/Client/Parser.swift index 206be80..1741753 100644 --- a/Sources/WSClient/Client/Parser.swift +++ b/Sources/WSClient/Client/Parser.swift @@ -51,12 +51,12 @@ struct Parser: Sendable { /// Return contents of parser as a string var count: Int { - return self.range.count + self.range.count } /// Return contents of parser as a string var string: String { - return makeString(self.buffer[self.range]) + makeString(self.buffer[self.range]) } private var buffer: [UInt8] @@ -74,12 +74,12 @@ extension Parser { self.range = range precondition(range.startIndex >= 0 && range.endIndex <= self.buffer.endIndex) - precondition(range.startIndex == self.buffer.endIndex || self.buffer[range.startIndex] & 0xC0 != 0x80) // check we arent in the middle of a UTF8 character + precondition(range.startIndex == self.buffer.endIndex || self.buffer[range.startIndex] & 0xC0 != 0x80) // check we arent in the middle of a UTF8 character } /// initialise a parser that parses a section of the buffer attached to this parser func subParser(_ range: Range) -> Parser { - return Parser(self, range: range) + Parser(self, range: range) } } @@ -99,7 +99,10 @@ extension Parser { mutating func read(_ char: Unicode.Scalar) throws -> Bool { let initialIndex = self.index let c = try character() - guard c == char else { self.index = initialIndex; return false } + guard c == char else { + self.index = initialIndex + return false + } return true } @@ -110,7 +113,10 @@ extension Parser { mutating func read(_ characterSet: Set) throws -> Bool { let initialIndex = self.index let c = try character() - guard characterSet.contains(c) else { self.index = initialIndex; return false } + guard characterSet.contains(c) else { + self.index = initialIndex + return false + } return true } @@ -122,7 +128,10 @@ extension Parser { let initialIndex = self.index guard string.count > 0 else { throw Error.emptyString } let subString = try read(count: string.count) - guard subString.string == string else { self.index = initialIndex; return false } + guard subString.string == string else { + self.index = initialIndex + return false + } return true } @@ -273,7 +282,7 @@ extension Parser { @discardableResult mutating func read(while: Unicode.Scalar) -> Int { var count = 0 while !self.reachedEnd(), - unsafeCurrent() == `while` + unsafeCurrent() == `while` { unsafeAdvance() count += 1 @@ -287,7 +296,7 @@ extension Parser { @discardableResult mutating func read(while characterSet: Set) -> Parser { let startIndex = self.index while !self.reachedEnd(), - characterSet.contains(unsafeCurrent()) + characterSet.contains(unsafeCurrent()) { unsafeAdvance() } @@ -300,7 +309,7 @@ extension Parser { @discardableResult mutating func read(while: (Unicode.Scalar) -> Bool) -> Parser { let startIndex = self.index while !self.reachedEnd(), - `while`(unsafeCurrent()) + `while`(unsafeCurrent()) { unsafeAdvance() } @@ -313,7 +322,7 @@ extension Parser { @discardableResult mutating func read(while keyPath: KeyPath) -> Parser { let startIndex = self.index while !self.reachedEnd(), - unsafeCurrent()[keyPath: keyPath] + unsafeCurrent()[keyPath: keyPath] { unsafeAdvance() } @@ -342,7 +351,7 @@ extension Parser { /// Return whether we have reached the end of the buffer /// - Returns: Have we reached the end func reachedEnd() -> Bool { - return self.index == self.range.endIndex + self.index == self.range.endIndex } } @@ -422,7 +431,7 @@ extension Parser: Sequence { public typealias Element = Unicode.Scalar public func makeIterator() -> Iterator { - return Iterator(self) + Iterator(self) } public struct Iterator: IteratorProtocol { @@ -442,22 +451,22 @@ extension Parser: Sequence { } // internal versions without checks -private extension Parser { - func unsafeCurrent() -> Unicode.Scalar { - return decodeUTF8Character(at: self.index).0 +extension Parser { + fileprivate func unsafeCurrent() -> Unicode.Scalar { + decodeUTF8Character(at: self.index).0 } - mutating func unsafeCurrentAndAdvance() -> Unicode.Scalar { + fileprivate mutating func unsafeCurrentAndAdvance() -> Unicode.Scalar { let (unicodeScalar, index) = decodeUTF8Character(at: self.index) self.index = index return unicodeScalar } - mutating func _setPosition(_ index: Int) { + fileprivate mutating func _setPosition(_ index: Int) { self.index = index } - func makeString(_ bytes: Bytes) -> String where Bytes.Element == UInt8, Bytes.Index == Int { + fileprivate func makeString(_ bytes: Bytes) -> String where Bytes.Element == UInt8, Bytes.Index == Int { if let string = bytes.withContiguousStorageIfAvailable({ String(decoding: $0, as: Unicode.UTF8.self) }) { return string } else { @@ -624,7 +633,7 @@ extension Parser { do { if #available(macOS 11, macCatalyst 14.0, iOS 14.0, tvOS 14.0, *) { return try String(unsafeUninitializedCapacity: range.endIndex - index) { bytes -> Int in - return try _percentDecode(self.buffer[self.index.. Self { - return .init(secIdentity: secIdentity) + .init(secIdentity: secIdentity) } public static func p12(filename: String, password: String) throws -> Self { @@ -101,7 +101,7 @@ public struct TSTLSOptions: Sendable { /// TSTLSOptions holding options public static func options(_ options: NWProtocolTLS.Options) -> Self { - return .init(value: .some(options)) + .init(value: .some(options)) } public static func options( @@ -117,7 +117,9 @@ public struct TSTLSOptions: Sendable { } public static func options( - clientIdentity: Identity, trustRoots: Certificates = .none, serverName: String? = nil + clientIdentity: Identity, + trustRoots: Certificates = .none, + serverName: String? = nil ) -> Self? { let options = NWProtocolTLS.Options() @@ -143,7 +145,8 @@ public struct TSTLSOptions: Sendable { } sec_protocol_verify_complete(result) } - }, Self.tlsDispatchQueue + }, + Self.tlsDispatchQueue ) } return .init(value: .some(options)) @@ -151,7 +154,7 @@ public struct TSTLSOptions: Sendable { /// Empty TSTLSOptions public static var none: Self { - return .init(value: .none) + .init(value: .none) } var options: NWProtocolTLS.Options? { diff --git a/Sources/WSClient/Client/URI.swift b/Sources/WSClient/Client/URI.swift index bd816bd..8076728 100644 --- a/Sources/WSClient/Client/URI.swift +++ b/Sources/WSClient/Client/URI.swift @@ -21,27 +21,27 @@ struct URI: Sendable, CustomStringConvertible, ExpressibleByStringLiteral { self.rawValue = rawValue } - static var http: Self { return .init(rawValue: "http") } - static var https: Self { return .init(rawValue: "https") } - static var unix: Self { return .init(rawValue: "unix") } - static var http_unix: Self { return .init(rawValue: "http_unix") } - static var https_unix: Self { return .init(rawValue: "https_unix") } - static var ws: Self { return .init(rawValue: "ws") } - static var wss: Self { return .init(rawValue: "wss") } + static var http: Self { .init(rawValue: "http") } + static var https: Self { .init(rawValue: "https") } + static var unix: Self { .init(rawValue: "unix") } + static var http_unix: Self { .init(rawValue: "http_unix") } + static var https_unix: Self { .init(rawValue: "https_unix") } + static var ws: Self { .init(rawValue: "ws") } + static var wss: Self { .init(rawValue: "wss") } } let string: String /// URL scheme - var scheme: Scheme? { return self._scheme.map { .init(rawValue: $0.string) } } + var scheme: Scheme? { self._scheme.map { .init(rawValue: $0.string) } } /// URL host - var host: String? { return self._host.map(\.string) } + var host: String? { self._host.map(\.string) } /// URL port - var port: Int? { return self._port.map { Int($0.string) } ?? nil } + var port: Int? { self._port.map { Int($0.string) } ?? nil } /// URL path - var path: String { return self._path.map(\.string) ?? "/" } + var path: String { self._path.map(\.string) ?? "/" } /// URL query - var query: String? { return self._query.map { String($0.string) }} + var query: String? { self._query.map { String($0.string) } } private let _scheme: Parser? private let _host: Parser? diff --git a/Sources/WSClient/WebSocketClientChannel.swift b/Sources/WSClient/WebSocketClientChannel.swift index c53d700..0d5de1c 100644 --- a/Sources/WSClient/WebSocketClientChannel.swift +++ b/Sources/WSClient/WebSocketClientChannel.swift @@ -68,10 +68,12 @@ struct WebSocketClientChannel: ClientConnectionChannel { let additionalHeaders = HTTPHeaders(self.configuration.additionalHeaders) headers.add(contentsOf: additionalHeaders) // add websocket extensions to headers - headers.add(contentsOf: self.configuration.extensions.compactMap { - let requestHeaders = $0.clientRequestHeader() - return requestHeaders != "" ? ("Sec-WebSocket-Extensions", requestHeaders) : nil - }) + headers.add( + contentsOf: self.configuration.extensions.compactMap { + let requestHeaders = $0.clientRequestHeader() + return requestHeaders != "" ? ("Sec-WebSocket-Extensions", requestHeaders) : nil + } + ) let requestHead = HTTPRequestHead( version: .http1_1, diff --git a/Sources/WSCompression/PerMessageDeflateExtension.swift b/Sources/WSCompression/PerMessageDeflateExtension.swift index f3a2d82..1d613dd 100644 --- a/Sources/WSCompression/PerMessageDeflateExtension.swift +++ b/Sources/WSCompression/PerMessageDeflateExtension.swift @@ -105,16 +105,18 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { let clientNoContextTakeoverParam = response.parameters["client_no_context_takeover"] != nil let serverMaxWindowParam = response.parameters["server_max_window_bits"]?.integer let serverNoContextTakeoverParam = response.parameters["server_no_context_takeover"] != nil - return try PerMessageDeflateExtension(configuration: .init( - receiveMaxWindow: serverMaxWindowParam, - receiveNoContextTakeover: serverNoContextTakeoverParam, - sendMaxWindow: clientMaxWindowParam, - sendNoContextTakeover: clientNoContextTakeoverParam, - compressionLevel: self.compressionLevel, - memoryLevel: self.memoryLevel, - maxDecompressedFrameSize: self.maxDecompressedFrameSize, - minFrameSizeToCompress: self.minFrameSizeToCompress - )) + return try PerMessageDeflateExtension( + configuration: .init( + receiveMaxWindow: serverMaxWindowParam, + receiveNoContextTakeover: serverNoContextTakeoverParam, + sendMaxWindow: clientMaxWindowParam, + sendNoContextTakeover: clientNoContextTakeoverParam, + compressionLevel: self.compressionLevel, + memoryLevel: self.memoryLevel, + maxDecompressedFrameSize: self.maxDecompressedFrameSize, + minFrameSizeToCompress: self.minFrameSizeToCompress + ) + ) } private func responseConfiguration(to request: WebSocketExtensionHTTPParameters) -> PerMessageDeflateExtension.Configuration { @@ -123,16 +125,15 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { let requestClientMaxWindow = request.parameters["client_max_window_bits"] let requestClientNoContextTakeover = request.parameters["client_no_context_takeover"] != nil - let receiveMaxWindow: Int? - // calculate client max window. If parameter doesn't exist then server cannot set it, if it does - // exist then the value should be set to minimum of both values, or the value of the other if - // one is nil - = if let requestClientMaxWindow - { - optionalMin(requestClientMaxWindow.integer, self.clientMaxWindow) - } else { - nil - } + // calculate client max window. If parameter doesn't exist then server cannot set it, if it does + // exist then the value should be set to minimum of both values, or the value of the other if + // one is nil + let receiveMaxWindow: Int? = + if let requestClientMaxWindow { + optionalMin(requestClientMaxWindow.integer, self.clientMaxWindow) + } else { + nil + } return PerMessageDeflateExtension.Configuration( receiveMaxWindow: receiveMaxWindow, @@ -280,7 +281,7 @@ struct PerMessageDeflateExtension: WebSocketExtension { func shutdown() async {} func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { - return try await self.decompressor.decompress( + try await self.decompressor.decompress( frame, maxSize: self.configuration.maxDecompressedFrameSize, resetStream: self.configuration.receiveNoContextTakeover, @@ -310,7 +311,7 @@ extension WebSocketExtensionFactory { maxDecompressedFrameSize: Int = 1 << 14, minFrameSizeToCompress: Int = 256 ) -> WebSocketExtensionFactory { - return .init { + .init { PerMessageDeflateExtensionBuilder( clientMaxWindow: maxWindow, clientNoContextTakeover: noContextTakeover, @@ -346,7 +347,7 @@ extension WebSocketExtensionFactory { maxDecompressedFrameSize: Int = 1 << 14, minFrameSizeToCompress: Int = 256 ) -> WebSocketExtensionFactory { - return .init { + .init { PerMessageDeflateExtensionBuilder( clientMaxWindow: clientMaxWindow, clientNoContextTakeover: clientNoContextTakeover, diff --git a/Sources/WSCore/Extensions/WebSocketExtensionBuilder.swift b/Sources/WSCore/Extensions/WebSocketExtensionBuilder.swift index 1cb6c9e..cf41e40 100644 --- a/Sources/WSCore/Extensions/WebSocketExtensionBuilder.swift +++ b/Sources/WSCore/Extensions/WebSocketExtensionBuilder.swift @@ -100,7 +100,7 @@ extension WebSocketNonNegotiableExtensionBuilder { public func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? { self.build() } } -extension Array { +extension [any WebSocketExtensionBuilder] { /// Build client extensions from response from WebSocket server /// - Parameter responseHeaders: Server response headers /// - Returns: Array of client extensions to enable @@ -144,7 +144,7 @@ public struct WebSocketExtensionFactory: Sendable { /// - Parameter build: closure creating extension /// - Returns: WebSocketExtensionFactory public static func nonNegotiatedExtension(_ build: @escaping @Sendable () -> some WebSocketExtension) -> Self { - return .init { + .init { WebSocketNonNegotiableExtensionBuilder(build) } } diff --git a/Sources/WSCore/String+validatingString.swift b/Sources/WSCore/String+validatingString.swift index 15f2677..3b811b4 100644 --- a/Sources/WSCore/String+validatingString.swift +++ b/Sources/WSCore/String+validatingString.swift @@ -29,6 +29,6 @@ extension String { } #else self = .init(buffer: buffer) - #endif // compiler(>=6) + #endif // compiler(>=6) } } diff --git a/Sources/WSCore/WebSocketDataFrame.swift b/Sources/WSCore/WebSocketDataFrame.swift index 5d85db7..4445f02 100644 --- a/Sources/WSCore/WebSocketDataFrame.swift +++ b/Sources/WSCore/WebSocketDataFrame.swift @@ -39,10 +39,10 @@ public struct WebSocketDataFrame: Equatable, Sendable, CustomStringConvertible, } public var description: String { - return "\(self.opcode): \(self.data.description), finished: \(self.fin)" + "\(self.opcode): \(self.data.description), finished: \(self.fin)" } public var debugDescription: String { - return "\(self.opcode): \(self.data.debugDescription), finished: \(self.fin)" + "\(self.opcode): \(self.data.debugDescription), finished: \(self.fin)" } } diff --git a/Sources/WSCore/WebSocketHandler.swift b/Sources/WSCore/WebSocketHandler.swift index b480fe5..c47785e 100644 --- a/Sources/WSCore/WebSocketHandler.swift +++ b/Sources/WSCore/WebSocketHandler.swift @@ -115,7 +115,13 @@ public struct WebSocketCloseFrame: Sendable { let rt = try await asyncChannel.executeThenClose { inbound, outbound in try await withTaskCancellationHandler { try await withThrowingTaskGroup(of: WebSocketCloseFrame.self) { group in - let webSocketHandler = Self(channel: asyncChannel.channel, outbound: outbound, type: type, configuration: configuration, context: context) + let webSocketHandler = Self( + channel: asyncChannel.channel, + outbound: outbound, + type: type, + configuration: configuration, + context: context + ) if case .enabled = configuration.autoPing.value { /// Add task sending ping frames every so often and verifying a pong frame was sent back group.addTask { @@ -123,7 +129,13 @@ public struct WebSocketCloseFrame: Sendable { return .init(closeCode: .goingAway, reason: "Ping timeout") } } - let rt = try await webSocketHandler.handle(type: type, inbound: inbound, outbound: outbound, handler: handler, context: context) + let rt = try await webSocketHandler.handle( + type: type, + inbound: inbound, + outbound: outbound, + handler: handler, + context: context + ) group.cancelAll() return rt } @@ -327,7 +339,7 @@ extension WebSocketErrorCode { case NIOWebSocketError.invalidFrameLength: self = .messageTooLarge case NIOWebSocketError.fragmentedControlFrame, - NIOWebSocketError.multiByteControlFrameLength: + NIOWebSocketError.multiByteControlFrameLength: self = .protocolError case WebSocketHandler.InternalError.close(let error): self = error diff --git a/Sources/WSCore/WebSocketStateMachine.swift b/Sources/WSCore/WebSocketStateMachine.swift index 749529a..d588a75 100644 --- a/Sources/WSCore/WebSocketStateMachine.swift +++ b/Sources/WSCore/WebSocketStateMachine.swift @@ -60,11 +60,12 @@ struct WebSocketStateMachine { // read close code and close reason let closeCode = frameData.readWebSocketErrorCode() let hasReason = frameData.readableBytes > 0 - let reason: String? = if hasReason { - String(buffer: frameData, validateUTF8: validateUTF8) - } else { - nil - } + let reason: String? = + if hasReason { + String(buffer: frameData, validateUTF8: validateUTF8) + } else { + nil + } switch self.state { case .open: @@ -73,18 +74,19 @@ struct WebSocketStateMachine { return .sendClose(.protocolError) } self.state = .closed(closeCode.map { .init(closeCode: $0, reason: reason) }) - let code: WebSocketErrorCode = if dataSize == 0 || closeCode != nil { - // codes 3000 - 3999 are reserved for use by libraries, frameworks - // codes 4000 - 4999 are reserved for private use - // both of these are considered valid. - if case .unknown(let code) = closeCode, code < 3000 || code > 4999 { - .protocolError + let code: WebSocketErrorCode = + if dataSize == 0 || closeCode != nil { + // codes 3000 - 3999 are reserved for use by libraries, frameworks + // codes 4000 - 4999 are reserved for private use + // both of these are considered valid. + if case .unknown(let code) = closeCode, code < 3000 || code > 4999 { + .protocolError + } else { + .normalClosure + } } else { - .normalClosure + .protocolError } - } else { - .protocolError - } return .sendClose(code) case .closing: self.state = .closed(closeCode.map { .init(closeCode: $0, reason: reason) }) diff --git a/Tests/WebSocketTests/AutobahnTests.swift b/Tests/WebSocketTests/AutobahnTests.swift index 9ab749f..f4998ad 100644 --- a/Tests/WebSocketTests/AutobahnTests.swift +++ b/Tests/WebSocketTests/AutobahnTests.swift @@ -100,7 +100,10 @@ final class AutobahnTests: XCTestCase { XCTAssert(status.behavior == "OK" || status.behavior == "INFORMATIONAL" || status.behavior == "NON-STRICT") } - try await WebSocketClient.connect(url: .init("ws://\(self.autobahnServer):9001/updateReports?agent=HB"), logger: logger) { inbound, _, _ in + try await WebSocketClient.connect( + url: .init("ws://\(self.autobahnServer):9001/updateReports?agent=HB"), + logger: logger + ) { inbound, _, _ in for try await _ in inbound {} } } catch let error as NIOConnectionError { @@ -171,27 +174,81 @@ final class AutobahnTests: XCTestCase { func test_13_CompressionDifferentParameters() async throws { if !self.runAllTests { - try await self.autobahnTests(cases: .init([392]), extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init([427]), extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init([440]), extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: false, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init([451]), extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: false, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init([473]), extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init([498]), extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: true, maxDecompressedFrameSize: 131_072)]) + try await self.autobahnTests( + cases: .init([392]), + extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init([427]), + extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init([440]), + extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: false, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init([451]), + extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: false, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init([473]), + extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init([498]), + extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] + ) // case 13.7.x are repeated with different setups - try await self.autobahnTests(cases: .init([509]), extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init([517]), extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init([504]), extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)]) + try await self.autobahnTests( + cases: .init([509]), + extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init([517]), + extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init([504]), + extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)] + ) } else { - try await self.autobahnTests(cases: .init(392..<410), extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init(410..<428), extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init(428..<446), extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: false, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init(446..<464), extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: false, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init(464..<482), extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init(482..<500), extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: true, maxDecompressedFrameSize: 131_072)]) + try await self.autobahnTests( + cases: .init(392..<410), + extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init(410..<428), + extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init(428..<446), + extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: false, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init(446..<464), + extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: false, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init(464..<482), + extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init(482..<500), + extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] + ) // case 13.7.x are repeated with different setups - try await self.autobahnTests(cases: .init(500..<518), extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init(500..<518), extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)]) - try await self.autobahnTests(cases: .init(500..<518), extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)]) + try await self.autobahnTests( + cases: .init(500..<518), + extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init(500..<518), + extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)] + ) + try await self.autobahnTests( + cases: .init(500..<518), + extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)] + ) } } } diff --git a/Tests/WebSocketTests/WebSocketExtensionNegotiationTests.swift b/Tests/WebSocketTests/WebSocketExtensionNegotiationTests.swift index 01071c1..1fd346c 100644 --- a/Tests/WebSocketTests/WebSocketExtensionNegotiationTests.swift +++ b/Tests/WebSocketTests/WebSocketExtensionNegotiationTests.swift @@ -14,9 +14,10 @@ import HTTPTypes import NIOWebSocket +import XCTest + @testable import WSCompression @testable import WSCore -import XCTest final class WebSocketExtensionNegotiationTests: XCTestCase { func testExtensionHeaderParsing() { @@ -36,7 +37,7 @@ final class WebSocketExtensionNegotiationTests: XCTestCase { func testDeflateServerResponse() { let requestHeaders: [WebSocketExtensionHTTPParameters] = [ - .init("permessage-deflate", parameters: ["client_max_window_bits": .value("10")]), + .init("permessage-deflate", parameters: ["client_max_window_bits": .value("10")]) ] let ext = PerMessageDeflateExtensionBuilder(clientNoContextTakeover: true, serverNoContextTakeover: true) let serverResponse = ext.serverResponseHeader(to: requestHeaders) @@ -48,7 +49,7 @@ final class WebSocketExtensionNegotiationTests: XCTestCase { func testDeflateServerResponseClientMaxWindowBits() { let requestHeaders: [WebSocketExtensionHTTPParameters] = [ - .init("permessage-deflate", parameters: ["client_max_window_bits": .null]), + .init("permessage-deflate", parameters: ["client_max_window_bits": .null]) ] let ext1 = PerMessageDeflateExtensionBuilder(serverNoContextTakeover: true) let serverResponse1 = ext1.serverResponseHeader(to: requestHeaders) @@ -90,18 +91,20 @@ final class WebSocketExtensionNegotiationTests: XCTestCase { var name = "my-extension" func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { - return frame + frame } func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { - return frame + frame } func shutdown() async {} } - let clientExtensionBuilders: [WebSocketExtensionBuilder] = [WebSocketExtensionFactory.nonNegotiatedExtension { - MyExtension() - }.build()] + let clientExtensionBuilders: [WebSocketExtensionBuilder] = [ + WebSocketExtensionFactory.nonNegotiatedExtension { + MyExtension() + }.build() + ] let clientExtensions = try clientExtensionBuilders.buildClientExtensions(from: [:]) XCTAssertEqual(clientExtensions.count, 1) let myExtension = try XCTUnwrap(clientExtensions.first) @@ -113,11 +116,11 @@ final class WebSocketExtensionNegotiationTests: XCTestCase { var name = "my-extension" func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { - return frame + frame } func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { - return frame + frame } func shutdown() async {} diff --git a/Tests/WebSocketTests/WebSocketStateMachineTests.swift b/Tests/WebSocketTests/WebSocketStateMachineTests.swift index 504be58..a85d5c8 100644 --- a/Tests/WebSocketTests/WebSocketStateMachineTests.swift +++ b/Tests/WebSocketTests/WebSocketStateMachineTests.swift @@ -14,9 +14,10 @@ import NIOCore import NIOWebSocket -@testable import WSCore import XCTest +@testable import WSCore + final class WebSocketStateMachineTests: XCTestCase { private func closeFrameData(code: WebSocketErrorCode = .normalClosure, reason: String? = nil) -> ByteBuffer { var buffer = ByteBufferAllocator().buffer(capacity: 2 + (reason?.utf8.count ?? 0)) @@ -29,34 +30,70 @@ final class WebSocketStateMachineTests: XCTestCase { func testClose() { var stateMachine = WebSocketStateMachine(autoPingSetup: .disabled) - guard case .sendClose = stateMachine.close() else { XCTFail(); return } - guard case .doNothing = stateMachine.close() else { XCTFail(); return } - guard case .doNothing = stateMachine.receivedClose(frameData: self.closeFrameData(), validateUTF8: false) else { XCTFail(); return } - guard case .closed(let frame) = stateMachine.state else { XCTFail(); return } + guard case .sendClose = stateMachine.close() else { + XCTFail() + return + } + guard case .doNothing = stateMachine.close() else { + XCTFail() + return + } + guard case .doNothing = stateMachine.receivedClose(frameData: self.closeFrameData(), validateUTF8: false) else { + XCTFail() + return + } + guard case .closed(let frame) = stateMachine.state else { + XCTFail() + return + } XCTAssertEqual(frame?.closeCode, .normalClosure) } func testReceivedClose() { var stateMachine = WebSocketStateMachine(autoPingSetup: .disabled) - guard case .sendClose(let error) = stateMachine.receivedClose(frameData: closeFrameData(code: .goingAway), validateUTF8: false) else { XCTFail(); return } + guard case .sendClose(let error) = stateMachine.receivedClose(frameData: closeFrameData(code: .goingAway), validateUTF8: false) else { + XCTFail() + return + } XCTAssertEqual(error, .normalClosure) - guard case .closed(let frame) = stateMachine.state else { XCTFail(); return } + guard case .closed(let frame) = stateMachine.state else { + XCTFail() + return + } XCTAssertEqual(frame?.closeCode, .goingAway) } func testPingLoopNoPong() { var stateMachine = WebSocketStateMachine(autoPingSetup: .enabled(timePeriod: .seconds(15))) - guard case .sendPing = stateMachine.sendPing() else { XCTFail(); return } - guard case .wait = stateMachine.sendPing() else { XCTFail(); return } + guard case .sendPing = stateMachine.sendPing() else { + XCTFail() + return + } + guard case .wait = stateMachine.sendPing() else { + XCTFail() + return + } } func testPingLoop() { var stateMachine = WebSocketStateMachine(autoPingSetup: .enabled(timePeriod: .seconds(15))) - guard case .sendPing(let buffer) = stateMachine.sendPing() else { XCTFail(); return } - guard case .wait = stateMachine.sendPing() else { XCTFail(); return } + guard case .sendPing(let buffer) = stateMachine.sendPing() else { + XCTFail() + return + } + guard case .wait = stateMachine.sendPing() else { + XCTFail() + return + } stateMachine.receivedPong(frameData: buffer) - guard case .open(let openState) = stateMachine.state else { XCTFail(); return } + guard case .open(let openState) = stateMachine.state else { + XCTFail() + return + } XCTAssertEqual(openState.lastPingTime, nil) - guard case .sendPing = stateMachine.sendPing() else { XCTFail(); return } + guard case .sendPing = stateMachine.sendPing() else { + XCTFail() + return + } } }