diff --git a/Sources/NIOSSL/NIOSSLHandler.swift b/Sources/NIOSSL/NIOSSLHandler.swift index 3a6313a8..46b2ca38 100644 --- a/Sources/NIOSSL/NIOSSLHandler.swift +++ b/Sources/NIOSSL/NIOSSLHandler.swift @@ -351,7 +351,15 @@ public class NIOSSLHandler : ChannelInboundHandler, ChannelOutboundHandler, Remo /// /// This method must not be called once the connection is established. private func doHandshakeStep(context: ChannelHandlerContext) { - let result = connection.doHandshake() + switch self.state { + case .unwrapped, .inputClosed, .outputClosed, .closed: + // We shouldn't be handshaking in any of these state. + return + case .idle, .handshaking, .additionalVerification, .active, .closing, .unwrapping: + () + } + + let result = self.connection.doHandshake() switch result { case .incomplete: @@ -872,9 +880,15 @@ extension NIOSSLHandler { } private func bufferWrite(data: ByteBuffer, promise: EventLoopPromise?) { - if case .outputClosed = self.state { + switch self.state { + case .idle, .handshaking, .additionalVerification, .active, .unwrapping, .closing, .unwrapped, .inputClosed: + () + case .outputClosed: promise?.fail(ChannelError.outputClosed) return + case .closed: + promise?.fail(ChannelError.ioOnClosedChannel) + return } var data = data @@ -914,41 +928,58 @@ extension NIOSSLHandler { // These are some annoying variables we use to persist state across invocations of // our closures. A better version of this code might be able to simplify this somewhat. var promises: [EventLoopPromise] = [] - var didWrite = false do { var invokeCloseOutput = false - bufferedActionsLoop: while bufferedActions.hasMark { - let element = bufferedActions.first! - switch element { - case .write(let bufferedWrite): - var data = bufferedWrite.data - let writeSuccessful = try self._encodeSingleWrite(buf: &data) - if writeSuccessful { - didWrite = true - if let promise = bufferedWrite.promise { promises.append(promise) } - _ = bufferedActions.removeFirst() + var bufferedActionsLoopCount = 0 + bufferedActionsLoop: while self.bufferedActions.hasMark, bufferedActionsLoopCount < 1000 { + bufferedActionsLoopCount += 1 + var didWrite = false + + writeLoop: while self.bufferedActions.hasMark { + let element = self.bufferedActions.first! + switch element { + case .write(let bufferedWrite): + var data = bufferedWrite.data + let writeSuccessful = try self._encodeSingleWrite(buf: &data) + if writeSuccessful { + didWrite = true + if let promise = bufferedWrite.promise { promises.append(promise) } + _ = self.bufferedActions.removeFirst() + } else { + // The write into BoringSSL unsuccessful. Break the write loop so any + // data is written to the network before resuming. + break writeLoop + } + case .closeOutput: + invokeCloseOutput = true + _ = self.bufferedActions.removeFirst() + break writeLoop } - case .closeOutput: - invokeCloseOutput = true - _ = bufferedActions.removeFirst() - break bufferedActionsLoop } - } - // If we got this far and did a write, we should shove the data out to the - // network. - if didWrite { - let ourPromise: EventLoopPromise? = promises.flattenPromises(on: context.eventLoop) - self.writeDataToNetwork(context: context, promise: ourPromise) + // If we got this far and did a write, we should shove the data out to the + // network. + if didWrite { + let ourPromise: EventLoopPromise? = promises.flattenPromises(on: context.eventLoop) + self.writeDataToNetwork(context: context, promise: ourPromise) + } + + // We detected a .closeOutput action in our action buffer. This means we + // close the output after we have written all pending writes. + if invokeCloseOutput { + self.state = .outputClosed + self.doShutdownStep(context: context) + self.discardBufferedActions(reason: ChannelError.outputClosed) + break bufferedActionsLoop + } } - // We detected a .closeOutput action in our action buffer. This means we - // close the output after we have written all pending writes. - if invokeCloseOutput { - self.state = .outputClosed - self.doShutdownStep(context: context) - self.discardBufferedActions(reason: ChannelError.outputClosed) + // We spun the outer loop too many times, something isn't right so let's bail out + // instead of looping any longer. + if bufferedActionsLoopCount >= 1000 { + assertionFailure("\(#function) looped too many times, please file a GitHub issue against swift-nio-ssl.") + throw NIOSSLExtraError.noForwardProgress } } catch { // We encountered an error, it's cleanup time. Close ourselves down. diff --git a/Sources/NIOSSL/SSLErrors.swift b/Sources/NIOSSL/SSLErrors.swift index e5286852..ea594bf4 100644 --- a/Sources/NIOSSL/SSLErrors.swift +++ b/Sources/NIOSSL/SSLErrors.swift @@ -200,6 +200,7 @@ extension NIOSSLExtraError { case cannotUseIPAddressInSNI case invalidSNIHostname case unknownPrivateKeyFileType + case noForwardProgress } } @@ -228,6 +229,12 @@ extension NIOSSLExtraError { /// The private key file for the TLS configuration has an unknown type. public static let unknownPrivateKeyFileType = NIOSSLExtraError(baseError: .unknownPrivateKeyFileType, description: nil) + /// No forward progress is being made. + /// + /// This can happen when the `NIOSSLHandler` is unbuffering actions and gets into a state where + /// it would potentially spin loop indefinitely. + static let noForwardProgress = NIOSSLExtraError(baseError: .noForwardProgress, description: nil) + @inline(never) internal static func failedToValidateHostname(expectedName: String) -> NIOSSLExtraError { let description = "Couldn't find \(expectedName) in certificate from peer" diff --git a/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift b/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift index ae0dc305..8662cdff 100644 --- a/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift +++ b/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift @@ -2826,4 +2826,59 @@ class NIOSSLIntegrationTest: XCTestCase { b2b.client.close(promise: nil) try b2b.interactInMemory() } + + func testDoesNotSpinLoopWhenInactiveAndActiveAreReversed() throws { + // This is a regression test for https://github.com/apple/swift-nio-ssl/issues/467 + // + // If channelInactive occurs before channelActive and a re-entrant write and flush occurred + // in channelActive then 'NIOSSLHandler.doUnbufferActions(context:)' would loop + // indefinitely. + let eventLoop = EmbeddedEventLoop() + let promise = eventLoop.makePromise(of: Void.self) + + final class WriteAndFlushOnActive: ChannelInboundHandler { + typealias InboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + private let promise: EventLoopPromise + + init(promise: EventLoopPromise) { + self.promise = promise + } + + func channelActive(context: ChannelHandlerContext) { + let buffer = context.channel.allocator.buffer(string: "You spin me right 'round") + context.writeAndFlush(self.wrapOutboundOut(buffer), promise: self.promise) + context.fireChannelActive() + } + } + + let context = try self.configuredSSLContext() + let handler = try NIOSSLClientHandler(context: context, serverHostname: nil) + let channel = EmbeddedChannel( + handlers: [handler, WriteAndFlushOnActive(promise: promise)], + loop: eventLoop + ) + + // Close _before_ channel active. This shouldn't (but can https://github.com/apple/swift-nio/issues/2773) + // happen for 'real' channels by synchronously closing the channel when the connect promise + // is succeeded. + channel.pipeline.fireChannelInactive() + channel.pipeline.fireChannelActive() + + // The handshake starts in channelActive (and handlerAdded if the channel is already + // active). If the events are reordered then the handshake shouldn't start and there + // shouldn't be any outbound data. + XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self)) + + // The write promise should fail. + XCTAssertThrowsError(try promise.futureResult.wait()) { error in + XCTAssertEqual(error as? ChannelError, .ioOnClosedChannel) + } + + // Subsequent writes should also fail. + XCTAssertThrowsError(try channel.writeOutbound(ByteBuffer(string: "Like a record, baby, right 'round"))) { error in + XCTAssertEqual(error as? ChannelError, .ioOnClosedChannel) + } + } }