Skip to content

Commit

Permalink
Add SessionId to SessionContext, fix missing timeout for starttls (#235)
Browse files Browse the repository at this point in the history
* Add SessionId, fix missing starttls timeout

Switch to .netstandard 2.1
https://learn.microsoft.com/en-us/dotnet/standard/net-standard?tabs=net-standard-2-1

* cleanup code

* code cleanup

* Fix parameter not used

* Add SessionTimeout UnitTests

* Update SmtpServerTests.cs

* Update SmtpServerTests.cs
  • Loading branch information
tinohager authored Sep 23, 2024
1 parent 3b3ca98 commit 8cd4ca3
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 31 deletions.
7 changes: 6 additions & 1 deletion Src/SmtpServer.Tests/RawSmtpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ internal class RawSmtpClient : IDisposable
{
private readonly TcpClient _tcpClient;
private NetworkStream _networkStream;
private readonly string _host;
private readonly int _port;

internal RawSmtpClient(string host, int port)
{
_host = host;
_port = port;

_tcpClient = new TcpClient();
}

Expand All @@ -24,7 +29,7 @@ public void Dispose()

internal async Task<bool> ConnectAsync()
{
await _tcpClient.ConnectAsync(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 9025));
await _tcpClient.ConnectAsync(new IPEndPoint(IPAddress.Parse(_host), _port));
_networkStream = _tcpClient.GetStream();

var greetingResponse = await WaitForDataAsync();
Expand Down
72 changes: 72 additions & 0 deletions Src/SmtpServer.Tests/SmtpServerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
Expand Down Expand Up @@ -367,6 +369,76 @@ public void SecuresTheSessionByDefault()
Assert.True(isSecure);
}

public static bool ValidateServerCertificate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
{
return true;
}

[Fact]
public async Task SessionTimeoutIsExceeded_DelayedAuthenticate()
{
var sessionTimeout = TimeSpan.FromSeconds(3);
var server = "localhost";
var port = 9025;

using var disposable = CreateServer(endpoint => endpoint
.SessionTimeout(sessionTimeout)
.IsSecure(true)
.Certificate(CreateCertificate())
);

using var tcpClient = new TcpClient(server, port);
using var sslStream = new SslStream(tcpClient.GetStream(), false, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);

await Task.Delay(sessionTimeout.Add(TimeSpan.FromSeconds(1)));

var exception = await Assert.ThrowsAsync<IOException>(async () =>
{
await sslStream.AuthenticateAsClientAsync(server);
});
}

[Fact]
public async Task SessionTimeoutIsExceeded_NoCommands()
{
var sessionTimeout = TimeSpan.FromSeconds(3);
var server = "localhost";
var port = 9025;

using var disposable = CreateServer(endpoint => endpoint
.SessionTimeout(sessionTimeout)
.IsSecure(true)
.Certificate(CreateCertificate())
);

var stopwatch = new Stopwatch();
stopwatch.Start();

using var tcpClient = new TcpClient(server, port);
using var sslStream = new SslStream(tcpClient.GetStream(), false, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);

await sslStream.AuthenticateAsClientAsync(server);

if (sslStream.IsAuthenticated)
{
var buffer = new byte[1024];

var welcomeByteCount = await sslStream.ReadAsync(buffer, 0, buffer.Length);

var emptyResponseCount = await sslStream.ReadAsync(buffer, 0, buffer.Length);

await Task.Delay(100); //Add a tolerance
stopwatch.Stop();

Assert.True(emptyResponseCount == 0, "Some data received");
Assert.True(stopwatch.Elapsed > sessionTimeout, $"SessionTimout not elapsed {stopwatch.Elapsed}");
}
else
{
Assert.Fail("Smtp Session is not authenticated");
}
}

[Fact]
public void ServerCanBeSecuredAndAuthenticated()
{
Expand Down
22 changes: 19 additions & 3 deletions Src/SmtpServer/IO/SecurableDuplexPipe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,27 @@ internal SecurableDuplexPipe(Stream stream, Action disposeAction)
/// <returns>A task that asynchronously performs the operation.</returns>
public async Task UpgradeAsync(X509Certificate certificate, SslProtocols protocols, CancellationToken cancellationToken = default)
{
var stream = new SslStream(_stream, true);
var sslStream = new SslStream(_stream, true);

await stream.AuthenticateAsServerAsync(certificate, false, protocols, true).ConfigureAwait(false);
try
{
var sslServerAuthenticationOptions = new SslServerAuthenticationOptions
{
ServerCertificate = certificate,
ClientCertificateRequired = false,
EnabledSslProtocols = protocols,
CertificateRevocationCheckMode = X509RevocationMode.Online
};

_stream = stream;
await sslStream.AuthenticateAsServerAsync(sslServerAuthenticationOptions, cancellationToken);
}
catch
{
sslStream.Dispose();
throw;
}

_stream = sslStream;

Input = PipeReader.Create(_stream);
Output = PipeWriter.Create(_stream);
Expand Down
5 changes: 5 additions & 0 deletions Src/SmtpServer/ISessionContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ namespace SmtpServer
/// </summary>
public interface ISessionContext
{
/// <summary>
/// A unique Id for the Session
/// </summary>
public Guid SessionId { get; }

/// <summary>
/// Fired when a command is about to execute.
/// </summary>
Expand Down
2 changes: 1 addition & 1 deletion Src/SmtpServer/SmtpServer.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<TargetFramework>netstandard2.1</TargetFramework>
<LangVersion>8.0</LangVersion>
<AssemblyName>SmtpServer</AssemblyName>
<RootNamespace>SmtpServer</RootNamespace>
Expand Down
4 changes: 4 additions & 0 deletions Src/SmtpServer/SmtpSessionContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ namespace SmtpServer
{
internal sealed class SmtpSessionContext : ISessionContext
{
/// <inheritdoc />
public Guid SessionId { get; private set; }

/// <inheritdoc />
public event EventHandler<SmtpCommandEventArgs> CommandExecuting;

Expand All @@ -27,6 +30,7 @@ internal sealed class SmtpSessionContext : ISessionContext
/// <param name="endpointDefinition">The endpoint definition.</param>
internal SmtpSessionContext(IServiceProvider serviceProvider, ISmtpServerOptions options, IEndpointDefinition endpointDefinition)
{
SessionId = Guid.NewGuid();
ServiceProvider = serviceProvider;
ServerOptions = options;
EndpointDefinition = endpointDefinition;
Expand Down
36 changes: 10 additions & 26 deletions Src/SmtpServer/SmtpSessionManager.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using System;
using System.Collections.Generic;
using System.Collections.Concurrent;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -9,8 +9,7 @@ namespace SmtpServer
internal sealed class SmtpSessionManager
{
readonly SmtpServer _smtpServer;
readonly HashSet<SmtpSessionHandle> _sessions = new HashSet<SmtpSessionHandle>();
readonly object _sessionsLock = new object();
readonly ConcurrentDictionary<Guid, SmtpSessionHandle> _sessions = new ConcurrentDictionary<Guid, SmtpSessionHandle>();

internal SmtpSessionManager(SmtpServer smtpServer)
{
Expand All @@ -20,16 +19,13 @@ internal SmtpSessionManager(SmtpServer smtpServer)
internal void Run(SmtpSessionContext sessionContext, CancellationToken cancellationToken)
{
var handle = new SmtpSessionHandle(new SmtpSession(sessionContext), sessionContext);
Add(handle);

handle.CompletionTask = RunAsync(handle, cancellationToken);
var smtpSessionTask = RunAsync(handle, cancellationToken).ContinueWith(task =>
{
Remove(handle);
});

// ReSharper disable once MethodSupportsCancellation
handle.CompletionTask.ContinueWith(
task =>
{
Remove(handle);
});
handle.CompletionTask = smtpSessionTask;
}

async Task RunAsync(SmtpSessionHandle handle, CancellationToken cancellationToken)
Expand Down Expand Up @@ -79,30 +75,18 @@ async Task UpgradeAsync(SmtpSessionHandle handle, CancellationToken cancellation

internal Task WaitAsync()
{
IReadOnlyList<Task> tasks;

lock (_sessionsLock)
{
tasks = _sessions.Select(session => session.CompletionTask).ToList();
}

var tasks = _sessions.Values.Select(session => session.CompletionTask).ToList().AsReadOnly();
return Task.WhenAll(tasks);
}

void Add(SmtpSessionHandle handle)
{
lock (_sessionsLock)
{
_sessions.Add(handle);
}
_sessions.TryAdd(handle.SessionContext.SessionId, handle);
}

void Remove(SmtpSessionHandle handle)
{
lock (_sessionsLock)
{
_sessions.Remove(handle);
}
_sessions.TryRemove(handle.SessionContext.SessionId, out _);
}

class SmtpSessionHandle
Expand Down

0 comments on commit 8cd4ca3

Please sign in to comment.