diff --git a/DotNetty.sln b/DotNetty.sln index f5198ca62..b93e4bb9f 100644 --- a/DotNetty.sln +++ b/DotNetty.sln @@ -92,6 +92,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "local-build", "local-build" localRestore.cmd = localRestore.cmd EndProjectSection EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DotNetty.Handlers.Proxy", "src\DotNetty.Handlers.Proxy\DotNetty.Handlers.Proxy.csproj", "{9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DotNetty.Handlers.Proxy.Tests", "test\DotNetty.Handlers.Proxy.Tests\DotNetty.Handlers.Proxy.Tests.csproj", "{8A11F53C-02FD-4537-9BC9-0525489F128B}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -520,6 +524,38 @@ Global {920F73C7-7FBE-44BE-8A99-3A394207D4C8}.Release|x64.Build.0 = Release|Any CPU {920F73C7-7FBE-44BE-8A99-3A394207D4C8}.Release|x86.ActiveCfg = Release|Any CPU {920F73C7-7FBE-44BE-8A99-3A394207D4C8}.Release|x86.Build.0 = Release|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Debug|ARM.ActiveCfg = Debug|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Debug|ARM.Build.0 = Debug|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Debug|x64.ActiveCfg = Debug|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Debug|x64.Build.0 = Debug|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Debug|x86.ActiveCfg = Debug|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Debug|x86.Build.0 = Debug|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Release|Any CPU.Build.0 = Release|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Release|ARM.ActiveCfg = Release|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Release|ARM.Build.0 = Release|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Release|x64.ActiveCfg = Release|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Release|x64.Build.0 = Release|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Release|x86.ActiveCfg = Release|Any CPU + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC}.Release|x86.Build.0 = Release|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Debug|ARM.ActiveCfg = Debug|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Debug|ARM.Build.0 = Debug|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Debug|x64.ActiveCfg = Debug|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Debug|x64.Build.0 = Debug|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Debug|x86.ActiveCfg = Debug|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Debug|x86.Build.0 = Debug|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Release|Any CPU.Build.0 = Release|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Release|ARM.ActiveCfg = Release|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Release|ARM.Build.0 = Release|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Release|x64.ActiveCfg = Release|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Release|x64.Build.0 = Release|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Release|x86.ActiveCfg = Release|Any CPU + {8A11F53C-02FD-4537-9BC9-0525489F128B}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -553,6 +589,8 @@ Global {E6B102FE-C706-4C40-B4F9-569EFC89B70F} = {01F3CC7E-F996-411E-AFD6-72673A826549} {920F73C7-7FBE-44BE-8A99-3A394207D4C8} = {01F3CC7E-F996-411E-AFD6-72673A826549} {E27C94F8-A148-46D4-A1E0-2CC2B1FBECE9} = {013DFD29-E1DB-4968-A67B-C2342E6F5B6E} + {9A960CAF-E1BB-49F0-8F4F-7FA52F787CFC} = {3D04C4DC-6F8E-4326-9569-92F3E26C6EEB} + {8A11F53C-02FD-4537-9BC9-0525489F128B} = {01F3CC7E-F996-411E-AFD6-72673A826549} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {A659CEFB-DDB3-49BE-AEDD-FF2F1B3297DB} diff --git a/src/DotNetty.Codecs/Base64/Base64.cs b/src/DotNetty.Codecs/Base64/Base64.cs index 71f5f211e..14587ed43 100644 --- a/src/DotNetty.Codecs/Base64/Base64.cs +++ b/src/DotNetty.Codecs/Base64/Base64.cs @@ -39,6 +39,8 @@ public static class Base64 const sbyte EQUALS_SIGN_ENC = -1; // Indicates equals sign in encoding public static IByteBuffer Encode(IByteBuffer src) => Encode(src, Base64Dialect.Standard); + + public static IByteBuffer Encode(IByteBuffer src, bool breakLines) => Encode(src, breakLines, Base64Dialect.Standard); public static IByteBuffer Encode(IByteBuffer src, IBase64Dialect dialect) => Encode(src, src.ReaderIndex, src.ReadableBytes, dialect.BreakLinesByDefault, dialect); diff --git a/src/DotNetty.Handlers.Proxy/DotNetty.Handlers.Proxy.csproj b/src/DotNetty.Handlers.Proxy/DotNetty.Handlers.Proxy.csproj new file mode 100644 index 000000000..e5fe93f95 --- /dev/null +++ b/src/DotNetty.Handlers.Proxy/DotNetty.Handlers.Proxy.csproj @@ -0,0 +1,24 @@ + + + + + $(StandardTfms) + DotNetty.Handlers.Proxy + SpanNetty.Handlers.Proxy + false + + + + SpanNetty.Handlers.Proxy + SpanNetty.Handlers.Proxy + Protobuf Proto3 codec. + socket;tcp;protocol;netty;dotnetty;network;proxy;webproxy;httpproxy;tunnelproxy + + + + + + + + + diff --git a/src/DotNetty.Handlers.Proxy/HttpProxyConnectException.cs b/src/DotNetty.Handlers.Proxy/HttpProxyConnectException.cs new file mode 100644 index 000000000..5bcaccea5 --- /dev/null +++ b/src/DotNetty.Handlers.Proxy/HttpProxyConnectException.cs @@ -0,0 +1,47 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * Copyright (c) 2020 The Dotnetty-Span-Fork Project (cuteant@outlook.com) All rights reserved. + * + * https://github.com/cuteant/dotnetty-span-fork + * + * Licensed under the MIT license. See LICENSE file in the project root for full license information. + */ + +using DotNetty.Codecs.Http; + +namespace DotNetty.Handlers.Proxy +{ + /// + /// Specific case of a connection failure, which may include headers from the proxy. + /// + public sealed class HttpProxyConnectException : ProxyConnectException + { + /// + /// @param message The failure message. + /// @param headers Header associated with the connection failure. May be {@code null}. + /// + public HttpProxyConnectException(string message, HttpHeaders headers) + : base(message) + { + this.Headers = headers; + } + + /// + /// Returns headers, if any. May be {@code null}. + /// + public HttpHeaders Headers { get; } + } +} \ No newline at end of file diff --git a/src/DotNetty.Handlers.Proxy/HttpProxyHandler.cs b/src/DotNetty.Handlers.Proxy/HttpProxyHandler.cs new file mode 100644 index 000000000..101f50e39 --- /dev/null +++ b/src/DotNetty.Handlers.Proxy/HttpProxyHandler.cs @@ -0,0 +1,299 @@ +using System; +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Threading.Tasks; +using DotNetty.Buffers; +using DotNetty.Codecs.Base64; +using DotNetty.Codecs.Http; +using DotNetty.Common; +using DotNetty.Common.Concurrency; +using DotNetty.Common.Utilities; +using DotNetty.Transport.Channels; + +namespace DotNetty.Handlers.Proxy +{ + public class HttpProxyHandler : ProxyHandler + { + static readonly string PROTOCOL = "http"; + static readonly string AuthBasic = "basic"; + + /// + /// Wrapper for the HttpClientCodec to prevent it to be removed by other handlers by mistake (for example the WebSocket*Handshaker). + /// See: + /// - https://github.com/netty/netty/issues/5201 + /// - https://github.com/netty/netty/issues/5070 + /// + private readonly HttpClientCodecWrapper _codecWrapper = new HttpClientCodecWrapper(); + + private readonly string _username; + private readonly string _password; + private readonly ICharSequence _authorization; + private readonly HttpHeaders _outboundHeaders; + private readonly bool _ignoreDefaultPortsInConnectHostHeader; + + HttpResponseStatus _status; + HttpHeaders _inboundHeaders; + + public HttpProxyHandler(EndPoint proxyAddress) + : this(proxyAddress, null) + { + } + + public HttpProxyHandler(EndPoint proxyAddress, HttpHeaders headers) + : this(proxyAddress, headers, false) + { + } + + public HttpProxyHandler(EndPoint proxyAddress, HttpHeaders headers, bool ignoreDefaultPortsInConnectHostHeader) + : base(proxyAddress) + { + _username = null; + _password = null; + _authorization = null; + _outboundHeaders = headers; + _ignoreDefaultPortsInConnectHostHeader = ignoreDefaultPortsInConnectHostHeader; + } + + public HttpProxyHandler(EndPoint proxyAddress, string username, string password) + : this(proxyAddress, username, password, null) + { + } + + public HttpProxyHandler(EndPoint proxyAddress, string username, string password, HttpHeaders headers) + : this(proxyAddress, username, password, headers, false) + { + } + + public HttpProxyHandler( + EndPoint proxyAddress, + string username, + string password, + HttpHeaders headers, + bool ignoreDefaultPortsInConnectHostHeader) + : base(proxyAddress) + { + + if (username is null) + { + throw new ArgumentNullException(nameof(username)); + } + + if (password is null) + { + throw new ArgumentNullException(nameof(password)); + } + + IByteBuffer authz = Unpooled.CopiedBuffer(username + ':' + password, Encoding.UTF8); + + IByteBuffer authzBase64; + try + { + authzBase64 = Base64.Encode(authz, false); + } + finally + { + authz.Release(); + } + + try + { + _authorization = new AsciiString("Basic " + authzBase64.ToString(Encoding.ASCII)); + } + finally + { + authzBase64.Release(); + } + + _outboundHeaders = headers; + _ignoreDefaultPortsInConnectHostHeader = ignoreDefaultPortsInConnectHostHeader; + } + + public override string Protocol => PROTOCOL; + + public override string AuthScheme => _authorization != null ? AuthBasic : AuthNone; + + public string Username => _username; + + public string Password => _password; + + protected override void AddCodec(IChannelHandlerContext ctx) + { + IChannelPipeline p = ctx.Channel.Pipeline; + string name = ctx.Name; + p.AddBefore(name, null, _codecWrapper); + } + + protected override void RemoveEncoder(IChannelHandlerContext ctx) + { + _codecWrapper._codec.RemoveOutboundHandler(); + } + + protected override void RemoveDecoder(IChannelHandlerContext ctx) + { + _codecWrapper._codec.RemoveInboundHandler(); + } + + protected override object NewInitialMessage(IChannelHandlerContext ctx) + { + if (!TryParseEndpoint(DestinationAddress, out string hostnameString, out int port)) + { + throw new NotSupportedException($"Endpoint {DestinationAddress} is not supported as http proxy destination"); + } + + string url = hostnameString + ":" + port; + string hostHeader = _ignoreDefaultPortsInConnectHostHeader && (port == 80 || port == 443) ? hostnameString : url; + + IFullHttpRequest req = new DefaultFullHttpRequest(DotNetty.Codecs.Http.HttpVersion.Http11, HttpMethod.Connect, url, Unpooled.Empty, false); + + req.Headers.Set(HttpHeaderNames.Host, hostHeader); + + if (_authorization != null) + { + req.Headers.Set(HttpHeaderNames.ProxyAuthorization, _authorization); + } + + if (_outboundHeaders != null) + { + req.Headers.Add(_outboundHeaders); + } + + return req; + } + + protected override bool HandleResponse(IChannelHandlerContext ctx, object response) + { + if (response is IHttpResponse) + { + if (_status != null) + { + throw new HttpProxyConnectException(ExceptionMessage("too many responses"), /*headers=*/ null); + } + + IHttpResponse res = (IHttpResponse)response; + _status = res.Status; + _inboundHeaders = res.Headers; + } + + bool finished = response is ILastHttpContent; + if (finished) + { + if (_status == null) + { + throw new HttpProxyConnectException(ExceptionMessage("missing response"), _inboundHeaders); + } + + if (_status.Code != 200) + { + throw new HttpProxyConnectException(ExceptionMessage("status: " + _status), _inboundHeaders); + } + } + + return finished; + } + + /// + /// Formats the host string of an address so it can be used for computing an HTTP component + /// such as a URL or a Host header + /// + /// addr the address + /// + /// + /// the formatted String + static bool TryParseEndpoint(EndPoint addr, out string hostnameString, out int port) + { + hostnameString = null; + port = 0; + + if (addr is DnsEndPoint eDns) + { + hostnameString = eDns.Host; + port = eDns.Port; + return true; + } + else if (addr is IPEndPoint eIp) + { + port = eIp.Port; + switch (addr.AddressFamily) + { + case AddressFamily.InterNetwork: + hostnameString = eIp.Address.ToString(); + return true; + + case AddressFamily.InterNetworkV6: + hostnameString = $"[{eIp.Address}]"; + return true; + + default: + return false; + } + } + else + { + return false; + } + } + + private sealed class HttpClientCodecWrapper : ChannelDuplexHandler + { + internal readonly HttpClientCodec _codec = new HttpClientCodec(); + + public override void HandlerAdded(IChannelHandlerContext context) + => _codec.HandlerAdded(context); + + public override void HandlerRemoved(IChannelHandlerContext context) + => _codec.HandlerRemoved(context); + + public override void ExceptionCaught(IChannelHandlerContext context, Exception exception) + => _codec.ExceptionCaught(context, exception); + + public override void ChannelRegistered(IChannelHandlerContext context) + => _codec.ChannelRegistered(context); + + public override void ChannelUnregistered(IChannelHandlerContext context) + => _codec.ChannelUnregistered(context); + + public override void ChannelActive(IChannelHandlerContext context) + => _codec.ChannelActive(context); + + public override void ChannelInactive(IChannelHandlerContext context) + => _codec.ChannelInactive(context); + + public override void ChannelRead(IChannelHandlerContext context, object message) + => _codec.ChannelRead(context, message); + + public override void ChannelReadComplete(IChannelHandlerContext context) + => _codec.ChannelReadComplete(context); + + public override void UserEventTriggered(IChannelHandlerContext context, object evt) + => _codec.UserEventTriggered(context, evt); + + public override void ChannelWritabilityChanged(IChannelHandlerContext context) + => _codec.ChannelWritabilityChanged(context); + + public override Task BindAsync(IChannelHandlerContext context, EndPoint localAddress) + => _codec.BindAsync(context, localAddress); + + public override Task ConnectAsync(IChannelHandlerContext context, EndPoint remoteAddress, EndPoint localAddress) + => _codec.ConnectAsync(context, remoteAddress, localAddress); + + public override void Disconnect(IChannelHandlerContext context, IPromise promise) + => _codec.Disconnect(context, promise); + + public override void Close(IChannelHandlerContext context, IPromise promise) + => _codec.Close(context, promise); + + public override void Deregister(IChannelHandlerContext context, IPromise promise) + => _codec.Deregister(context, promise); + + public override void Read(IChannelHandlerContext context) + => _codec.Read(context); + + public override void Write(IChannelHandlerContext context, object message, IPromise promise) + => _codec.Write(context, message, promise); + + public override void Flush(IChannelHandlerContext context) + => _codec.Flush(context); + } + } +} \ No newline at end of file diff --git a/src/DotNetty.Handlers.Proxy/ProxyConnectException.cs b/src/DotNetty.Handlers.Proxy/ProxyConnectException.cs new file mode 100644 index 000000000..232b2ac3a --- /dev/null +++ b/src/DotNetty.Handlers.Proxy/ProxyConnectException.cs @@ -0,0 +1,41 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * Copyright (c) 2020 The Dotnetty-Span-Fork Project (cuteant@outlook.com) All rights reserved. + * + * https://github.com/cuteant/dotnetty-span-fork + * + * Licensed under the MIT license. See LICENSE file in the project root for full license information. + */ + +using System; +using DotNetty.Transport.Channels; + +namespace DotNetty.Handlers.Proxy +{ + public class ProxyConnectException : ConnectException + { + public ProxyConnectException(string msg) : base(msg, null) + { } + + public ProxyConnectException(Exception cause) :base(null, cause) + { + } + + public ProxyConnectException(string message, Exception innerException) : base(message, innerException) + { + } + } +} \ No newline at end of file diff --git a/src/DotNetty.Handlers.Proxy/ProxyConnectionEvent.cs b/src/DotNetty.Handlers.Proxy/ProxyConnectionEvent.cs new file mode 100644 index 000000000..c8b2bb81b --- /dev/null +++ b/src/DotNetty.Handlers.Proxy/ProxyConnectionEvent.cs @@ -0,0 +1,84 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * Copyright (c) 2020 The Dotnetty-Span-Fork Project (cuteant@outlook.com) All rights reserved. + * + * https://github.com/cuteant/dotnetty-span-fork + * + * Licensed under the MIT license. See LICENSE file in the project root for full license information. + */ + +using System; +using System.Net; +using System.Text; + +namespace DotNetty.Handlers.Proxy +{ + /// + /// Creates a new event that indicates a successful connection attempt to the destination address. + /// + public sealed class ProxyConnectionEvent + { + private string _strVal; + + public ProxyConnectionEvent(string protocol, string authScheme, EndPoint proxyAddress, + EndPoint destinationAddress) + { + Protocol = protocol ?? throw new ArgumentNullException(nameof(protocol)); + AuthScheme = authScheme ?? throw new ArgumentNullException(nameof(authScheme)); + ProxyAddress = proxyAddress ?? throw new ArgumentNullException(nameof(proxyAddress)); + DestinationAddress = destinationAddress ?? throw new ArgumentNullException(nameof(destinationAddress)); + } + + /// + ///Returns the name of the proxy protocol in use. + /// + public string Protocol { get; } + + /// + /// Returns the name of the authentication scheme in use. + /// + public string AuthScheme { get; } + + /// + /// Returns the address of the proxy server. + /// + public EndPoint ProxyAddress { get; } + + /// + /// Returns the address of the destination. + /// + public EndPoint DestinationAddress { get; } + + public override string ToString() + { + if (_strVal != null) return _strVal; + + var buf = new StringBuilder(128) + .Append(typeof(ProxyConnectionEvent).Name) + .Append('(') + .Append(Protocol) + .Append(", ") + .Append(AuthScheme) + .Append(", ") + .Append(ProxyAddress) + .Append(" => ") + .Append(DestinationAddress) + .Append(')'); + + return _strVal = buf.ToString(); + } + } +} \ No newline at end of file diff --git a/src/DotNetty.Handlers.Proxy/ProxyHandler.cs b/src/DotNetty.Handlers.Proxy/ProxyHandler.cs new file mode 100644 index 000000000..5ddbf8f3d --- /dev/null +++ b/src/DotNetty.Handlers.Proxy/ProxyHandler.cs @@ -0,0 +1,515 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * Copyright (c) 2020 The Dotnetty-Span-Fork Project (cuteant@outlook.com) All rights reserved. + * + * https://github.com/cuteant/dotnetty-span-fork + * + * Licensed under the MIT license. See LICENSE file in the project root for full license information. + */ + +using System; +using System.Net; +using System.Text; +using System.Threading.Tasks; +using DotNetty.Common.Concurrency; +using DotNetty.Common.Internal.Logging; +using DotNetty.Common.Utilities; +using DotNetty.Transport.Channels; + +namespace DotNetty.Handlers.Proxy +{ + public abstract class ProxyHandler : ChannelDuplexHandler + { + static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + + /// + /// The default connect timeout: 10 seconds. + /// + static readonly TimeSpan DefaultConnectTimeout = TimeSpan.FromMilliseconds(10000); + + /// + /// A string that signifies 'no authentication' or 'anonymous'. + /// + protected const string AuthNone = "none"; + + private readonly EndPoint _proxyAddress; + private readonly TaskCompletionSource _connectPromise = new TaskCompletionSource(); + + private volatile EndPoint _destinationAddress; + private TimeSpan _connectTimeout = DefaultConnectTimeout; + + private IChannelHandlerContext _ctx; + private PendingWriteQueue _pendingWrites; + private bool _finished; + private bool _suppressChannelReadComplete; + private bool _flushedPrematurely; + + private IScheduledTask _connectTimeoutFuture; + + protected ProxyHandler(EndPoint proxyAddress) + { + _proxyAddress = proxyAddress ?? throw new ArgumentNullException(nameof(proxyAddress)); + } + + /// + /// Returns the name of the proxy protocol in use. + /// + public abstract string Protocol { get; } + + /// + /// Returns the name of the authentication scheme in use. + /// + public abstract string AuthScheme { get; } + + /// + /// Returns the address of the proxy server. + /// + public EndPoint ProxyAddress => _proxyAddress; + + /// + /// Returns the address of the destination to connect to via the proxy server. + /// + public EndPoint DestinationAddress => _destinationAddress; + + /// + /// Returns {@code true} if and only if the connection to the destination has been established successfully. + /// + public bool Connected => _connectPromise.Task.Status == TaskStatus.RanToCompletion; + + /// + /// Returns a {@link Future} that is notified when the connection to the destination has been established + /// or the connection attempt has failed. + /// + public Task ConnectFuture => _connectPromise.Task; + + /// + /// Connect timeout. If the connection attempt to the destination does not finish within + /// the timeout, the connection attempt will be failed. + /// + public TimeSpan ConnectTimeout + { + get => _connectTimeout; + set + { + if (value <= TimeSpan.Zero) + { + value = TimeSpan.Zero; + } + + _connectTimeout = value; + } + } + + public override void HandlerAdded(IChannelHandlerContext ctx) + { + _ctx = ctx; + + AddCodec(ctx); + + if (ctx.Channel.IsActive) + { + // channelActive() event has been fired already, which means channelActive() will + // not be invoked. We have to initialize here instead. + SendInitialMessage(ctx); + } + else + { + // channelActive() event has not been fired yet. channelOpen() will be invoked + // and initialization will occur there. + } + } + + /// + /// Adds the codec handlers required to communicate with the proxy server. + /// + protected abstract void AddCodec(IChannelHandlerContext ctx); + + /// + /// Removes the encoders added in {@link #addCodec(IChannelHandlerContext)}. + /// + protected abstract void RemoveEncoder(IChannelHandlerContext ctx); + + /// + /// Removes the decoders added in {@link #addCodec(IChannelHandlerContext)}. + /// + protected abstract void RemoveDecoder(IChannelHandlerContext ctx); + + public override Task ConnectAsync(IChannelHandlerContext context, EndPoint remoteAddress, EndPoint localAddress) + { + if (_destinationAddress != null) + { + return TaskUtil.FromException(new ConnectionPendingException()); + } + + _destinationAddress = remoteAddress; + + return _ctx.ConnectAsync(_proxyAddress, localAddress); + } + + public override void ChannelActive(IChannelHandlerContext ctx) + { + SendInitialMessage(ctx); + ctx.FireChannelActive(); + } + + /// + /// Sends the initial message to be sent to the proxy server. This method also starts a timeout task which marks + /// the {@link #connectPromise} as failure if the connection attempt does not success within the timeout. + /// + void SendInitialMessage(IChannelHandlerContext ctx) + { + var connectTimeout = _connectTimeout; + if (connectTimeout > TimeSpan.Zero) + { + _connectTimeoutFuture = ctx.Executor.Schedule(ConnectTimeout, connectTimeout); + } + + object initialMessage = NewInitialMessage(ctx); + if (initialMessage != null) + { + SendToProxyServer(initialMessage); + } + + ReadIfNeeded(ctx); + + void ConnectTimeout() + { + if (!_connectPromise.Task.IsCompleted) + { + SetConnectFailure(new ProxyConnectException(ExceptionMessage("timeout"))); + } + } + } + + /// + /// Returns a new message that is sent at first time when the connection to the proxy server has been established. + /// + /// + /// the initial message, or {@code null} if the proxy server is expected to send the first message instead + protected abstract object NewInitialMessage(IChannelHandlerContext ctx); + + /// + /// Sends the specified message to the proxy server. Use this method to send a response to the proxy server in + /// {@link #handleResponse(IChannelHandlerContext, object)}. + /// + protected void SendToProxyServer(object msg) + { + _ctx.WriteAndFlushAsync(msg).ContinueWith(OnCompleted, TaskContinuationOptions.NotOnRanToCompletion | TaskContinuationOptions.ExecuteSynchronously); + + void OnCompleted(Task future) + { + SetConnectFailure(future.Exception); + } + } + + public override void ChannelInactive(IChannelHandlerContext ctx) + { + if (_finished) + { + ctx.FireChannelInactive(); + } + else + { + // Disconnected before connected to the destination. + SetConnectFailure(new ProxyConnectException(ExceptionMessage("disconnected"))); + } + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception cause) + { + if (_finished) + { + ctx.FireExceptionCaught(cause); + } + else + { + // Exception was raised before the connection attempt is finished. + SetConnectFailure(cause); + } + } + + public override void ChannelRead(IChannelHandlerContext ctx, object msg) + { + if (_finished) + { + // Received a message after the connection has been established; pass through. + _suppressChannelReadComplete = false; + ctx.FireChannelRead(msg); + } + else + { + _suppressChannelReadComplete = true; + Exception cause = null; + try + { + bool done = HandleResponse(ctx, msg); + if (done) + { + SetConnectSuccess(); + } + } + catch (Exception t) + { + cause = t; + } + finally + { + ReferenceCountUtil.Release(msg); + if (cause != null) + { + SetConnectFailure(cause); + } + } + } + } + + /// + /// expected from the proxy server + /// + /// + /// + /// + /// {@code true} if the connection to the destination has been established, + /// {@code false} if the connection to the destination has not been established and more messages are expected from the proxy server + /// + protected abstract bool HandleResponse(IChannelHandlerContext ctx, object response); + + void SetConnectSuccess() + { + _finished = true; + + CancelConnectTimeoutFuture(); + + if (!_connectPromise.Task.IsCompleted) + { + bool removedCodec = true; + + removedCodec &= SafeRemoveEncoder(); + + _ctx.FireUserEventTriggered( + new ProxyConnectionEvent(Protocol, AuthScheme, _proxyAddress, _destinationAddress)); + + removedCodec &= SafeRemoveDecoder(); + + if (removedCodec) + { + WritePendingWrites(); + + if (_flushedPrematurely) + { + _ctx.Flush(); + } + + _connectPromise.TrySetResult(_ctx.Channel); + } + else + { + // We are at inconsistent state because we failed to remove all codec handlers. + Exception cause = new ProxyConnectException( + "failed to remove all codec handlers added by the proxy handler; bug?"); + FailPendingWritesAndClose(cause); + } + } + } + + bool SafeRemoveDecoder() + { + try + { + RemoveDecoder(_ctx); + return true; + } + catch (Exception e) + { + Logger.Warn("Failed to remove proxy decoders:", e); + } + + return false; + } + + bool SafeRemoveEncoder() + { + try + { + RemoveEncoder(_ctx); + return true; + } + catch (Exception e) + { + Logger.Warn("Failed to remove proxy encoders:", e); + } + + return false; + } + + void SetConnectFailure(Exception cause) + { + _finished = true; + + CancelConnectTimeoutFuture(); + + if (!_connectPromise.Task.IsCompleted) + { + if (!(cause is ProxyConnectException)) + { + cause = new ProxyConnectException(ExceptionMessage(cause.ToString()), cause); + } + + SafeRemoveDecoder(); + SafeRemoveEncoder(); + FailPendingWritesAndClose(cause); + } + } + + void FailPendingWritesAndClose(Exception cause) + { + FailPendingWrites(cause); + + _connectPromise.TrySetException(cause); + + _ctx.FireExceptionCaught(cause); + + _ctx.CloseAsync(); + } + + void CancelConnectTimeoutFuture() + { + if (_connectTimeoutFuture != null) + { + _connectTimeoutFuture.Cancel(); + _connectTimeoutFuture = null; + } + } + + /// + /// Decorates the specified exception message with the common information such as the current protocol, + /// authentication scheme, proxy address, and destination address. + /// + protected string ExceptionMessage(string msg) + { + if (msg == null) + { + msg = ""; + } + + StringBuilder buf = new StringBuilder(128 + msg.Length) + .Append(Protocol) + .Append(", ") + .Append(AuthScheme) + .Append(", ") + .Append(_proxyAddress) + .Append(" => ") + .Append(_destinationAddress); + + if (!string.IsNullOrEmpty(msg)) + { + buf.Append(", ").Append(msg); + } + + return buf.ToString(); + } + + public override void ChannelReadComplete(IChannelHandlerContext ctx) + { + if (_suppressChannelReadComplete) + { + _suppressChannelReadComplete = false; + + ReadIfNeeded(ctx); + } + else + { + ctx.FireChannelReadComplete(); + } + } + + public override void Write(IChannelHandlerContext context, object message, IPromise promise) + { + if (_finished) + { + WritePendingWrites(); + base.Write(context, message, promise); + } + else + { + AddPendingWrite(_ctx, message, promise); + } + } + + public override void Flush(IChannelHandlerContext context) + { + if (_finished) + { + WritePendingWrites(); + _ctx.Flush(); + } + else + { + _flushedPrematurely = true; + } + } + + static void ReadIfNeeded(IChannelHandlerContext ctx) + { + if (!ctx.Channel.Configuration.IsAutoRead) + { + ctx.Read(); + } + } + + void WritePendingWrites() + { + if (_pendingWrites != null) + { + _pendingWrites.RemoveAndWriteAllAsync(); + _pendingWrites = null; + } + } + + void FailPendingWrites(Exception cause) + { + if (_pendingWrites != null) + { + _pendingWrites.RemoveAndFailAll(cause); + _pendingWrites = null; + } + } + + void AddPendingWrite(IChannelHandlerContext ctx, object msg, IPromise promise) + { + PendingWriteQueue pendingWrites = _pendingWrites; + if (pendingWrites == null) + { + _pendingWrites = pendingWrites = new PendingWriteQueue(ctx); + } + + pendingWrites.Add(msg, promise); + } + + protected IEventExecutor Executor + { + get + { + if (_ctx == null) + { + throw new Exception("Should not reach here"); + } + + return _ctx.Executor; + } + } + } +} \ No newline at end of file diff --git a/src/DotNetty.Transport/Bootstrapping/NoopNameResolver.cs b/src/DotNetty.Transport/Bootstrapping/NoopNameResolver.cs new file mode 100644 index 000000000..57ae1a8a3 --- /dev/null +++ b/src/DotNetty.Transport/Bootstrapping/NoopNameResolver.cs @@ -0,0 +1,42 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * Copyright (c) The DotNetty Project (Microsoft). All rights reserved. + * + * https://github.com/azure/dotnetty + * + * Licensed under the MIT license. See LICENSE file in the project root for full license information. + * + * Copyright (c) 2020 The Dotnetty-Span-Fork Project (cuteant@outlook.com) All rights reserved. + * + * https://github.com/cuteant/dotnetty-span-fork + * + * Licensed under the MIT license. See LICENSE file in the project root for full license information. + */ + +using System.Net; +using System.Threading.Tasks; + +namespace DotNetty.Transport.Bootstrapping +{ + public class NoopNameResolver : INameResolver + { + public static readonly NoopNameResolver Instance = new NoopNameResolver(); + + public bool IsResolved(EndPoint address) => true; + + public Task ResolveAsync(EndPoint address) => Task.FromResult(address); + } +} \ No newline at end of file diff --git a/test/DotNetty.Handlers.Proxy.Tests/DotNetty.Handlers.Proxy.Tests.csproj b/test/DotNetty.Handlers.Proxy.Tests/DotNetty.Handlers.Proxy.Tests.csproj new file mode 100644 index 000000000..d7390cef8 --- /dev/null +++ b/test/DotNetty.Handlers.Proxy.Tests/DotNetty.Handlers.Proxy.Tests.csproj @@ -0,0 +1,25 @@ + + + + + $(StandardTestTfms) + DotNetty.Handlers.Proxy.Tests + DotNetty.Handlers.Proxy.Tests + false + + + + + + + + + + + + + + + + + diff --git a/test/DotNetty.Handlers.Proxy.Tests/HttpProxyHandlerTest.cs b/test/DotNetty.Handlers.Proxy.Tests/HttpProxyHandlerTest.cs new file mode 100644 index 000000000..b1f564f39 --- /dev/null +++ b/test/DotNetty.Handlers.Proxy.Tests/HttpProxyHandlerTest.cs @@ -0,0 +1,277 @@ +using System; +using System.Linq; +using System.Net; +using System.Threading.Tasks; +using DotNetty.Codecs.Http; +using DotNetty.Common.Concurrency; +using DotNetty.Common.Utilities; +using DotNetty.Transport.Bootstrapping; +using DotNetty.Transport.Channels; +using DotNetty.Transport.Channels.Embedded; +using DotNetty.Transport.Channels.Local; +using Moq; +using Xunit; +using HttpVersion = DotNetty.Codecs.Http.HttpVersion; + +namespace DotNetty.Handlers.Proxy.Tests +{ + public class HttpProxyHandlerTest + { + [Fact] + public void TestHostname() + { + EndPoint socketAddress = new DnsEndPoint("localhost", 8080); + TestInitialMessage( + socketAddress, + "localhost:8080", + "localhost:8080", + null, + true); + } + + [Fact] + public void TestHostnameUnresolved() + { + EndPoint socketAddress = new DnsEndPoint("localhost", 8080); + TestInitialMessage( + socketAddress, + "localhost:8080", + "localhost:8080", + null, + true); + } + + [Fact] + public void TestHostHeaderWithHttpDefaultPort() + { + EndPoint socketAddress = new DnsEndPoint("localhost", 80); + TestInitialMessage(socketAddress, + "localhost:80", + "localhost:80", null, + false); + } + + [Fact] + public void TestHostHeaderWithHttpDefaultPortIgnored() + { + EndPoint socketAddress = new DnsEndPoint("localhost", 80); + TestInitialMessage( + socketAddress, + "localhost:80", + "localhost", + null, + true); + } + + [Fact] + public void TestHostHeaderWithHttpsDefaultPort() + { + EndPoint socketAddress = new DnsEndPoint("localhost", 443); + TestInitialMessage( + socketAddress, + "localhost:443", + "localhost:443", + null, + false); + } + + [Fact] + public void TestHostHeaderWithHttpsDefaultPortIgnored() + { + EndPoint socketAddress = new DnsEndPoint("localhost", 443); + TestInitialMessage( + socketAddress, + "localhost:443", + "localhost", + null, + true); + } + + [Fact] + public void TestIpv6() + { + EndPoint socketAddress = new IPEndPoint(IPAddress.Parse("::1"), 8080); + TestInitialMessage( + socketAddress, + "[::1]:8080", + "[::1]:8080", + null, + true); + } + + [Fact] + public void TestIpv6Unresolved() + { + EndPoint socketAddress = new DnsEndPoint("foo.bar", 8080); + TestInitialMessage( + socketAddress, + "foo.bar:8080", + "foo.bar:8080", + null, + true); + } + + [Fact] + public void TestIpv4() + { + EndPoint socketAddress = new IPEndPoint(IPAddress.Parse("10.0.0.1"), 8080); + TestInitialMessage(socketAddress, + "10.0.0.1:8080", + "10.0.0.1:8080", + null, + true); + } + + [Fact] + public void TestIpv4Unresolved() + { + EndPoint socketAddress = new DnsEndPoint("10.0.0.1", 8080); + TestInitialMessage( + socketAddress, + "10.0.0.1:8080", + "10.0.0.1:8080", + null, + true); + } + + [Fact] + public void TestCustomHeaders() + { + EndPoint socketAddress = new DnsEndPoint("10.0.0.1", 8080); + TestInitialMessage( + socketAddress, + "10.0.0.1:8080", + "10.0.0.1:8080", + new DefaultHttpHeaders() + .Add(AsciiString.Of("CUSTOM_HEADER"), "CUSTOM_VALUE1") + .Add(AsciiString.Of("CUSTOM_HEADER"), "CUSTOM_VALUE2"), + true); + } + + [Fact] + public void TestExceptionDuringConnect() + { + IEventLoopGroup group = null; + IChannel serverChannel = null; + IChannel clientChannel = null; + try + { + group = new DefaultEventLoopGroup(1); + var addr = new LocalAddress("a"); + var exception = new AtomicReference(); + var sf = + new ServerBootstrap().Channel().Group(group).ChildHandler( + new ActionChannelInitializer(ch => + { + ch.Pipeline.AddFirst(new HttpResponseEncoder()); + var response = new DefaultFullHttpResponse( + HttpVersion.Http11, + HttpResponseStatus.BadGateway); + response.Headers.Add(AsciiString.Of("name"), "value"); + response.Headers.Add(HttpHeaderNames.ContentLength, "0"); + ch.WriteAndFlushAsync(response); + } + )).BindAsync(addr); + serverChannel = sf.Result; + + var cf = new Bootstrap().Channel().Group(group).Handler( + new ActionChannelInitializer(ch => + { + ch.Pipeline.AddFirst(new HttpProxyHandler(addr)); + ch.Pipeline.AddLast(new ErrorCaptureHandler(exception)); + })).ConnectAsync(new DnsEndPoint("localhost", 1234)); + + clientChannel = cf.Result; + clientChannel.CloseAsync().Wait(); + + Assert.True(exception.Value is HttpProxyConnectException); + var actual = (HttpProxyConnectException) exception.Value; + Assert.NotNull(actual.Headers); + Assert.Equal("value", actual.Headers.GetAsString(AsciiString.Of("name"))); + } + finally + { + if (clientChannel != null) clientChannel.CloseAsync(); + if (serverChannel != null) serverChannel.CloseAsync(); + if (group != null) @group.ShutdownGracefullyAsync().Wait(); + } + } + + private static void TestInitialMessage(EndPoint socketAddress, + string expectedUrl, + string expectedHostHeader, + HttpHeaders headers, + bool ignoreDefaultPortsInConnectHostHeader) + { + EndPoint proxyAddress = new IPEndPoint(IPAddress.Loopback, 8080); + + var promise = new TaskCompletionSource(); + + var channel = new Mock(); + + var pipeline = new Mock(); + channel.Setup(c => c.Pipeline).Returns(pipeline.Object); + + var config = new Mock(); + channel.SetupGet(c => c.Configuration).Returns(config.Object); + + var ctx = new Mock(); + ctx.SetupGet(c => c.Channel).Returns(channel.Object); + var executor = new Mock(); + ctx.Setup(c => c.Executor).Returns(executor.Object); + ctx.Setup(c => c.ConnectAsync(proxyAddress, null)).Returns(promise.Task); + + var handler = new HttpProxyHandler( + new IPEndPoint(IPAddress.Loopback, 8080), + headers, + ignoreDefaultPortsInConnectHostHeader); + + handler.HandlerAdded(ctx.Object); + + handler.ConnectAsync(ctx.Object, socketAddress, null); + ctx.Verify(c => c.ConnectAsync(proxyAddress, null), Times.Once); + + handler.ChannelActive(ctx.Object); + ctx.Verify(c => c.WriteAndFlushAsync(It.Is(request => + request.ProtocolVersion.Equals(HttpVersion.Http11) + && request.Uri == expectedUrl + && request.Headers.GetAsString(HttpHeaderNames.Host) == expectedHostHeader + && (headers == null || headers.Names().All(name => string.Join(",", headers.GetAllAsString(name)).Equals(string.Join(",",request.Headers.GetAllAsString(name)))))) + )); + } + + [Fact] + public void TestHttpClientCodecIsInvisible() + { + EmbeddedChannel channel = + new InactiveEmbeddedChannel(new HttpProxyHandler(new IPEndPoint(IPAddress.Loopback, 8080))); + Assert.NotNull(channel.Pipeline.Get()); + Assert.Null(channel.Pipeline.Get()); + } + + class ErrorCaptureHandler : ChannelHandlerAdapter + { + private readonly AtomicReference _exception; + + public ErrorCaptureHandler(AtomicReference exception) + { + _exception = exception; + } + + public override void ExceptionCaught(IChannelHandlerContext context, Exception exception) + { + _exception.Value = exception; + } + } + + private class InactiveEmbeddedChannel : EmbeddedChannel + { + public InactiveEmbeddedChannel(params IChannelHandler[] handlers) + : base(handlers) + { + } + + public override bool IsActive => false; + } + } +} \ No newline at end of file diff --git a/test/DotNetty.Handlers.Proxy.Tests/HttpProxyServer.cs b/test/DotNetty.Handlers.Proxy.Tests/HttpProxyServer.cs new file mode 100644 index 000000000..198ae0ebc --- /dev/null +++ b/test/DotNetty.Handlers.Proxy.Tests/HttpProxyServer.cs @@ -0,0 +1,161 @@ +using System.Net; +using System.Text; +using DotNetty.Buffers; +using DotNetty.Codecs; +using DotNetty.Codecs.Base64; +using DotNetty.Codecs.Http; +using DotNetty.Transport.Channels; +using DotNetty.Transport.Channels.Sockets; +using Xunit; +using HttpVersion = DotNetty.Codecs.Http.HttpVersion; + +namespace DotNetty.Handlers.Proxy.Tests +{ + internal sealed class HttpProxyServer : ProxyServer + { + internal HttpProxyServer(bool useSsl, TestMode testMode, EndPoint destination) + : base(useSsl, testMode, destination) + { + } + + internal HttpProxyServer(bool useSsl, TestMode testMode, EndPoint destination, string username, string password) + : base(useSsl, testMode, destination, username, password) + { + } + + protected override void Configure(ISocketChannel ch) + { + var p = ch.Pipeline; + switch (TestMode) + { + case TestMode.Intermediary: + p.AddLast(new HttpServerCodec()); + p.AddLast(new HttpObjectAggregator(1)); + p.AddLast(new HttpIntermediaryHandler(this)); + break; + case TestMode.Terminal: + p.AddLast(new HttpServerCodec()); + p.AddLast(new HttpObjectAggregator(1)); + p.AddLast(new HttpTerminalHandler(this)); + break; + case TestMode.Unresponsive: + p.AddLast(UnresponsiveHandler.Instance); + break; + } + } + + bool Authenticate(IChannelHandlerContext ctx, IFullHttpRequest req) + { + Assert.Equal(req.Method, HttpMethod.Connect); + + if (TestMode != TestMode.Intermediary) + ctx.Pipeline.AddBefore(ctx.Name, "lineDecoder", new LineBasedFrameDecoder(64, false, true)); + + ctx.Pipeline.Remove(); + ctx.Pipeline.Get().RemoveInboundHandler(); + + var authzSuccess = false; + if (Username != null) + { + if (req.Headers.TryGet(HttpHeaderNames.ProxyAuthorization, out var authz)) + { + var authzParts = authz.ToString().Split(' '); + var authzBuf64 = Unpooled.CopiedBuffer(authzParts[1], Encoding.ASCII); + var authzBuf = Base64.Decode(authzBuf64); + + var expectedAuthz = Username + ':' + Password; + authzSuccess = "Basic".Equals(authzParts[0]) && + expectedAuthz.Equals(authzBuf.ToString(Encoding.ASCII)); + + authzBuf64.Release(); + authzBuf.Release(); + } + } + else + { + authzSuccess = true; + } + + return authzSuccess; + } + + private sealed class HttpIntermediaryHandler : IntermediaryHandler + { + private readonly HttpProxyServer _server; + + public HttpIntermediaryHandler(HttpProxyServer server) + : base(server) + { + _server = server; + } + + protected override EndPoint IntermediaryDestination { get; set; } + + protected override bool HandleProxyProtocol(IChannelHandlerContext ctx, object msg) + { + var req = (IFullHttpRequest) msg; + IFullHttpResponse res; + if (!_server.Authenticate(ctx, req)) + { + res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.Unauthorized); + res.Headers.Set(HttpHeaderNames.ContentLength, 0); + } + else + { + res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + var uri = req.Uri; + var lastColonPos = uri.LastIndexOf(':'); + Assert.True(lastColonPos > 0); + IntermediaryDestination = new DnsEndPoint(uri.Substring(0, lastColonPos), + int.Parse(uri.Substring(lastColonPos + 1))); + } + + ctx.WriteAsync(res); + ctx.Pipeline.Get().RemoveOutboundHandler(); + return true; + } + } + + private sealed class HttpTerminalHandler : TerminalHandler + { + private readonly HttpProxyServer _server; + + public HttpTerminalHandler(HttpProxyServer server) + : base(server) + { + _server = server; + } + + protected override bool HandleProxyProtocol(IChannelHandlerContext ctx, object msg) + { + var req = (IFullHttpRequest) msg; + IFullHttpResponse res; + var sendGreeting = false; + + if (!_server.Authenticate(ctx, req)) + { + res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.Unauthorized); + res.Headers.Set(HttpHeaderNames.ContentLength, 0); + } + else if (!req.Uri.Equals(((DnsEndPoint) _server.Destination).Host + ':' + + ((DnsEndPoint) _server.Destination).Port)) + { + res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.Forbidden); + res.Headers.Set(HttpHeaderNames.ContentLength, 0); + } + else + { + res = new DefaultFullHttpResponse(HttpVersion.Http11, HttpResponseStatus.OK); + sendGreeting = true; + } + + ctx.WriteAsync(res); + ctx.Pipeline.Get().RemoveOutboundHandler(); + + if (sendGreeting) ctx.WriteAsync(Unpooled.CopiedBuffer("0\n", Encoding.ASCII)); + + return true; + } + } + } +} \ No newline at end of file diff --git a/test/DotNetty.Handlers.Proxy.Tests/ProxyHandlerTest.cs b/test/DotNetty.Handlers.Proxy.Tests/ProxyHandlerTest.cs new file mode 100644 index 000000000..e1de55f40 --- /dev/null +++ b/test/DotNetty.Handlers.Proxy.Tests/ProxyHandlerTest.cs @@ -0,0 +1,707 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Security; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using DotNetty.Buffers; +using DotNetty.Codecs; +using DotNetty.Common.Internal.Logging; +using DotNetty.Common.Utilities; +using DotNetty.Handlers.Tls; +using DotNetty.Tests.Common; +using DotNetty.Transport.Bootstrapping; +using DotNetty.Transport.Channels; +using DotNetty.Transport.Channels.Sockets; +using Microsoft.Extensions.Logging; +using Xunit; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace DotNetty.Handlers.Proxy.Tests +{ + public class ProxyHandlerTest : TestBase, IClassFixture, IDisposable + { + private class ProxyHandlerTestFixture : IDisposable + { + public void Dispose() + { + StopServers(); + } + } + + private static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + + private static readonly EndPoint DESTINATION = new DnsEndPoint("destination.com", 42); + private static readonly EndPoint BAD_DESTINATION = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 5); + private static readonly string USERNAME = "testUser"; + private static readonly string PASSWORD = "testPassword"; + private static readonly string BAD_USERNAME = "badUser"; + private static readonly string BAD_PASSWORD = "badPassword"; + + internal static readonly IEventLoopGroup Group = new DefaultEventLoopGroup(3); + + private static readonly ProxyServer DeadHttpProxy = new HttpProxyServer(false, TestMode.Unresponsive, null); + private static readonly ProxyServer InterHttpProxy = new HttpProxyServer(false, TestMode.Intermediary, null); + private static readonly ProxyServer AnonHttpProxy = new HttpProxyServer(false, TestMode.Terminal, DESTINATION); + + private static readonly ProxyServer HttpProxy = + new HttpProxyServer(false, TestMode.Terminal, DESTINATION, USERNAME, PASSWORD); + + private static readonly ProxyServer DeadHttpsProxy = new HttpProxyServer(true, TestMode.Unresponsive, null); + private static readonly ProxyServer InterHttpsProxy = new HttpProxyServer(true, TestMode.Intermediary, null); + private static readonly ProxyServer AnonHttpsProxy = new HttpProxyServer(true, TestMode.Terminal, DESTINATION); + + private static readonly ProxyServer HttpsProxy = + new HttpProxyServer(true, TestMode.Terminal, DESTINATION, USERNAME, PASSWORD); + + /* + static readonly ProxyServer deadSocks4Proxy = new Socks4ProxyServer(false, TestMode.UNRESPONSIVE, null); + static readonly ProxyServer interSocks4Proxy = new Socks4ProxyServer(false, TestMode.INTERMEDIARY, null); + static readonly ProxyServer anonSocks4Proxy = new Socks4ProxyServer(false, TestMode.TERMINAL, DESTINATION); + static readonly ProxyServer socks4Proxy = new Socks4ProxyServer(false, TestMode.TERMINAL, DESTINATION, USERNAME); + + static readonly ProxyServer deadSocks5Proxy = new Socks5ProxyServer(false, TestMode.UNRESPONSIVE, null); + static readonly ProxyServer interSocks5Proxy = new Socks5ProxyServer(false, TestMode.INTERMEDIARY, null); + static readonly ProxyServer anonSocks5Proxy = new Socks5ProxyServer(false, TestMode.TERMINAL, DESTINATION); + static readonly ProxyServer socks5Proxy = + new Socks5ProxyServer(false, TestMode.TERMINAL, DESTINATION, USERNAME, PASSWORD);*/ + + private static readonly IEnumerable AllProxies = new[] + { + DeadHttpProxy, InterHttpProxy, AnonHttpProxy, HttpProxy, + DeadHttpsProxy, InterHttpsProxy, AnonHttpsProxy, HttpsProxy + //deadSocks4Proxy, interSocks4Proxy, anonSocks4Proxy, socks4Proxy, + //deadSocks5Proxy, interSocks5Proxy, anonSocks5Proxy, socks5Proxy + }; + + // set to non-zero value in case you need predictable shuffling of test cases + // look for "Seed used: *" debug message in test logs + private static readonly int ReproducibleSeed = 0; + + public ProxyHandlerTest(ITestOutputHelper output) : base(output) + { + ClearServerExceptions(); + } + + [Theory] + [MemberData(nameof(CreateTestItems))] + public void Test(TestItem item) + { + item.Test(); + } + + public void Dispose() + { + foreach (var p in AllProxies) p.CheckExceptions(); + } + + public static List CreateTestItems() + { + var items = new List + { + // HTTP ------------------------------------------------------- + + new SuccessTestItem( + "Anonymous HTTP proxy: successful connection, AUTO_READ on", + DESTINATION, + true, + new HttpProxyHandler(AnonHttpProxy.Address)), + + new SuccessTestItem( + "Anonymous HTTP proxy: successful connection, AUTO_READ off", + DESTINATION, + false, + new HttpProxyHandler(AnonHttpProxy.Address)), + + new FailureTestItem( + "Anonymous HTTP proxy: rejected connection", + BAD_DESTINATION, "status: 403", + new HttpProxyHandler(AnonHttpProxy.Address)), + + new FailureTestItem( + "HTTP proxy: rejected anonymous connection", + DESTINATION, "status: 401", + new HttpProxyHandler(HttpProxy.Address)), + + new SuccessTestItem( + "HTTP proxy: successful connection, AUTO_READ on", + DESTINATION, + true, + new HttpProxyHandler(HttpProxy.Address, USERNAME, PASSWORD)), + + new SuccessTestItem( + "HTTP proxy: successful connection, AUTO_READ off", + DESTINATION, + false, + new HttpProxyHandler(HttpProxy.Address, USERNAME, PASSWORD)), + + new FailureTestItem( + "HTTP proxy: rejected connection", + BAD_DESTINATION, "status: 403", + new HttpProxyHandler(HttpProxy.Address, USERNAME, PASSWORD)), + + new FailureTestItem( + "HTTP proxy: authentication failure", + DESTINATION, "status: 401", + new HttpProxyHandler(HttpProxy.Address, BAD_USERNAME, BAD_PASSWORD)), + + new TimeoutTestItem( + "HTTP proxy: timeout", + new HttpProxyHandler(DeadHttpProxy.Address)), + + // HTTPS ------------------------------------------------------ + + new SuccessTestItem( + "Anonymous HTTPS proxy: successful connection, AUTO_READ on", + DESTINATION, + true, + CreateClientTlsHandler(), + new HttpProxyHandler(AnonHttpsProxy.Address)), + + new SuccessTestItem( + "Anonymous HTTPS proxy: successful connection, AUTO_READ off", + DESTINATION, + false, + CreateClientTlsHandler(), + new HttpProxyHandler(AnonHttpsProxy.Address)), + + new FailureTestItem( + "Anonymous HTTPS proxy: rejected connection", + BAD_DESTINATION, "status: 403", + CreateClientTlsHandler(), + new HttpProxyHandler(AnonHttpsProxy.Address)), + + new FailureTestItem( + "HTTPS proxy: rejected anonymous connection", + DESTINATION, "status: 401", + CreateClientTlsHandler(), + new HttpProxyHandler(HttpsProxy.Address)), + + new SuccessTestItem( + "HTTPS proxy: successful connection, AUTO_READ on", + DESTINATION, + true, + CreateClientTlsHandler(), + new HttpProxyHandler(HttpsProxy.Address, USERNAME, PASSWORD)), + + new SuccessTestItem( + "HTTPS proxy: successful connection, AUTO_READ off", + DESTINATION, + false, + CreateClientTlsHandler(), + new HttpProxyHandler(HttpsProxy.Address, USERNAME, PASSWORD)), + + new FailureTestItem( + "HTTPS proxy: rejected connection", + BAD_DESTINATION, "status: 403", + CreateClientTlsHandler(), + new HttpProxyHandler(HttpsProxy.Address, USERNAME, PASSWORD)), + + new FailureTestItem( + "HTTPS proxy: authentication failure", + DESTINATION, "status: 401", + CreateClientTlsHandler(), + new HttpProxyHandler(HttpsProxy.Address, BAD_USERNAME, BAD_PASSWORD)), + + new TimeoutTestItem( + "HTTPS proxy: timeout", + CreateClientTlsHandler(), + new HttpProxyHandler(DeadHttpsProxy.Address)) + +/* + // SOCKS4 ----------------------------------------------------- + + new SuccessTestItem( + "Anonymous SOCKS4: successful connection, AUTO_READ on", + DESTINATION, + true, + new Socks4ProxyHandler(anonSocks4Proxy.Address)), + + new SuccessTestItem( + "Anonymous SOCKS4: successful connection, AUTO_READ off", + DESTINATION, + false, + new Socks4ProxyHandler(anonSocks4Proxy.Address)), + + new FailureTestItem( + "Anonymous SOCKS4: rejected connection", + BAD_DESTINATION, "status: REJECTED_OR_FAILED", + new Socks4ProxyHandler(anonSocks4Proxy.Address)), + + new FailureTestItem( + "SOCKS4: rejected anonymous connection", + DESTINATION, "status: IDENTD_AUTH_FAILURE", + new Socks4ProxyHandler(socks4Proxy.Address)), + + new SuccessTestItem( + "SOCKS4: successful connection, AUTO_READ on", + DESTINATION, + true, + new Socks4ProxyHandler(socks4Proxy.Address, USERNAME)), + + new SuccessTestItem( + "SOCKS4: successful connection, AUTO_READ off", + DESTINATION, + false, + new Socks4ProxyHandler(socks4Proxy.Address, USERNAME)), + + new FailureTestItem( + "SOCKS4: rejected connection", + BAD_DESTINATION, "status: REJECTED_OR_FAILED", + new Socks4ProxyHandler(socks4Proxy.Address, USERNAME)), + + new FailureTestItem( + "SOCKS4: authentication failure", + DESTINATION, "status: IDENTD_AUTH_FAILURE", + new Socks4ProxyHandler(socks4Proxy.Address, BAD_USERNAME)), + + new TimeoutTestItem( + "SOCKS4: timeout", + new Socks4ProxyHandler(deadSocks4Proxy.Address)), +*/ + // SOCKS5 ----------------------------------------------------- +/* + new SuccessTestItem( + "Anonymous SOCKS5: successful connection, AUTO_READ on", + DESTINATION, + true, + new Socks5ProxyHandler(anonSocks5Proxy.Address)), + + new SuccessTestItem( + "Anonymous SOCKS5: successful connection, AUTO_READ off", + DESTINATION, + false, + new Socks5ProxyHandler(anonSocks5Proxy.Address)), + + new FailureTestItem( + "Anonymous SOCKS5: rejected connection", + BAD_DESTINATION, "status: FORBIDDEN", + new Socks5ProxyHandler(anonSocks5Proxy.Address)), + + new FailureTestItem( + "SOCKS5: rejected anonymous connection", + DESTINATION, "unexpected authMethod: PASSWORD", + new Socks5ProxyHandler(socks5Proxy.Address)), + + new SuccessTestItem( + "SOCKS5: successful connection, AUTO_READ on", + DESTINATION, + true, + new Socks5ProxyHandler(socks5Proxy.Address, USERNAME, PASSWORD)), + + new SuccessTestItem( + "SOCKS5: successful connection, AUTO_READ off", + DESTINATION, + false, + new Socks5ProxyHandler(socks5Proxy.Address, USERNAME, PASSWORD)), + + new FailureTestItem( + "SOCKS5: rejected connection", + BAD_DESTINATION, "status: FORBIDDEN", + new Socks5ProxyHandler(socks5Proxy.Address, USERNAME, PASSWORD)), + + new FailureTestItem( + "SOCKS5: authentication failure", + DESTINATION, "authStatus: FAILURE", + new Socks5ProxyHandler(socks5Proxy.Address, BAD_USERNAME, BAD_PASSWORD)), + + new TimeoutTestItem( + "SOCKS5: timeout", + new Socks5ProxyHandler(deadSocks5Proxy.Address)), + + // HTTP + HTTPS + SOCKS4 + SOCKS5 + + new SuccessTestItem( + "Single-chain: successful connection, AUTO_READ on", + DESTINATION, + true, + new Socks5ProxyHandler(interSocks5Proxy.Address), // SOCKS5 + new Socks4ProxyHandler(interSocks4Proxy.Address), // SOCKS4 + clientSslCtx.newHandler(PooledByteBufferAllocator.Default), + new HttpProxyHandler(interHttpsProxy.Address), // HTTPS + new HttpProxyHandler(interHttpProxy.Address), // HTTP + new HttpProxyHandler(anonHttpProxy.Address)), + + new SuccessTestItem( + "Single-chain: successful connection, AUTO_READ off", + DESTINATION, + false, + new Socks5ProxyHandler(interSocks5Proxy.Address), // SOCKS5 + new Socks4ProxyHandler(interSocks4Proxy.Address), // SOCKS4 + clientSslCtx.newHandler(PooledByteBufferAllocator.Default), + new HttpProxyHandler(interHttpsProxy.Address), // HTTPS + new HttpProxyHandler(interHttpProxy.Address), // HTTP + new HttpProxyHandler(anonHttpProxy.Address)), + + // (HTTP + HTTPS + SOCKS4 + SOCKS5) * 2 + + new SuccessTestItem( + "Double-chain: successful connection, AUTO_READ on", + DESTINATION, + true, + new Socks5ProxyHandler(interSocks5Proxy.Address), // SOCKS5 + new Socks4ProxyHandler(interSocks4Proxy.Address), // SOCKS4 + clientSslCtx.newHandler(PooledByteBufferAllocator.Default), + new HttpProxyHandler(interHttpsProxy.Address), // HTTPS + new HttpProxyHandler(interHttpProxy.Address), // HTTP + new Socks5ProxyHandler(interSocks5Proxy.Address), // SOCKS5 + new Socks4ProxyHandler(interSocks4Proxy.Address), // SOCKS4 + clientSslCtx.newHandler(PooledByteBufferAllocator.Default), + new HttpProxyHandler(interHttpsProxy.Address), // HTTPS + new HttpProxyHandler(interHttpProxy.Address), // HTTP + new HttpProxyHandler(anonHttpProxy.Address)), + + new SuccessTestItem( + "Double-chain: successful connection, AUTO_READ off", + DESTINATION, + false, + new Socks5ProxyHandler(interSocks5Proxy.Address), // SOCKS5 + new Socks4ProxyHandler(interSocks4Proxy.Address), // SOCKS4 + clientSslCtx.newHandler(PooledByteBufferAllocator.Default), + new HttpProxyHandler(interHttpsProxy.Address), // HTTPS + new HttpProxyHandler(interHttpProxy.Address), // HTTP + new Socks5ProxyHandler(interSocks5Proxy.Address), // SOCKS5 + new Socks4ProxyHandler(interSocks4Proxy.Address), // SOCKS4 + clientSslCtx.newHandler(PooledByteBufferAllocator.Default), + new HttpProxyHandler(interHttpsProxy.Address), // HTTPS + new HttpProxyHandler(interHttpProxy.Address), // HTTP + new HttpProxyHandler(anonHttpProxy.Address)) + */ + }; + + // Convert the test items to the list of constructor parameters. + var parameters = new List(items.Count); + foreach (var i in items) + { + parameters.Add(new object[] {i}); + } + + // Randomize the execution order to increase the possibility of exposing failure dependencies. + var seed = ReproducibleSeed == 0L ? Environment.TickCount : ReproducibleSeed; + Logger.Debug($"Seed used: {seed}\n"); + var rnd = new Random(seed); + parameters = parameters.OrderBy(_ => rnd.Next()).ToList(); + return parameters; + } + + private static TlsHandler CreateClientTlsHandler() + { + return new(s => new SslStream(s, true, (sender, certificate, chain, errors) => true), + new ClientTlsSettings("foo")); + } + + private static void StopServers() + { + foreach (var p in AllProxies) p.Stop(); + } + + private static void ClearServerExceptions() + { + foreach (var p in AllProxies) p.ClearExceptions(); + } + + private class SuccessTestHandler : SimpleChannelInboundHandler + { + internal readonly Queue Exceptions = new(); + internal readonly Queue Received = new(); + internal volatile int EventCount; + + public override void ChannelActive(IChannelHandlerContext ctx) + { + ctx.WriteAndFlushAsync(Unpooled.CopiedBuffer("A\n", Encoding.ASCII)); + ReadIfNeeded(ctx); + } + + public override void UserEventTriggered(IChannelHandlerContext ctx, object evt) + { + if (evt is ProxyConnectionEvent) + { + EventCount++; + + if ( + EventCount == + 1) // Note that ProxyConnectionEvent can be triggered multiple times when there are multiple + // ProxyHandlers in the pipeline. Therefore, we send the 'B' message only on the first event. + ctx.WriteAndFlushAsync(Unpooled.CopiedBuffer("B\n", Encoding.ASCII)); + ReadIfNeeded(ctx); + } + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, object msg) + { + var str = ((IByteBuffer) msg).ToString(Encoding.ASCII); + Received.Enqueue(str); + if ("2".Equals(str)) ctx.WriteAndFlushAsync(Unpooled.CopiedBuffer("C\n", Encoding.ASCII)); + ReadIfNeeded(ctx); + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception cause) + { + Exceptions.Enqueue(cause); + ctx.CloseAsync(); + } + + private static void ReadIfNeeded(IChannelHandlerContext ctx) + { + if (!ctx.Channel.Configuration.IsAutoRead) ctx.Read(); + } + } + + private class FailureTestHandler : SimpleChannelInboundHandler + { + internal readonly Queue Exceptions = new(); + + /** + * A latch that counts down when: + * - a pending write attempt in {@link #channelActive(IChannelHandlerContext)} finishes, or + * - the IChannel is closed. + * By waiting until the latch goes down to 0, we can make sure all assertion failures related with all write + * attempts have been recorded. + */ + internal readonly CountdownEvent Latch = new(2); + + public override void ChannelActive(IChannelHandlerContext ctx) + { + ctx.WriteAndFlushAsync(Unpooled.CopiedBuffer("A\n", Encoding.ASCII)).ContinueWith(future => + { + Latch.Signal(); + if (!(future.Exception.InnerException is ProxyConnectException)) + Exceptions.Enqueue(new XunitException( + "Unexpected failure cause for initial write: " + future.Exception)); + }); + } + + public override void ChannelInactive(IChannelHandlerContext ctx) + { + Latch.Signal(); + } + + public override void UserEventTriggered(IChannelHandlerContext ctx, object evt) + { + if (evt is ProxyConnectionEvent) throw new XunitException("Unexpected event: " + evt); + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, object msg) + { + throw new XunitException("Unexpected message: " + msg); + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception cause) + { + Exceptions.Enqueue(cause); + ctx.CloseAsync(); + } + } + + public abstract class TestItem + { + private readonly string _name; + + protected readonly IChannelHandler[] ClientHandlers; + protected readonly EndPoint Destination; + + protected TestItem(string name, EndPoint destination, params IChannelHandler[] clientHandlers) + { + _name = name; + Destination = destination; + ClientHandlers = clientHandlers; + } + + public abstract void Test(); + + protected void AssertProxyHandlers(bool success) + { + foreach (var h in ClientHandlers) + if (h is ProxyHandler) + { + var ph = (ProxyHandler) h; + var type = ph.GetType().Name; + var f = ph.ConnectFuture; + if (!f.IsCompleted) + { + Logger.Warn($"{type}: not done"); + } + else if (f.IsSuccess()) + { + if (success) + Logger.Debug("{0}: success", type); + else + Logger.Warn("{0}: success", type); + } + else + { + if (success) + Logger.Warn("{0}: failure", type, f.Exception); + else + Logger.Debug("{0}: failure", type, f.Exception); + } + } + + foreach (var h in ClientHandlers) + if (h is ProxyHandler) + { + var ph = (ProxyHandler) h; + Assert.True(ph.ConnectFuture.IsCompleted); + Assert.Equal(success, ph.ConnectFuture.IsSuccess()); + } + } + + public override string ToString() + { + return _name; + } + } + + private class SuccessTestItem : TestItem + { + // Probably we need to be more flexible here and as for the configuration map, + // not a single key. But as far as it works for now, I'm leaving the impl. + // as is, in case we need to cover more cases (like, AUTO_CLOSE, TCP_NODELAY etc) + // feel free to replace this bool with either config or method to setup bootstrap + private readonly bool _autoRead; + private readonly int _expectedEventCount; + + internal SuccessTestItem(string name, + EndPoint destination, + bool autoRead, + params IChannelHandler[] clientHandlers) + : base(name, destination, clientHandlers) + { + var expectedEventCount = 0; + foreach (var h in clientHandlers) + if (h is ProxyHandler) + expectedEventCount++; + + _expectedEventCount = expectedEventCount; + _autoRead = autoRead; + } + + public override void Test() + { + var testHandler = new SuccessTestHandler(); + var b = new Bootstrap() + .Group(Group) + .Channel() + .Option(ChannelOption.AutoRead, _autoRead) + .Resolver(NoopNameResolver.Instance) + .Handler(new ActionChannelInitializer(ch => + { + var p = ch.Pipeline; + p.AddLast(ClientHandlers); + p.AddLast(new LineBasedFrameDecoder(64)); + p.AddLast(testHandler); + })); + + + var channel = b.ConnectAsync(Destination).Result; + var finished = channel.CloseCompletion.Wait(TimeSpan.FromSeconds(10)); + + Logger.Debug("Received messages: {0}", testHandler.Received); + + if (testHandler.Exceptions.Count == 0) + Logger.Debug("No recorded exceptions on the client side."); + else + foreach (var t in testHandler.Exceptions) + Logger.Debug("Recorded exception on the client side: {0}", t); + + AssertProxyHandlers(true); + + Assert.Equal(testHandler.Received, new object[] {"0", "1", "2", "3"}); + Assert.Empty(testHandler.Exceptions); + Assert.Equal(testHandler.EventCount, _expectedEventCount); + Assert.True(finished); + } + } + + private class FailureTestItem : TestItem + { + private readonly string _expectedMessage; + + internal FailureTestItem( + string name, EndPoint destination, string expectedMessage, params IChannelHandler[] clientHandlers) + : base(name, destination, clientHandlers) + { + _expectedMessage = expectedMessage; + } + + public override void Test() + { + var testHandler = new FailureTestHandler(); + var b = new Bootstrap(); + b + .Group(Group) + .Channel() + .Resolver(NoopNameResolver.Instance) + .Handler(new ActionChannelInitializer(ch => + { + var p = ch.Pipeline; + p.AddLast(ClientHandlers); + p.AddLast(new LineBasedFrameDecoder(64)); + p.AddLast(testHandler); + })); + + var finished = b.ConnectAsync(Destination).Result.CloseCompletion.Wait(TimeSpan.FromSeconds(10)); + finished &= testHandler.Latch.Wait(TimeSpan.FromSeconds(10)); + + Logger.Debug("Recorded exceptions: {0}", testHandler.Exceptions); + + AssertProxyHandlers(false); + + Assert.Single(testHandler.Exceptions); + var e = testHandler.Exceptions.Dequeue(); + Assert.IsAssignableFrom(e); + Assert.Contains(_expectedMessage, e.Message); + Assert.True(finished); + } + } + + private class TimeoutTestItem : TestItem + { + internal TimeoutTestItem(string name, params IChannelHandler[] clientHandlers) + : base(name, null, clientHandlers) + { + } + + public override void Test() + { + const long timeout = 2000; + foreach (var h in ClientHandlers) + { + if (h is ProxyHandler handler) + handler.ConnectTimeout = TimeSpan.FromMilliseconds(timeout); + } + + var testHandler = new FailureTestHandler(); + var b = new Bootstrap() + .Group(Group) + .Channel() + .Resolver(NoopNameResolver.Instance) + .Handler(new ActionChannelInitializer(ch => + { + var p = ch.Pipeline; + p.AddLast(ClientHandlers); + p.AddLast(new LineBasedFrameDecoder(64)); + p.AddLast(testHandler); + })); + + var channel = b.ConnectAsync(DESTINATION).Result; + var cf = channel.CloseCompletion; + var finished = cf.Wait(TimeSpan.FromMilliseconds(timeout * 2)); + finished &= testHandler.Latch.Wait(TimeSpan.FromMilliseconds(timeout * 2)); + + Logger.Debug("Recorded exceptions: {0}", testHandler.Exceptions); + + AssertProxyHandlers(false); + + Assert.Single(testHandler.Exceptions); + var e = testHandler.Exceptions.Dequeue(); + Assert.IsType(e); + Assert.Contains("timeout", e.Message); + Assert.True(finished); + } + } + } +} \ No newline at end of file diff --git a/test/DotNetty.Handlers.Proxy.Tests/ProxyServer.cs b/test/DotNetty.Handlers.Proxy.Tests/ProxyServer.cs new file mode 100644 index 000000000..e9a9de37c --- /dev/null +++ b/test/DotNetty.Handlers.Proxy.Tests/ProxyServer.cs @@ -0,0 +1,311 @@ +using System; +using System.Collections.Concurrent; +using System.Net; +using System.Text; +using System.Threading.Tasks; +using DotNetty.Buffers; +using DotNetty.Common.Internal.Logging; +using DotNetty.Common.Utilities; +using DotNetty.Handlers.Tls; +using DotNetty.Tests.Common; +using DotNetty.Transport.Bootstrapping; +using DotNetty.Transport.Channels; +using DotNetty.Transport.Channels.Sockets; + +namespace DotNetty.Handlers.Proxy.Tests +{ + internal abstract class ProxyServer + { + protected readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); + + private readonly TcpServerSocketChannel _ch; + private readonly ConcurrentQueue _recordedExceptions = new ConcurrentQueue(); + + protected readonly TestMode TestMode; + protected readonly string Username; + protected readonly string Password; + protected readonly EndPoint Destination; + + /** + * Starts a new proxy server with disabled authentication for testing purpose. + * + * @param useSsl {@code true} if and only if implicit SSL is enabled + * @param testMode the test mode + * @param destination the expected destination. If the client requests proxying to a different destination, this + * server will reject the connection request. + */ + protected ProxyServer(bool useSsl, TestMode testMode, EndPoint destination) + : this(useSsl, testMode, destination, null, null) + { + } + + /** + * Starts a new proxy server with disabled authentication for testing purpose. + * + * @param useSsl {@code true} if and only if implicit SSL is enabled + * @param testMode the test mode + * @param username the expected username. If the client tries to authenticate with a different username, this server + * will fail the authentication request. + * @param password the expected password. If the client tries to authenticate with a different password, this server + * will fail the authentication request. + * @param destination the expected destination. If the client requests proxying to a different destination, this + * server will reject the connection request. + */ + protected ProxyServer(bool useSsl, TestMode testMode, EndPoint destination, string username, string password) + { + TestMode = testMode; + Destination = destination; + Username = username; + Password = password; + + var b = new ServerBootstrap() + .Channel() + .Group(ProxyHandlerTest.Group) + .ChildHandler(new ActionChannelInitializer(ch => + { + var p = ch.Pipeline; + if (useSsl) + { + p.AddLast(TlsHandler.Server(TestResourceHelper.GetTestCertificate())); + } + + Configure(ch); + })); + + _ch = (TcpServerSocketChannel) b.BindAsync(IPAddress.Loopback, 0).Result; + } + + public IPEndPoint Address + => new IPEndPoint(IPAddress.Loopback, ((IPEndPoint) _ch.LocalAddress).Port); + + protected abstract void Configure(ISocketChannel ch); + + private void RecordException(Exception t) + { + Logger.Warn("Unexpected exception from proxy server:", t); + _recordedExceptions.Enqueue(t); + } + + /** + * Clears all recorded exceptions. + */ + public void ClearExceptions() + { + while (_recordedExceptions.TryDequeue(out _)) + { + + } + } + + /** + * Logs all recorded exceptions and raises the last one so that the caller can fail. + */ + public void CheckExceptions() + { + Exception t; + for (;;) + { + if (!_recordedExceptions.TryDequeue(out t)) + { + break; + } + + Logger.Warn("Unexpected exception:", t); + } + + if (t != null) + { + throw t; + } + } + + public void Stop() + { + _ch.CloseAsync(); + } + + protected abstract class IntermediaryHandler : SimpleChannelInboundHandler + { + private readonly ProxyServer _server; + private readonly ConcurrentQueue _received = new ConcurrentQueue(); + + private bool _finished; + private IChannel _backend; + + protected IntermediaryHandler(ProxyServer server) + { + _server = server; + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, object msg) + { + if (_finished) + { + _received.Enqueue(ReferenceCountUtil.Retain(msg)); + Flush(); + return; + } + + bool finished = HandleProxyProtocol(ctx, msg); + if (finished) + { + _finished = true; + Task f = ConnectToDestination(ctx.Channel.EventLoop, new BackendHandler(_server, ctx)); + f.ContinueWith(future => + { + if (!future.IsSuccess()) + { + _server.RecordException(future.Exception); + ctx.CloseAsync(); + } + else + { + _backend = future.Result; + Flush(); + } + }, TaskContinuationOptions.ExecuteSynchronously); + } + } + + private void Flush() + { + if (_backend != null) + { + for (;;) + { + if (!_received.TryDequeue(out var msg)) + { + break; + } + + _backend.WriteAsync(msg); + _backend.Flush(); + } + } + } + + protected abstract bool HandleProxyProtocol(IChannelHandlerContext ctx, object msg); + + protected abstract EndPoint IntermediaryDestination { get; set; } + + private Task ConnectToDestination(IEventLoop loop, IChannelHandler handler) + { + var b = new Bootstrap() + .Channel() + .Group(loop) + .Handler(handler); + + return b.ConnectAsync(IntermediaryDestination); + } + + public override void ChannelReadComplete(IChannelHandlerContext ctx) + { + ctx.Flush(); + } + + public override void ChannelInactive(IChannelHandlerContext ctx) + { + if (_backend != null) + { + _backend.CloseAsync(); + } + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception cause) + { + _server.RecordException(cause); + ctx.CloseAsync(); + } + + private sealed class BackendHandler : ChannelHandlerAdapter + { + private readonly ProxyServer _server; + private readonly IChannelHandlerContext _frontend; + + internal BackendHandler(ProxyServer server, IChannelHandlerContext frontend) + { + _server = server; + _frontend = frontend; + } + + public override void ChannelRead(IChannelHandlerContext ctx, object msg) + { + _frontend.WriteAsync(msg); + } + + public override void ChannelReadComplete(IChannelHandlerContext ctx) + { + _frontend.Flush(); + } + + public override void ChannelInactive(IChannelHandlerContext ctx) + { + _frontend.CloseAsync(); + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception cause) + { + _server.RecordException(cause); + ctx.CloseAsync(); + } + } + } + + protected abstract class TerminalHandler : SimpleChannelInboundHandler + { + private readonly ProxyServer _server; + private bool _finished; + + protected TerminalHandler(ProxyServer server) + { + _server = server; + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, object msg) + { + if (_finished) + { + string str = ((IByteBuffer) msg).ToString(Encoding.ASCII); + if ("A\n".Equals(str)) + { + ctx.WriteAsync(Unpooled.CopiedBuffer("1\n", Encoding.ASCII)); + } + else if ("B\n".Equals(str)) + { + ctx.WriteAsync(Unpooled.CopiedBuffer("2\n", Encoding.ASCII)); + } + else if ("C\n".Equals(str)) + { + ctx.WriteAsync(Unpooled.CopiedBuffer("3\n", Encoding.ASCII)) + .ContinueWith(_ => ctx.Channel.CloseAsync(), TaskContinuationOptions.ExecuteSynchronously); + } + else + { + throw new InvalidOperationException("unexpected message: " + str); + } + + return; + } + + bool finished = HandleProxyProtocol(ctx, msg); + if (finished) + { + _finished = true; + } + } + + protected abstract bool HandleProxyProtocol(IChannelHandlerContext ctx, object msg); + + public override void ChannelReadComplete(IChannelHandlerContext ctx) + { + ctx.Flush(); + } + + public override void ExceptionCaught(IChannelHandlerContext ctx, Exception cause) + { + _server.RecordException(cause); + ctx.CloseAsync(); + } + } + } +} \ No newline at end of file diff --git a/test/DotNetty.Handlers.Proxy.Tests/TestMode.cs b/test/DotNetty.Handlers.Proxy.Tests/TestMode.cs new file mode 100644 index 000000000..9a300aab0 --- /dev/null +++ b/test/DotNetty.Handlers.Proxy.Tests/TestMode.cs @@ -0,0 +1,9 @@ +namespace DotNetty.Handlers.Proxy.Tests +{ + internal enum TestMode + { + Intermediary, + Terminal, + Unresponsive + } +} \ No newline at end of file diff --git a/test/DotNetty.Handlers.Proxy.Tests/UnresponsiveHandler.cs b/test/DotNetty.Handlers.Proxy.Tests/UnresponsiveHandler.cs new file mode 100644 index 000000000..34681dd53 --- /dev/null +++ b/test/DotNetty.Handlers.Proxy.Tests/UnresponsiveHandler.cs @@ -0,0 +1,36 @@ +using System; +using DotNetty.Transport.Channels; + +namespace DotNetty.Handlers.Proxy.Tests +{ + internal sealed class UnresponsiveHandler : SimpleChannelInboundHandler + { + public static readonly UnresponsiveHandler Instance = new UnresponsiveHandler(); + + private UnresponsiveHandler() + { + } + + public override bool IsSharable => true; + + public override void ChannelActive(IChannelHandlerContext context) + { + base.ChannelActive(context); + } + + public override void ChannelInactive(IChannelHandlerContext context) + { + base.ChannelInactive(context); + } + + public override void ExceptionCaught(IChannelHandlerContext context, Exception exception) + { + base.ExceptionCaught(context, exception); + } + + protected override void ChannelRead0(IChannelHandlerContext ctx, object msg) + { + //Ignore + } + } +} \ No newline at end of file