Skip to content

Commit

Permalink
ZLibCompressor/Decompressor (#22)
Browse files Browse the repository at this point in the history
* Use non-copyable zlib compressor/decompressor

* make algorithm non-mutable

* Make DecompressByteBufferSequence a struct

* No need to make DecompressByteBufferSequence.makeAsyncSequence consuming

* Update for CompressNIO Zlib updates

Cleaned up DecompressByteBufferSequence by using a state machine, also allocate Decompressor on first call to `next()`.

* swift format

* Include window in state machine, make iterator a struct

* Use compress-nio 1.3.0
  • Loading branch information
adam-fowler authored Nov 18, 2024
1 parent c1643d9 commit 1fc402c
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 76 deletions.
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ let package = Package(
],
dependencies: [
.package(url: "https://github.com/hummingbird-project/hummingbird.git", from: "2.0.0"),
.package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.2.1"),
.package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.3.0"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.32.1"),
],
targets: [
Expand Down
38 changes: 16 additions & 22 deletions Sources/HummingbirdCompression/CompressedBodyWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,29 @@ import Logging
// ResponseBodyWriter that writes a compressed version of the response to a parent writer
final class CompressedBodyWriter<ParentWriter: ResponseBodyWriter & Sendable>: ResponseBodyWriter {
var parentWriter: ParentWriter
let compressor: NIOCompressor
private let compressor: ZlibCompressor
private var window: ByteBuffer
var lastBuffer: ByteBuffer?
let logger: Logger

init(
parent: ParentWriter,
algorithm: CompressionAlgorithm,
algorithm: ZlibAlgorithm,
configuration: ZlibConfiguration,
windowSize: Int,
logger: Logger
) throws {
self.parentWriter = parent
self.compressor = algorithm.compressor
self.compressor.window = ByteBufferAllocator().buffer(capacity: windowSize)
self.compressor = try ZlibCompressor(algorithm: algorithm, configuration: configuration)
self.window = ByteBufferAllocator().buffer(capacity: windowSize)
self.lastBuffer = nil
self.logger = logger
try self.compressor.startStream()
}

deinit {
do {
try self.compressor.finishStream()
} catch {
logger.error("Error finalizing compression stream: \(error) ")
}
}

/// Write response buffer
func write(_ buffer: ByteBuffer) async throws {
var buffer = buffer
try await buffer.compressStream(with: self.compressor, flush: .sync) { buffer in
try await buffer.compressStream(with: self.compressor, window: &self.window, flush: .sync) { buffer in
try await self.parentWriter.write(buffer)
}
// need to store the last buffer so it can be finished once the writer is done
Expand All @@ -59,17 +52,17 @@ final class CompressedBodyWriter<ParentWriter: ResponseBodyWriter & Sendable>: R
/// - Parameter trailingHeaders: Any trailing headers you want to include at end
consuming func finish(_ trailingHeaders: HTTPFields?) async throws {
// The last buffer must be finished
if var lastBuffer, var window = self.compressor.window {
if var lastBuffer {
// keep finishing stream until we don't get a buffer overflow
while true {
do {
try lastBuffer.compressStream(to: &window, with: self.compressor, flush: .finish)
try await self.parentWriter.write(window)
window.clear()
try lastBuffer.compressStream(to: &self.window, with: self.compressor, flush: .finish)
try await self.parentWriter.write(self.window)
self.window.clear()
break
} catch let error as CompressNIOError where error == .bufferOverflow {
try await self.parentWriter.write(window)
window.clear()
try await self.parentWriter.write(self.window)
self.window.clear()
}
}
}
Expand All @@ -87,10 +80,11 @@ extension ResponseBodyWriter {
/// - logger: Logger used to output compression errors
/// - Returns: new ``HummingbirdCore/ResponseBodyWriter``
public func compressed(
algorithm: CompressionAlgorithm,
algorithm: ZlibAlgorithm,
configuration: ZlibConfiguration,
windowSize: Int,
logger: Logger
) throws -> some ResponseBodyWriter {
try CompressedBodyWriter(parent: self, algorithm: algorithm, windowSize: windowSize, logger: logger)
try CompressedBodyWriter(parent: self, algorithm: algorithm, configuration: configuration, windowSize: windowSize, logger: logger)
}
}
106 changes: 57 additions & 49 deletions Sources/HummingbirdCompression/RequestDecompressionMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ public struct RequestDecompressionMiddleware<Context: RequestContext>: RouterMid
}

/// Determines the decompression algorithm based off content encoding header.
private func algorithm(from contentEncodingHeaders: [String]) -> CompressionAlgorithm? {
private func algorithm(from contentEncodingHeaders: [String]) -> ZlibAlgorithm? {
for encoding in contentEncodingHeaders {
switch encoding {
case "gzip":
return CompressionAlgorithm.gzip()
return .gzip
case "deflate":
return CompressionAlgorithm.zlib()
return .zlib
default:
break
}
Expand All @@ -67,67 +67,75 @@ struct DecompressByteBufferSequence<Base: AsyncSequence & Sendable>: AsyncSequen
typealias Element = ByteBuffer

let base: Base
let algorithm: CompressionAlgorithm
let algorithm: ZlibAlgorithm
let windowSize: Int
let logger: Logger

class AsyncIterator: AsyncIteratorProtocol {
var baseIterator: Base.AsyncIterator
let decompressor: NIODecompressor
var currentBuffer: ByteBuffer?
var window: ByteBuffer
let logger: Logger
init(base: Base, algorithm: ZlibAlgorithm, windowSize: Int, logger: Logger) {
self.base = base
self.algorithm = algorithm
self.windowSize = windowSize
self.logger = logger
}

init(baseIterator: Base.AsyncIterator, algorithm: CompressionAlgorithm, windowSize: Int, logger: Logger) {
self.baseIterator = baseIterator
self.decompressor = algorithm.decompressor
self.window = ByteBufferAllocator().buffer(capacity: windowSize)
self.currentBuffer = nil
self.logger = logger
do {
try self.decompressor.startStream()
} catch {
logger.error("Error initializing decompression stream: \(error) ")
}
struct AsyncIterator: AsyncIteratorProtocol {
enum State {
case uninitialized(ZlibAlgorithm, windowSize: Int)
case decompressing(ZlibDecompressor, buffer: ByteBuffer, window: ByteBuffer)
case done
}

deinit {
do {
try self.decompressor.finishStream()
} catch {
logger.error("Error finalizing decompression stream: \(error) ")
}
var baseIterator: Base.AsyncIterator
var state: State

init(baseIterator: Base.AsyncIterator, algorithm: ZlibAlgorithm, windowSize: Int) {
self.baseIterator = baseIterator
self.state = .uninitialized(algorithm, windowSize: windowSize)
}

func next() async throws -> ByteBuffer? {
do {
if self.currentBuffer == nil {
self.currentBuffer = try await self.baseIterator.next()
mutating func next() async throws -> ByteBuffer? {
switch self.state {
case .uninitialized(let algorithm, let windowSize):
guard let buffer = try await self.baseIterator.next() else {
self.state = .done
return nil
}
self.window.clear()
while var buffer = self.currentBuffer {
do {
try buffer.decompressStream(to: &self.window, with: self.decompressor)
} catch let error as CompressNIOError where error == .bufferOverflow {
self.currentBuffer = buffer
return self.window
} catch let error as CompressNIOError where error == .inputBufferOverflow {
// can ignore CompressNIOError.inputBufferOverflow errors here
}
let decompressor = try ZlibDecompressor(algorithm: algorithm)
self.state = .decompressing(decompressor, buffer: buffer, window: ByteBufferAllocator().buffer(capacity: windowSize))
return try await self.next()

self.currentBuffer = try await self.baseIterator.next()
case .decompressing(let decompressor, var buffer, var window):
do {
window.clear()
while true {
do {
try buffer.decompressStream(to: &window, with: decompressor)
} catch let error as CompressNIOError where error == .bufferOverflow {
self.state = .decompressing(decompressor, buffer: buffer, window: window)
return window
} catch let error as CompressNIOError where error == .inputBufferOverflow {
// can ignore CompressNIOError.inputBufferOverflow errors here
}

guard let nextBuffer = try await self.baseIterator.next() else {
self.state = .done
return window.readableBytes > 0 ? window : nil
}
buffer = nextBuffer
}
} catch let error as CompressNIOError where error == .corruptData {
throw HTTPError(.badRequest, message: "Corrupt compression data.")
} catch {
throw HTTPError(.badRequest, message: "Data decompression failed.")
}
self.currentBuffer = nil
return self.window.readableBytes > 0 ? self.window : nil
} catch let error as CompressNIOError where error == .corruptData {
throw HTTPError(.badRequest, message: "Corrupt compression data.")
} catch {
throw HTTPError(.badRequest, message: "Data decompression failed.")

case .done:
return nil
}
}
}

func makeAsyncIterator() -> AsyncIterator {
.init(baseIterator: self.base.makeAsyncIterator(), algorithm: self.algorithm, windowSize: self.windowSize, logger: self.logger)
.init(baseIterator: self.base.makeAsyncIterator(), algorithm: self.algorithm, windowSize: self.windowSize)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public struct ResponseCompressionMiddleware<Context: RequestContext>: RouterMidd
editedResponse.body = .init { writer in
let compressWriter = try writer.compressed(
algorithm: algorithm,
configuration: self.zlibConfiguration,
windowSize: self.windowSize,
logger: context.logger
)
Expand All @@ -95,7 +96,7 @@ public struct ResponseCompressionMiddleware<Context: RequestContext>: RouterMidd
}

/// Determines the compression algorithm to use for the next response.
private func compressionAlgorithm(from acceptContentHeaders: [some StringProtocol]) -> (compressor: CompressionAlgorithm, name: String)? {
private func compressionAlgorithm(from acceptContentHeaders: [some StringProtocol]) -> (algorithm: ZlibAlgorithm, name: String)? {
var gzipQValue: Float = -1
var deflateQValue: Float = -1
var anyQValue: Float = -1
Expand All @@ -112,15 +113,15 @@ public struct ResponseCompressionMiddleware<Context: RequestContext>: RouterMidd

if gzipQValue > 0 || deflateQValue > 0 {
if gzipQValue > deflateQValue {
return (compressor: CompressionAlgorithm.gzip(configuration: self.zlibConfiguration), name: "gzip")
return (algorithm: .gzip, name: "gzip")
} else {
return (compressor: CompressionAlgorithm.zlib(configuration: self.zlibConfiguration), name: "deflate")
return (algorithm: .zlib, name: "deflate")
}
} else if anyQValue > 0 {
// Though gzip is usually less well compressed than deflate, it has slightly
// wider support because it's unabiguous. We therefore default to that unless
// the client has expressed a preference.
return (compressor: CompressionAlgorithm.gzip(configuration: self.zlibConfiguration), name: "gzip")
return (algorithm: .gzip, name: "gzip")
}

return nil
Expand Down

0 comments on commit 1fc402c

Please sign in to comment.