From 8cd4ca30aced7764ed7f73e0422396a2670c86f3 Mon Sep 17 00:00:00 2001 From: Tino Hager Date: Mon, 23 Sep 2024 14:20:32 +0200 Subject: [PATCH 1/2] Add SessionId to SessionContext, fix missing timeout for starttls (#235) * Add SessionId, fix missing starttls timeout Switch to .netstandard 2.1 https://learn.microsoft.com/en-us/dotnet/standard/net-standard?tabs=net-standard-2-1 * cleanup code * code cleanup * Fix parameter not used * Add SessionTimeout UnitTests * Update SmtpServerTests.cs * Update SmtpServerTests.cs --- Src/SmtpServer.Tests/RawSmtpClient.cs | 7 ++- Src/SmtpServer.Tests/SmtpServerTests.cs | 72 ++++++++++++++++++++++++ Src/SmtpServer/IO/SecurableDuplexPipe.cs | 22 +++++++- Src/SmtpServer/ISessionContext.cs | 5 ++ Src/SmtpServer/SmtpServer.csproj | 2 +- Src/SmtpServer/SmtpSessionContext.cs | 4 ++ Src/SmtpServer/SmtpSessionManager.cs | 36 ++++-------- 7 files changed, 117 insertions(+), 31 deletions(-) diff --git a/Src/SmtpServer.Tests/RawSmtpClient.cs b/Src/SmtpServer.Tests/RawSmtpClient.cs index 4e3211d..fca6381 100644 --- a/Src/SmtpServer.Tests/RawSmtpClient.cs +++ b/Src/SmtpServer.Tests/RawSmtpClient.cs @@ -10,9 +10,14 @@ internal class RawSmtpClient : IDisposable { private readonly TcpClient _tcpClient; private NetworkStream _networkStream; + private readonly string _host; + private readonly int _port; internal RawSmtpClient(string host, int port) { + _host = host; + _port = port; + _tcpClient = new TcpClient(); } @@ -24,7 +29,7 @@ public void Dispose() internal async Task ConnectAsync() { - await _tcpClient.ConnectAsync(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 9025)); + await _tcpClient.ConnectAsync(new IPEndPoint(IPAddress.Parse(_host), _port)); _networkStream = _tcpClient.GetStream(); var greetingResponse = await WaitForDataAsync(); diff --git a/Src/SmtpServer.Tests/SmtpServerTests.cs b/Src/SmtpServer.Tests/SmtpServerTests.cs index db558bb..9c69b88 100644 --- a/Src/SmtpServer.Tests/SmtpServerTests.cs +++ b/Src/SmtpServer.Tests/SmtpServerTests.cs @@ -10,6 +10,8 @@ using System.Diagnostics; using System.IO; using System.Net; +using System.Net.Security; +using System.Net.Sockets; using System.Security.Authentication; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; @@ -367,6 +369,76 @@ public void SecuresTheSessionByDefault() Assert.True(isSecure); } + public static bool ValidateServerCertificate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) + { + return true; + } + + [Fact] + public async Task SessionTimeoutIsExceeded_DelayedAuthenticate() + { + var sessionTimeout = TimeSpan.FromSeconds(3); + var server = "localhost"; + var port = 9025; + + using var disposable = CreateServer(endpoint => endpoint + .SessionTimeout(sessionTimeout) + .IsSecure(true) + .Certificate(CreateCertificate()) + ); + + using var tcpClient = new TcpClient(server, port); + using var sslStream = new SslStream(tcpClient.GetStream(), false, new RemoteCertificateValidationCallback(ValidateServerCertificate), null); + + await Task.Delay(sessionTimeout.Add(TimeSpan.FromSeconds(1))); + + var exception = await Assert.ThrowsAsync(async () => + { + await sslStream.AuthenticateAsClientAsync(server); + }); + } + + [Fact] + public async Task SessionTimeoutIsExceeded_NoCommands() + { + var sessionTimeout = TimeSpan.FromSeconds(3); + var server = "localhost"; + var port = 9025; + + using var disposable = CreateServer(endpoint => endpoint + .SessionTimeout(sessionTimeout) + .IsSecure(true) + .Certificate(CreateCertificate()) + ); + + var stopwatch = new Stopwatch(); + stopwatch.Start(); + + using var tcpClient = new TcpClient(server, port); + using var sslStream = new SslStream(tcpClient.GetStream(), false, new RemoteCertificateValidationCallback(ValidateServerCertificate), null); + + await sslStream.AuthenticateAsClientAsync(server); + + if (sslStream.IsAuthenticated) + { + var buffer = new byte[1024]; + + var welcomeByteCount = await sslStream.ReadAsync(buffer, 0, buffer.Length); + + var emptyResponseCount = await sslStream.ReadAsync(buffer, 0, buffer.Length); + + await Task.Delay(100); //Add a tolerance + stopwatch.Stop(); + + Assert.True(emptyResponseCount == 0, "Some data received"); + Assert.True(stopwatch.Elapsed > sessionTimeout, $"SessionTimout not elapsed {stopwatch.Elapsed}"); + } + else + { + Assert.Fail("Smtp Session is not authenticated"); + } + } + [Fact] public void ServerCanBeSecuredAndAuthenticated() { diff --git a/Src/SmtpServer/IO/SecurableDuplexPipe.cs b/Src/SmtpServer/IO/SecurableDuplexPipe.cs index f2207a4..9031c71 100644 --- a/Src/SmtpServer/IO/SecurableDuplexPipe.cs +++ b/Src/SmtpServer/IO/SecurableDuplexPipe.cs @@ -38,11 +38,27 @@ internal SecurableDuplexPipe(Stream stream, Action disposeAction) /// A task that asynchronously performs the operation. public async Task UpgradeAsync(X509Certificate certificate, SslProtocols protocols, CancellationToken cancellationToken = default) { - var stream = new SslStream(_stream, true); + var sslStream = new SslStream(_stream, true); - await stream.AuthenticateAsServerAsync(certificate, false, protocols, true).ConfigureAwait(false); + try + { + var sslServerAuthenticationOptions = new SslServerAuthenticationOptions + { + ServerCertificate = certificate, + ClientCertificateRequired = false, + EnabledSslProtocols = protocols, + CertificateRevocationCheckMode = X509RevocationMode.Online + }; - _stream = stream; + await sslStream.AuthenticateAsServerAsync(sslServerAuthenticationOptions, cancellationToken); + } + catch + { + sslStream.Dispose(); + throw; + } + + _stream = sslStream; Input = PipeReader.Create(_stream); Output = PipeWriter.Create(_stream); diff --git a/Src/SmtpServer/ISessionContext.cs b/Src/SmtpServer/ISessionContext.cs index 590dba2..78f8b31 100644 --- a/Src/SmtpServer/ISessionContext.cs +++ b/Src/SmtpServer/ISessionContext.cs @@ -9,6 +9,11 @@ namespace SmtpServer /// public interface ISessionContext { + /// + /// A unique Id for the Session + /// + public Guid SessionId { get; } + /// /// Fired when a command is about to execute. /// diff --git a/Src/SmtpServer/SmtpServer.csproj b/Src/SmtpServer/SmtpServer.csproj index f70e2f9..008cd07 100644 --- a/Src/SmtpServer/SmtpServer.csproj +++ b/Src/SmtpServer/SmtpServer.csproj @@ -1,7 +1,7 @@  - netstandard2.0 + netstandard2.1 8.0 SmtpServer SmtpServer diff --git a/Src/SmtpServer/SmtpSessionContext.cs b/Src/SmtpServer/SmtpSessionContext.cs index 8b52912..630c5c7 100644 --- a/Src/SmtpServer/SmtpSessionContext.cs +++ b/Src/SmtpServer/SmtpSessionContext.cs @@ -7,6 +7,9 @@ namespace SmtpServer { internal sealed class SmtpSessionContext : ISessionContext { + /// + public Guid SessionId { get; private set; } + /// public event EventHandler CommandExecuting; @@ -27,6 +30,7 @@ internal sealed class SmtpSessionContext : ISessionContext /// The endpoint definition. internal SmtpSessionContext(IServiceProvider serviceProvider, ISmtpServerOptions options, IEndpointDefinition endpointDefinition) { + SessionId = Guid.NewGuid(); ServiceProvider = serviceProvider; ServerOptions = options; EndpointDefinition = endpointDefinition; diff --git a/Src/SmtpServer/SmtpSessionManager.cs b/Src/SmtpServer/SmtpSessionManager.cs index 7e9db55..910a514 100644 --- a/Src/SmtpServer/SmtpSessionManager.cs +++ b/Src/SmtpServer/SmtpSessionManager.cs @@ -1,5 +1,5 @@ using System; -using System.Collections.Generic; +using System.Collections.Concurrent; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -9,8 +9,7 @@ namespace SmtpServer internal sealed class SmtpSessionManager { readonly SmtpServer _smtpServer; - readonly HashSet _sessions = new HashSet(); - readonly object _sessionsLock = new object(); + readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); internal SmtpSessionManager(SmtpServer smtpServer) { @@ -20,16 +19,13 @@ internal SmtpSessionManager(SmtpServer smtpServer) internal void Run(SmtpSessionContext sessionContext, CancellationToken cancellationToken) { var handle = new SmtpSessionHandle(new SmtpSession(sessionContext), sessionContext); - Add(handle); - handle.CompletionTask = RunAsync(handle, cancellationToken); + var smtpSessionTask = RunAsync(handle, cancellationToken).ContinueWith(task => + { + Remove(handle); + }); - // ReSharper disable once MethodSupportsCancellation - handle.CompletionTask.ContinueWith( - task => - { - Remove(handle); - }); + handle.CompletionTask = smtpSessionTask; } async Task RunAsync(SmtpSessionHandle handle, CancellationToken cancellationToken) @@ -79,30 +75,18 @@ async Task UpgradeAsync(SmtpSessionHandle handle, CancellationToken cancellation internal Task WaitAsync() { - IReadOnlyList tasks; - - lock (_sessionsLock) - { - tasks = _sessions.Select(session => session.CompletionTask).ToList(); - } - + var tasks = _sessions.Values.Select(session => session.CompletionTask).ToList().AsReadOnly(); return Task.WhenAll(tasks); } void Add(SmtpSessionHandle handle) { - lock (_sessionsLock) - { - _sessions.Add(handle); - } + _sessions.TryAdd(handle.SessionContext.SessionId, handle); } void Remove(SmtpSessionHandle handle) { - lock (_sessionsLock) - { - _sessions.Remove(handle); - } + _sessions.TryRemove(handle.SessionContext.SessionId, out _); } class SmtpSessionHandle From 9424a3f50dd2cee6b35cd675ae24ed2946f07acf Mon Sep 17 00:00:00 2001 From: cosullivan Date: Mon, 23 Sep 2024 21:41:43 +0800 Subject: [PATCH 2/2] minor cleanup and fixed missing Add from session manager --- Src/SmtpServer.Tests/SmtpServer.Tests.csproj | 4 ++-- Src/SmtpServer.Tests/SmtpServerTests.cs | 8 ++++---- Src/SmtpServer/IO/SecurableDuplexPipe.cs | 18 +++++++++--------- Src/SmtpServer/Net/EndpointListener.cs | 9 ++------- Src/SmtpServer/Net/EndpointListenerFactory.cs | 2 +- Src/SmtpServer/SmtpSessionManager.cs | 10 +++++----- 6 files changed, 23 insertions(+), 28 deletions(-) diff --git a/Src/SmtpServer.Tests/SmtpServer.Tests.csproj b/Src/SmtpServer.Tests/SmtpServer.Tests.csproj index 38bb095..f8732be 100644 --- a/Src/SmtpServer.Tests/SmtpServer.Tests.csproj +++ b/Src/SmtpServer.Tests/SmtpServer.Tests.csproj @@ -6,8 +6,8 @@ - - + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/Src/SmtpServer.Tests/SmtpServerTests.cs b/Src/SmtpServer.Tests/SmtpServerTests.cs index 9c69b88..300bf03 100644 --- a/Src/SmtpServer.Tests/SmtpServerTests.cs +++ b/Src/SmtpServer.Tests/SmtpServerTests.cs @@ -382,10 +382,10 @@ public async Task SessionTimeoutIsExceeded_DelayedAuthenticate() var port = 9025; using var disposable = CreateServer(endpoint => endpoint - .SessionTimeout(sessionTimeout) - .IsSecure(true) - .Certificate(CreateCertificate()) - ); + .SessionTimeout(sessionTimeout) + .IsSecure(true) + .Certificate(CreateCertificate()) + ); using var tcpClient = new TcpClient(server, port); using var sslStream = new SslStream(tcpClient.GetStream(), false, new RemoteCertificateValidationCallback(ValidateServerCertificate), null); diff --git a/Src/SmtpServer/IO/SecurableDuplexPipe.cs b/Src/SmtpServer/IO/SecurableDuplexPipe.cs index 9031c71..70ede62 100644 --- a/Src/SmtpServer/IO/SecurableDuplexPipe.cs +++ b/Src/SmtpServer/IO/SecurableDuplexPipe.cs @@ -42,15 +42,15 @@ public async Task UpgradeAsync(X509Certificate certificate, SslProtocols protoco try { - var sslServerAuthenticationOptions = new SslServerAuthenticationOptions - { - ServerCertificate = certificate, - ClientCertificateRequired = false, - EnabledSslProtocols = protocols, - CertificateRevocationCheckMode = X509RevocationMode.Online - }; - - await sslStream.AuthenticateAsServerAsync(sslServerAuthenticationOptions, cancellationToken); + await sslStream.AuthenticateAsServerAsync( + new SslServerAuthenticationOptions + { + ServerCertificate = certificate, + ClientCertificateRequired = false, + EnabledSslProtocols = protocols, + CertificateRevocationCheckMode = X509RevocationMode.Online + }, + cancellationToken); } catch { diff --git a/Src/SmtpServer/Net/EndpointListener.cs b/Src/SmtpServer/Net/EndpointListener.cs index 160b021..c7d2946 100644 --- a/Src/SmtpServer/Net/EndpointListener.cs +++ b/Src/SmtpServer/Net/EndpointListener.cs @@ -21,19 +21,16 @@ public sealed class EndpointListener : IEndpointListener /// public const string RemoteEndPointKey = "EndpointListener:RemoteEndPoint"; - readonly IEndpointDefinition _endpointDefinition; readonly TcpListener _tcpListener; readonly Action _disposeAction; /// /// Constructor. /// - /// The endpoint definition to create the listener for. /// The TCP listener for the endpoint. /// The action to execute when the listener has been disposed. - internal EndpointListener(IEndpointDefinition endpointDefinition, TcpListener tcpListener, Action disposeAction) + internal EndpointListener(TcpListener tcpListener, Action disposeAction) { - _endpointDefinition = endpointDefinition; _tcpListener = tcpListener; _disposeAction = disposeAction; } @@ -61,9 +58,7 @@ public async Task GetPipeAsync(ISessionContext context, Ca tcpClient.Close(); tcpClient.Dispose(); } - catch (Exception) - { - } + catch { } }); } diff --git a/Src/SmtpServer/Net/EndpointListenerFactory.cs b/Src/SmtpServer/Net/EndpointListenerFactory.cs index 097efca..48f5386 100644 --- a/Src/SmtpServer/Net/EndpointListenerFactory.cs +++ b/Src/SmtpServer/Net/EndpointListenerFactory.cs @@ -32,7 +32,7 @@ public virtual IEndpointListener CreateListener(IEndpointDefinition endpointDefi var endpointEventArgs = new EndpointEventArgs(endpointDefinition, tcpListener.LocalEndpoint); OnEndpointStarted(endpointEventArgs); - return new EndpointListener(endpointDefinition, tcpListener, () => OnEndpointStopped(endpointEventArgs)); + return new EndpointListener(tcpListener, () => OnEndpointStopped(endpointEventArgs)); } /// diff --git a/Src/SmtpServer/SmtpSessionManager.cs b/Src/SmtpServer/SmtpSessionManager.cs index 910a514..e02ab8d 100644 --- a/Src/SmtpServer/SmtpSessionManager.cs +++ b/Src/SmtpServer/SmtpSessionManager.cs @@ -19,19 +19,19 @@ internal SmtpSessionManager(SmtpServer smtpServer) internal void Run(SmtpSessionContext sessionContext, CancellationToken cancellationToken) { var handle = new SmtpSessionHandle(new SmtpSession(sessionContext), sessionContext); + Add(handle); - var smtpSessionTask = RunAsync(handle, cancellationToken).ContinueWith(task => + handle.CompletionTask = RunAsync(handle, cancellationToken).ContinueWith(task => { Remove(handle); }); - - handle.CompletionTask = smtpSessionTask; } async Task RunAsync(SmtpSessionHandle handle, CancellationToken cancellationToken) { - using var sessionReadTimeoutCancellationTokenSource = new CancellationTokenSource(handle.SessionContext.EndpointDefinition.SessionTimeout); - using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, sessionReadTimeoutCancellationTokenSource.Token); + using var sessionTimeoutCancellationTokenSource = new CancellationTokenSource(handle.SessionContext.EndpointDefinition.SessionTimeout); + + using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, sessionTimeoutCancellationTokenSource.Token); try {