Skip to content

Commit

Permalink
Close websockets on client disconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
cristipufu committed Aug 24, 2024
1 parent a5b10d8 commit 3afc537
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 34 deletions.
6 changes: 3 additions & 3 deletions src/Tunnelite.Client/Tunnelite.Client.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<Description>Tool for tunneling URLs</Description>
<PackageProjectUrl>https://github.com/cristipufu/ws-tunnel-signalr</PackageProjectUrl>
<RepositoryUrl>https://github.com/cristipufu/ws-tunnel-signalr</RepositoryUrl>
<Version>1.1.1</Version>
<Version>1.1.2</Version>
</PropertyGroup>

<ItemGroup>
Expand All @@ -27,8 +27,8 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.SignalR.Client" Version="8.0.6" />
<PackageReference Include="Microsoft.AspNetCore.SignalR.Protocols.MessagePack" Version="8.0.7" />
<PackageReference Include="Microsoft.AspNetCore.SignalR.Client" Version="8.0.8" />
<PackageReference Include="Microsoft.AspNetCore.SignalR.Protocols.MessagePack" Version="8.0.8" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="8.0.0" />
<PackageReference Include="System.CommandLine" Version="2.0.0-beta4.22272.1" />
</ItemGroup>
Expand Down
13 changes: 10 additions & 3 deletions src/Tunnelite.Server/HttpTunnel/HttpAppExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,19 @@ static async Task TunnelRequestAsync(HttpContext context, IHubContext<HttpTunnel

await completionTask;
}
catch (TaskCanceledException)
{
// ignore
}
catch (Exception ex)
{
logger.LogError(ex, "Error processing request tunnel: {Message}", ex.Message);
if (!context.Response.HasStarted)
{
context.Response.StatusCode = StatusCodes.Status500InternalServerError;
await context.Response.WriteAsync("An error occurred while processing the tunnel.");
}

context.Response.StatusCode = StatusCodes.Status500InternalServerError;
await context.Response.WriteAsync("An error occurred while processing the tunnel.");
logger.LogError(ex, "Error processing request tunnel: {Message}", ex.Message);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/Tunnelite.Server/HttpTunnel/HttpRequestsQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace Tunnelite.Server.HttpTunnel;

public class HttpRequestsQueue
{
public ConcurrentDictionary<Guid, HttpDefferedRequest> PendingRequests = new();
private readonly ConcurrentDictionary<Guid, HttpDefferedRequest> PendingRequests = new();

public virtual Task WaitForCompletionAsync(Guid requestId, HttpContext context, TimeSpan? timeout = null, CancellationToken cancellationToken = default)
{
Expand Down
35 changes: 27 additions & 8 deletions src/Tunnelite.Server/HttpTunnel/HttpTunnelHub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ public override Task OnConnectedAsync()

public async IAsyncEnumerable<(ReadOnlyMemory<byte>, WebSocketMessageType)> StreamIncomingWsAsync(WsConnection wsConnection)
{
var webSocket = _wsRequestsQueue.GetWebSocket(wsConnection.RequestId);
var clientId = GetClientId(Context);

var webSocket = _wsRequestsQueue.GetWebSocket(clientId, wsConnection.RequestId);

if (webSocket == null)
{
Expand All @@ -32,19 +34,26 @@ public override Task OnConnectedAsync()
const int chunkSize = 32 * 1024;

byte[] buffer = ArrayPool<byte>.Shared.Rent(chunkSize);
WebSocketReceiveResult result;
WebSocketReceiveResult? result = null;

try
{
do
{
result = await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), Context.ConnectionAborted);
try
{
result = await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), Context.ConnectionAborted);
}
catch (WebSocketException)
{
break;
}

yield return (new ReadOnlyMemory<byte>(buffer, 0, result.Count), result.MessageType);
}
while (!result.CloseStatus.HasValue && !Context.ConnectionAborted.IsCancellationRequested);

if (result.MessageType == WebSocketMessageType.Close)
if (result?.MessageType == WebSocketMessageType.Close)
{
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, result.CloseStatusDescription, CancellationToken.None);
}
Expand All @@ -62,7 +71,9 @@ public override Task OnConnectedAsync()

public async Task StreamOutgoingWsAsync(WsConnection wsConnection, IAsyncEnumerable<(ReadOnlyMemory<byte> Data, WebSocketMessageType Type)> stream)
{
var webSocket = _wsRequestsQueue.GetWebSocket(wsConnection.RequestId);
var clientId = GetClientId(Context);

var webSocket = _wsRequestsQueue.GetWebSocket(clientId, wsConnection.RequestId);

if (webSocket == null)
{
Expand All @@ -83,6 +94,14 @@ public async Task StreamOutgoingWsAsync(WsConnection wsConnection, IAsyncEnumera
}
}
}
catch (OperationCanceledException)
{
// ignore
}
catch (Exception ex) when (ex.Message == "Stream canceled by client.")
{
// ignore
}
catch (Exception ex)
{
_logger.LogError(ex, "An unexpected error occurred while streaming outgoing data for {RequestId}", wsConnection.RequestId);
Expand All @@ -93,7 +112,7 @@ public async Task StreamOutgoingWsAsync(WsConnection wsConnection, IAsyncEnumera
}
}

public override Task OnDisconnectedAsync(Exception? exception)
public override async Task OnDisconnectedAsync(Exception? exception)
{
var clientId = GetClientId(Context);

Expand All @@ -104,9 +123,9 @@ public override Task OnDisconnectedAsync(Exception? exception)
_httpTunnelStore.Clients.Remove(clientId, out _);
}

// todo close and dispose all websockets for clientId
await _wsRequestsQueue.CompleteAsync(clientId);

return base.OnDisconnectedAsync(exception);
await base.OnDisconnectedAsync(exception);
}

private static Guid GetClientId(HubCallerContext context)
Expand Down
16 changes: 8 additions & 8 deletions src/Tunnelite.Server/TcpTunnel/TcpClientStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ namespace Tunnelite.Server.TcpTunnel;
public class TcpClientStore
{
// client, [requestId, TcpClient]
private readonly ConcurrentDictionary<Guid, ConcurrentDictionary<Guid, TcpClient>> _clientStore = new();
private readonly ConcurrentDictionary<Guid, ConcurrentDictionary<Guid, TcpClient>> PendingRequests = new();
// clientId, TcpListener
public ConcurrentDictionary<Guid, TcpListenerContext> _listenerStore = new();
private readonly ConcurrentDictionary<Guid, TcpListenerContext> Listeners = new();

public void AddTcpClient(Guid clientId, Guid requestId, TcpClient tcpClient)
{
_clientStore.AddOrUpdate(
PendingRequests.AddOrUpdate(
clientId,
_ => new ConcurrentDictionary<Guid, TcpClient> { [requestId] = tcpClient },
(_, tcpClients) =>
Expand All @@ -25,7 +25,7 @@ public void AddTcpClient(Guid clientId, Guid requestId, TcpClient tcpClient)

public TcpClient GetTcpClient(Guid clientId, Guid requestId)
{
if (!_clientStore.TryGetValue(clientId, out var tcpClients))
if (!PendingRequests.TryGetValue(clientId, out var tcpClients))
{
return null;
}
Expand All @@ -37,7 +37,7 @@ public TcpClient GetTcpClient(Guid clientId, Guid requestId)

public void DisposeTcpClient(Guid clientId, Guid requestId)
{
if (!_clientStore.TryGetValue(clientId, out var tcpClients))
if (!PendingRequests.TryGetValue(clientId, out var tcpClients))
{
return;
}
Expand All @@ -52,20 +52,20 @@ public void DisposeTcpClient(Guid clientId, Guid requestId)

public void AddTcpListener(Guid clientId, TcpListenerContext tcpListener)
{
_listenerStore.AddOrUpdate(clientId, tcpListener, (key, oldValue) => tcpListener);
Listeners.AddOrUpdate(clientId, tcpListener, (key, oldValue) => tcpListener);
}

public void DisposeTcpListener(Guid clientId)
{
if (_clientStore.TryRemove(clientId, out var tcpClients))
if (PendingRequests.TryRemove(clientId, out var tcpClients))
{
foreach (var client in tcpClients.Values)
{
client?.Dispose();
}
}

if (_listenerStore.TryRemove(clientId, out var listener))
if (Listeners.TryRemove(clientId, out var listener))
{
listener?.Dispose();
}
Expand Down
4 changes: 2 additions & 2 deletions src/Tunnelite.Server/Tunnelite.Server.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.SignalR.Protocols.MessagePack" Version="8.0.7" />
<PackageReference Include="Microsoft.Azure.SignalR" Version="1.26.0" />
<PackageReference Include="Microsoft.AspNetCore.SignalR.Protocols.MessagePack" Version="8.0.8" />
<PackageReference Include="Microsoft.Azure.SignalR" Version="1.26.1" />
</ItemGroup>

<ItemGroup>
Expand Down
44 changes: 36 additions & 8 deletions src/Tunnelite.Server/WsTunnel/WsRequestsQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ namespace Tunnelite.Server.WsTunnel;

public class WsRequestsQueue
{
public ConcurrentDictionary<Guid, WsDefferedRequest> PendingRequests = new();
// client, [requestId, WsDefferedRequest]
private readonly ConcurrentDictionary<Guid, ConcurrentDictionary<Guid, WsDefferedRequest>> PendingRequests = new();

public virtual Task WaitForCompletionAsync(Guid requestId, WebSocket webSocket)
public virtual Task WaitForCompletionAsync(Guid clientId, Guid requestId, WebSocket webSocket)
{
WsDefferedRequest request = new()
{
Expand All @@ -16,24 +17,38 @@ public virtual Task WaitForCompletionAsync(Guid requestId, WebSocket webSocket)
TaskCompletionSource = new TaskCompletionSource(),
};

PendingRequests.TryAdd(request.RequestId, request);
PendingRequests.AddOrUpdate(
clientId,
_ => new ConcurrentDictionary<Guid, WsDefferedRequest> { [requestId] = request },
(_, requests) =>
{
requests[requestId] = request;
return requests;
});

return request.TaskCompletionSource.Task;
}

public virtual WebSocket? GetWebSocket(Guid requestId)
public virtual WebSocket? GetWebSocket(Guid clientId, Guid requestId)
{
if (!PendingRequests.TryGetValue(requestId, out var request))
if (!PendingRequests.TryGetValue(clientId, out var requests))
{
return null;
}

return request.WebSocket;
requests.TryGetValue(requestId, out var request);

return request?.WebSocket;
}

public virtual Task CompleteAsync(Guid requestId)
public virtual Task CompleteAsync(Guid clientId, Guid requestId)
{
if (!PendingRequests.TryRemove(requestId, out var request))
if (!PendingRequests.TryGetValue(clientId, out var requests))
{
return Task.CompletedTask;
}

if (!requests.TryRemove(requestId, out var request))
{
return Task.CompletedTask;
}
Expand All @@ -53,4 +68,17 @@ public virtual Task CompleteAsync(Guid requestId)

return Task.CompletedTask;
}

public virtual async Task CompleteAsync(Guid clientId)
{
if (!PendingRequests.TryRemove(clientId, out var requests))
{
return;
}

foreach (var request in requests)
{
await CompleteAsync(clientId, request.Key);
}
}
}
2 changes: 1 addition & 1 deletion src/Tunnelite.Server/WsTunnel/WsTunnelMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public async Task InvokeAsync(HttpContext context)

_logger.LogInformation("WebSocket connection accepted: {requestId}", requestId);

var completionTask = _requestsQueue.WaitForCompletionAsync(requestId, webSocket);
var completionTask = _requestsQueue.WaitForCompletionAsync(tunnel!.ClientId, requestId, webSocket);

await _hubContext.Clients.Client(connectionId!).SendAsync("NewWsConnection", new WsConnection
{
Expand Down

0 comments on commit 3afc537

Please sign in to comment.