Skip to content

Commit

Permalink
Add UTF8 validation (#2)
Browse files Browse the repository at this point in the history
* Add UTF8 validation

* Add UTF8 validation

* Wrap validation code in compiler(>=6)

* Disable some tests for swift 5.10

* Duplicate API being submitted to SwiftNIO

* Update docs to use WSCore

* Update Sources/WSCore/ByteBuffer+validatingString.swift

Co-authored-by: Joannis Orlandos <[email protected]>

* Minor updates

* Fix CI

---------

Co-authored-by: Joannis Orlandos <[email protected]>
  • Loading branch information
adam-fowler and Joannis authored Nov 15, 2024
1 parent 7c53f09 commit fee8d4f
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 29 deletions.
3 changes: 1 addition & 2 deletions Sources/WSClient/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ import WSCore
/// }
/// ```
public struct WebSocketClient {
/// Basic context implementation of ``/HummingbirdWSCore/WebSocketContext``.
/// Used by non-router web socket handle function
/// Client implementation of ``/WSCore/WebSocketContext``.
public struct Context: WebSocketContext {
public let logger: Logger

Expand Down
3 changes: 2 additions & 1 deletion Sources/WSClient/WebSocketClientChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ struct WebSocketClientChannel: ClientConnectionChannel {
type: .client,
configuration: .init(
extensions: extensions,
autoPing: self.configuration.autoPing
autoPing: self.configuration.autoPing,
validateUTF8: self.configuration.validateUTF8
),
asyncChannel: webSocketChannel,
context: WebSocketClient.Context(logger: logger),
Expand Down
6 changes: 5 additions & 1 deletion Sources/WSClient/WebSocketClientConfiguration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public struct WebSocketClientConfiguration: Sendable {
public var extensions: [any WebSocketExtensionBuilder]
/// Automatic ping setup
public var autoPing: AutoPingSetup
/// Should text be validated to be UTF8
public var validateUTF8: Bool

/// Initialize WebSocketClient configuration
/// - Paramters
Expand All @@ -34,11 +36,13 @@ public struct WebSocketClientConfiguration: Sendable {
maxFrameSize: Int = (1 << 14),
additionalHeaders: HTTPFields = .init(),
extensions: [WebSocketExtensionFactory] = [],
autoPing: AutoPingSetup = .disabled
autoPing: AutoPingSetup = .disabled,
validateUTF8: Bool = false
) {
self.maxFrameSize = maxFrameSize
self.additionalHeaders = additionalHeaders
self.extensions = extensions.map { $0.build() }
self.autoPing = autoPing
self.validateUTF8 = validateUTF8
}
}
128 changes: 128 additions & 0 deletions Sources/WSCore/ByteBuffer+validatingString.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import NIOCore

#if compiler(>=6)
extension ByteBuffer {
/// Get the string at `index` from this `ByteBuffer` decoding using the UTF-8 encoding. Does not move the reader index.
/// The selected bytes must be readable or else `nil` will be returned.
///
/// This is an alternative to `ByteBuffer.getString(at:length:)` which ensures the returned string is valid UTF8. If the
/// string is not valid UTF8 then a `ReadUTF8ValidationError` error is thrown.
///
/// - Parameters:
/// - index: The starting index into `ByteBuffer` containing the string of interest.
/// - length: The number of bytes making up the string.
/// - Returns: A `String` value containing the UTF-8 decoded selected bytes from this `ByteBuffer` or `nil` if
/// the requested bytes are not readable.
@inlinable
@available(macOS 15, iOS 18, tvOS 18, watchOS 11, *)
public func getUTF8ValidatedString(at index: Int, length: Int) throws -> String? {
guard let range = self.rangeWithinReadableBytes(index: index, length: length) else {
return nil
}
return try self.withUnsafeReadableBytes { pointer in
assert(range.lowerBound >= 0 && (range.upperBound - range.lowerBound) <= pointer.count)
guard
let string = String(
validating: UnsafeRawBufferPointer(fastRebase: pointer[range]),
as: Unicode.UTF8.self
)
else {
throw ReadUTF8ValidationError.invalidUTF8
}
return string
}
}

/// Read `length` bytes off this `ByteBuffer`, decoding it as `String` using the UTF-8 encoding. Move the reader index
/// forward by `length`.
///
/// This is an alternative to `ByteBuffer.readString(length:)` which ensures the returned string is valid UTF8. If the
/// string is not valid UTF8 then a `ReadUTF8ValidationError` error is thrown and the reader index is not advanced.
///
/// - Parameters:
/// - length: The number of bytes making up the string.
/// - Returns: A `String` value deserialized from this `ByteBuffer` or `nil` if there aren't at least `length` bytes readable.
@inlinable
@available(macOS 15, iOS 18, tvOS 18, watchOS 11, *)
public mutating func readUTF8ValidatedString(length: Int) throws -> String? {
guard let result = try self.getUTF8ValidatedString(at: self.readerIndex, length: length) else {
return nil
}
self.moveReaderIndex(forwardBy: length)
return result
}

/// Errors thrown when calling `readUTF8ValidatedString` or `getUTF8ValidatedString`.
public struct ReadUTF8ValidationError: Error, Equatable {
private enum BaseError: Hashable {
case invalidUTF8
}

private var baseError: BaseError

/// The length of the bytes to copy was negative.
public static let invalidUTF8: ReadUTF8ValidationError = .init(baseError: .invalidUTF8)
}

@inlinable
func rangeWithinReadableBytes(index: Int, length: Int) -> Range<Int>? {
guard index >= self.readerIndex, length >= 0 else {
return nil
}

// both these &-s are safe, they can't underflow because both left & right side are >= 0 (and index >= readerIndex)
let indexFromReaderIndex = index &- self.readerIndex
assert(indexFromReaderIndex >= 0)
guard indexFromReaderIndex <= self.readableBytes &- length else {
return nil
}

let upperBound = indexFromReaderIndex &+ length // safe, can't overflow, we checked it above.

// uncheckedBounds is safe because `length` is >= 0, so the lower bound will always be lower/equal to upper
return Range<Int>(uncheckedBounds: (lower: indexFromReaderIndex, upper: upperBound))
}
}

extension UnsafeRawBufferPointer {
@inlinable
init(fastRebase slice: Slice<UnsafeRawBufferPointer>) {
let base = slice.base.baseAddress?.advanced(by: slice.startIndex)
self.init(start: base, count: slice.endIndex &- slice.startIndex)
}
}

#endif // compiler(>=6)

extension String {
init?(buffer: ByteBuffer, validateUTF8: Bool) {
#if compiler(>=6)
if #available(macOS 15, iOS 18, tvOS 18, watchOS 11, *), validateUTF8 {
do {
var buffer = buffer
self = try buffer.readUTF8ValidatedString(length: buffer.readableBytes)!
} catch {
return nil
}
} else {
self = .init(buffer: buffer)
}
#else
self = .init(buffer: buffer)
#endif // compiler(>=6)
}
}
4 changes: 2 additions & 2 deletions Sources/WSCore/WebSocketFrameSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ struct WebSocketFrameSequence {
}
}

var message: WebSocketMessage {
.init(frame: self.collated)!
func getMessage(validateUTF8: Bool) -> WebSocketMessage? {
.init(frame: self.collated, validate: validateUTF8)
}

var collated: WebSocketDataFrame {
Expand Down
6 changes: 4 additions & 2 deletions Sources/WSCore/WebSocketHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ public struct WebSocketCloseFrame: Sendable {
@_spi(WSInternal) public struct Configuration: Sendable {
let extensions: [any WebSocketExtension]
let autoPing: AutoPingSetup
let validateUTF8: Bool

@_spi(WSInternal) public init(extensions: [any WebSocketExtension], autoPing: AutoPingSetup) {
@_spi(WSInternal) public init(extensions: [any WebSocketExtension], autoPing: AutoPingSetup, validateUTF8: Bool) {
self.extensions = extensions
self.autoPing = autoPing
self.validateUTF8 = validateUTF8
}
}

Expand Down Expand Up @@ -287,7 +289,7 @@ public struct WebSocketCloseFrame: Sendable {
}

func receivedClose(_ frame: WebSocketFrame) async throws {
switch self.stateMachine.receivedClose(frameData: frame.unmaskedData) {
switch self.stateMachine.receivedClose(frameData: frame.unmaskedData, validateUTF8: self.configuration.validateUTF8) {
case .sendClose(let errorCode):
try await self.sendClose(code: errorCode, reason: nil)
// Only server should initiate a connection close. Clients should wait for the
Expand Down
19 changes: 11 additions & 8 deletions Sources/WSCore/WebSocketInboundStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,28 @@ public final class WebSocketInboundStream: AsyncSequence, Sendable {
case .text, .binary:
frameSequence = .init(frame: frame)
if frame.fin {
return frameSequence.message
guard let message = frameSequence.getMessage(validateUTF8: self.handler.configuration.validateUTF8) else {
throw WebSocketHandler.InternalError.close(.dataInconsistentWithMessage)
}
return message
}
default:
try await self.handler.close(code: .protocolError)
return nil
throw WebSocketHandler.InternalError.close(.protocolError)
}
// parse continuation frames until we get a frame with a FIN flag
while let frame = try await self.next() {
guard frame.opcode == .continuation else {
try await self.handler.close(code: .protocolError)
return nil
throw WebSocketHandler.InternalError.close(.protocolError)
}
guard frameSequence.size + frame.data.readableBytes <= maxSize else {
try await self.handler.close(code: .messageTooLarge)
return nil
throw WebSocketHandler.InternalError.close(.messageTooLarge)
}
frameSequence.append(frame)
if frame.fin {
return frameSequence.message
guard let message = frameSequence.getMessage(validateUTF8: self.handler.configuration.validateUTF8) else {
throw WebSocketHandler.InternalError.close(.dataInconsistentWithMessage)
}
return message
}
}
return nil
Expand Down
7 changes: 5 additions & 2 deletions Sources/WSCore/WebSocketMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ public enum WebSocketMessage: Equatable, Sendable, CustomStringConvertible, Cust
case text(String)
case binary(ByteBuffer)

init?(frame: WebSocketDataFrame) {
init?(frame: WebSocketDataFrame, validate: Bool) {
switch frame.opcode {
case .text:
self = .text(String(buffer: frame.data))
guard let string = String(buffer: frame.data, validateUTF8: validate) else {
return nil
}
self = .text(string)
case .binary:
self = .binary(frame.data)
default:
Expand Down
15 changes: 11 additions & 4 deletions Sources/WSCore/WebSocketStateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,24 @@ struct WebSocketStateMachine {

// we received a connection close.
// send a close back if it hasn't already been send and exit
mutating func receivedClose(frameData: ByteBuffer) -> ReceivedCloseResult {
mutating func receivedClose(frameData: ByteBuffer, validateUTF8: Bool) -> ReceivedCloseResult {
var frameData = frameData
let dataSize = frameData.readableBytes
// read close code and close reason
let closeCode = frameData.readWebSocketErrorCode()
let reason = frameData.readableBytes > 0
? frameData.readString(length: frameData.readableBytes)
: nil
let hasReason = frameData.readableBytes > 0
let reason: String? = if hasReason {
String(buffer: frameData, validateUTF8: validateUTF8)
} else {
nil
}

switch self.state {
case .open:
if hasReason, reason == nil {
self.state = .closed(closeCode.map { .init(closeCode: $0, reason: reason) })
return .sendClose(.protocolError)
}
self.state = .closed(closeCode.map { .init(closeCode: $0, reason: reason) })
let code: WebSocketErrorCode = if dataSize == 0 || closeCode != nil {
// codes 3000 - 3999 are reserved for use by libraries, frameworks
Expand Down
29 changes: 24 additions & 5 deletions Tests/WebSocketTests/AutobahnTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ import WSClient
import WSCompression
import XCTest

/// The Autobahn|Testsuite provides a fully automated test suite to verify client and server
/// implementations of The WebSocket Protocol for specification conformance and implementation robustness.
/// You can find out more at https://github.com/crossbario/autobahn-testsuite
final class AutobahnTests: XCTestCase {
/// To run all the autobahn tests takes a long time. By default we only run a selection.
/// To run all the autobahn compression tests takes a long time. By default we only run a selection.
/// The `AUTOBAHN_ALL_TESTS` environment flag triggers running all of them.
var runAllTests: Bool { ProcessInfo.processInfo.environment["AUTOBAHN_ALL_TESTS"] == "true" }
var autobahnServer: String { ProcessInfo.processInfo.environment["FUZZING_SERVER"] ?? "localhost" }
Expand All @@ -30,6 +33,7 @@ final class AutobahnTests: XCTestCase {
let result: NIOLockedValueBox<T?> = .init(nil)
try await WebSocketClient.connect(
url: .init("ws://\(self.autobahnServer):9001/\(path)"),
configuration: .init(validateUTF8: true),
logger: Logger(label: "Autobahn")
) { inbound, _, _ in
var inboundIterator = inbound.messages(maxSize: .max).makeAsyncIterator()
Expand All @@ -49,6 +53,7 @@ final class AutobahnTests: XCTestCase {
return try result.withLockedValue { try XCTUnwrap($0) }
}

/// Run a number of autobahn tests
func autobahnTests(
cases: Set<Int>,
extensions: [WebSocketExtensionFactory] = [.perMessageDeflate(maxDecompressedFrameSize: 16_777_216)]
Expand All @@ -73,7 +78,11 @@ final class AutobahnTests: XCTestCase {
// run case
try await WebSocketClient.connect(
url: .init("ws://\(self.autobahnServer):9001/runCase?case=\(index)&agent=swift-websocket"),
configuration: .init(maxFrameSize: 16_777_216, extensions: extensions),
configuration: .init(
maxFrameSize: 16_777_216,
extensions: extensions,
validateUTF8: true
),
logger: logger
) { inbound, outbound, _ in
for try await msg in inbound.messages(maxSize: .max) {
Expand All @@ -88,7 +97,11 @@ final class AutobahnTests: XCTestCase {

// get case status
let status = try await getValue("getCaseStatus?case=\(index)&agent=swift-websocket", as: CaseStatus.self)
XCTAssert(status.behavior == "OK" || status.behavior == "INFORMATIONAL")
XCTAssert(status.behavior == "OK" || status.behavior == "INFORMATIONAL" || status.behavior == "NON-STRICT")
}

try await WebSocketClient.connect(url: .init("ws://\(self.autobahnServer):9001/updateReports?agent=HB"), logger: logger) { inbound, _, _ in
for try await _ in inbound {}
}
} catch let error as NIOConnectionError {
logger.error("Autobahn tests require a running Autobahn fuzzing server. Run ./scripts/autobahn-server.sh")
Expand Down Expand Up @@ -119,15 +132,21 @@ final class AutobahnTests: XCTestCase {
}

func test_6_UTF8Handling() async throws {
// UTF8 validation fails
// UTF8 validation is available on swift 5.10 or earlier
#if compiler(<6)
try XCTSkipIf(true)
#endif
try await self.autobahnTests(cases: .init(65..<210))
}

func test_7_CloseHandling() async throws {
// UTF8 validation is available on swift 5.10 or earlier
#if compiler(<6)
try await self.autobahnTests(cases: .init(210..<222))
// UTF8 validation fails so skip 222
try await self.autobahnTests(cases: .init(223..<247))
#else
try await self.autobahnTests(cases: .init(210..<247))
#endif
}

func test_9_Performance() async throws {
Expand Down
4 changes: 2 additions & 2 deletions Tests/WebSocketTests/WebSocketStateMachineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ final class WebSocketStateMachineTests: XCTestCase {
var stateMachine = WebSocketStateMachine(autoPingSetup: .disabled)
guard case .sendClose = stateMachine.close() else { XCTFail(); return }
guard case .doNothing = stateMachine.close() else { XCTFail(); return }
guard case .doNothing = stateMachine.receivedClose(frameData: self.closeFrameData()) else { XCTFail(); return }
guard case .doNothing = stateMachine.receivedClose(frameData: self.closeFrameData(), validateUTF8: false) else { XCTFail(); return }
guard case .closed(let frame) = stateMachine.state else { XCTFail(); return }
XCTAssertEqual(frame?.closeCode, .normalClosure)
}

func testReceivedClose() {
var stateMachine = WebSocketStateMachine(autoPingSetup: .disabled)
guard case .sendClose(let error) = stateMachine.receivedClose(frameData: closeFrameData(code: .goingAway)) else { XCTFail(); return }
guard case .sendClose(let error) = stateMachine.receivedClose(frameData: closeFrameData(code: .goingAway), validateUTF8: false) else { XCTFail(); return }
XCTAssertEqual(error, .normalClosure)
guard case .closed(let frame) = stateMachine.state else { XCTFail(); return }
XCTAssertEqual(frame?.closeCode, .goingAway)
Expand Down

0 comments on commit fee8d4f

Please sign in to comment.