diff --git a/src/Tunnelite.Client/Tunnelite.Client.csproj b/src/Tunnelite.Client/Tunnelite.Client.csproj
index 2999b2e..e92482f 100644
--- a/src/Tunnelite.Client/Tunnelite.Client.csproj
+++ b/src/Tunnelite.Client/Tunnelite.Client.csproj
@@ -16,7 +16,7 @@
Tool for tunneling URLs
https://github.com/cristipufu/ws-tunnel-signalr
https://github.com/cristipufu/ws-tunnel-signalr
- 1.1.1
+ 1.1.2
@@ -27,8 +27,8 @@
-
-
+
+
diff --git a/src/Tunnelite.Server/HttpTunnel/HttpAppExtensions.cs b/src/Tunnelite.Server/HttpTunnel/HttpAppExtensions.cs
index 05b86af..f4448d0 100644
--- a/src/Tunnelite.Server/HttpTunnel/HttpAppExtensions.cs
+++ b/src/Tunnelite.Server/HttpTunnel/HttpAppExtensions.cs
@@ -284,12 +284,19 @@ static async Task TunnelRequestAsync(HttpContext context, IHubContext PendingRequests = new();
+ private readonly ConcurrentDictionary PendingRequests = new();
public virtual Task WaitForCompletionAsync(Guid requestId, HttpContext context, TimeSpan? timeout = null, CancellationToken cancellationToken = default)
{
diff --git a/src/Tunnelite.Server/HttpTunnel/HttpTunnelHub.cs b/src/Tunnelite.Server/HttpTunnel/HttpTunnelHub.cs
index 0816914..b346b94 100644
--- a/src/Tunnelite.Server/HttpTunnel/HttpTunnelHub.cs
+++ b/src/Tunnelite.Server/HttpTunnel/HttpTunnelHub.cs
@@ -22,7 +22,9 @@ public override Task OnConnectedAsync()
public async IAsyncEnumerable<(ReadOnlyMemory, WebSocketMessageType)> StreamIncomingWsAsync(WsConnection wsConnection)
{
- var webSocket = _wsRequestsQueue.GetWebSocket(wsConnection.RequestId);
+ var clientId = GetClientId(Context);
+
+ var webSocket = _wsRequestsQueue.GetWebSocket(clientId, wsConnection.RequestId);
if (webSocket == null)
{
@@ -32,19 +34,26 @@ public override Task OnConnectedAsync()
const int chunkSize = 32 * 1024;
byte[] buffer = ArrayPool.Shared.Rent(chunkSize);
- WebSocketReceiveResult result;
+ WebSocketReceiveResult? result = null;
try
{
do
{
- result = await webSocket.ReceiveAsync(new ArraySegment(buffer), Context.ConnectionAborted);
+ try
+ {
+ result = await webSocket.ReceiveAsync(new ArraySegment(buffer), Context.ConnectionAborted);
+ }
+ catch (WebSocketException)
+ {
+ break;
+ }
yield return (new ReadOnlyMemory(buffer, 0, result.Count), result.MessageType);
}
while (!result.CloseStatus.HasValue && !Context.ConnectionAborted.IsCancellationRequested);
- if (result.MessageType == WebSocketMessageType.Close)
+ if (result?.MessageType == WebSocketMessageType.Close)
{
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, result.CloseStatusDescription, CancellationToken.None);
}
@@ -62,7 +71,9 @@ public override Task OnConnectedAsync()
public async Task StreamOutgoingWsAsync(WsConnection wsConnection, IAsyncEnumerable<(ReadOnlyMemory Data, WebSocketMessageType Type)> stream)
{
- var webSocket = _wsRequestsQueue.GetWebSocket(wsConnection.RequestId);
+ var clientId = GetClientId(Context);
+
+ var webSocket = _wsRequestsQueue.GetWebSocket(clientId, wsConnection.RequestId);
if (webSocket == null)
{
@@ -83,6 +94,14 @@ public async Task StreamOutgoingWsAsync(WsConnection wsConnection, IAsyncEnumera
}
}
}
+ catch (OperationCanceledException)
+ {
+ // ignore
+ }
+ catch (Exception ex) when (ex.Message == "Stream canceled by client.")
+ {
+ // ignore
+ }
catch (Exception ex)
{
_logger.LogError(ex, "An unexpected error occurred while streaming outgoing data for {RequestId}", wsConnection.RequestId);
@@ -93,7 +112,7 @@ public async Task StreamOutgoingWsAsync(WsConnection wsConnection, IAsyncEnumera
}
}
- public override Task OnDisconnectedAsync(Exception? exception)
+ public override async Task OnDisconnectedAsync(Exception? exception)
{
var clientId = GetClientId(Context);
@@ -104,9 +123,9 @@ public override Task OnDisconnectedAsync(Exception? exception)
_httpTunnelStore.Clients.Remove(clientId, out _);
}
- // todo close and dispose all websockets for clientId
+ await _wsRequestsQueue.CompleteAsync(clientId);
- return base.OnDisconnectedAsync(exception);
+ await base.OnDisconnectedAsync(exception);
}
private static Guid GetClientId(HubCallerContext context)
diff --git a/src/Tunnelite.Server/TcpTunnel/TcpClientStore.cs b/src/Tunnelite.Server/TcpTunnel/TcpClientStore.cs
index c4f866d..b5ae9e4 100644
--- a/src/Tunnelite.Server/TcpTunnel/TcpClientStore.cs
+++ b/src/Tunnelite.Server/TcpTunnel/TcpClientStore.cs
@@ -7,13 +7,13 @@ namespace Tunnelite.Server.TcpTunnel;
public class TcpClientStore
{
// client, [requestId, TcpClient]
- private readonly ConcurrentDictionary> _clientStore = new();
+ private readonly ConcurrentDictionary> PendingRequests = new();
// clientId, TcpListener
- public ConcurrentDictionary _listenerStore = new();
+ private readonly ConcurrentDictionary Listeners = new();
public void AddTcpClient(Guid clientId, Guid requestId, TcpClient tcpClient)
{
- _clientStore.AddOrUpdate(
+ PendingRequests.AddOrUpdate(
clientId,
_ => new ConcurrentDictionary { [requestId] = tcpClient },
(_, tcpClients) =>
@@ -25,7 +25,7 @@ public void AddTcpClient(Guid clientId, Guid requestId, TcpClient tcpClient)
public TcpClient GetTcpClient(Guid clientId, Guid requestId)
{
- if (!_clientStore.TryGetValue(clientId, out var tcpClients))
+ if (!PendingRequests.TryGetValue(clientId, out var tcpClients))
{
return null;
}
@@ -37,7 +37,7 @@ public TcpClient GetTcpClient(Guid clientId, Guid requestId)
public void DisposeTcpClient(Guid clientId, Guid requestId)
{
- if (!_clientStore.TryGetValue(clientId, out var tcpClients))
+ if (!PendingRequests.TryGetValue(clientId, out var tcpClients))
{
return;
}
@@ -52,12 +52,12 @@ public void DisposeTcpClient(Guid clientId, Guid requestId)
public void AddTcpListener(Guid clientId, TcpListenerContext tcpListener)
{
- _listenerStore.AddOrUpdate(clientId, tcpListener, (key, oldValue) => tcpListener);
+ Listeners.AddOrUpdate(clientId, tcpListener, (key, oldValue) => tcpListener);
}
public void DisposeTcpListener(Guid clientId)
{
- if (_clientStore.TryRemove(clientId, out var tcpClients))
+ if (PendingRequests.TryRemove(clientId, out var tcpClients))
{
foreach (var client in tcpClients.Values)
{
@@ -65,7 +65,7 @@ public void DisposeTcpListener(Guid clientId)
}
}
- if (_listenerStore.TryRemove(clientId, out var listener))
+ if (Listeners.TryRemove(clientId, out var listener))
{
listener?.Dispose();
}
diff --git a/src/Tunnelite.Server/Tunnelite.Server.csproj b/src/Tunnelite.Server/Tunnelite.Server.csproj
index c7954a2..988b8ee 100644
--- a/src/Tunnelite.Server/Tunnelite.Server.csproj
+++ b/src/Tunnelite.Server/Tunnelite.Server.csproj
@@ -7,8 +7,8 @@
-
-
+
+
diff --git a/src/Tunnelite.Server/WsTunnel/WsRequestsQueue.cs b/src/Tunnelite.Server/WsTunnel/WsRequestsQueue.cs
index 1248c8e..fbcc575 100644
--- a/src/Tunnelite.Server/WsTunnel/WsRequestsQueue.cs
+++ b/src/Tunnelite.Server/WsTunnel/WsRequestsQueue.cs
@@ -5,9 +5,10 @@ namespace Tunnelite.Server.WsTunnel;
public class WsRequestsQueue
{
- public ConcurrentDictionary PendingRequests = new();
+ // client, [requestId, WsDefferedRequest]
+ private readonly ConcurrentDictionary> PendingRequests = new();
- public virtual Task WaitForCompletionAsync(Guid requestId, WebSocket webSocket)
+ public virtual Task WaitForCompletionAsync(Guid clientId, Guid requestId, WebSocket webSocket)
{
WsDefferedRequest request = new()
{
@@ -16,24 +17,38 @@ public virtual Task WaitForCompletionAsync(Guid requestId, WebSocket webSocket)
TaskCompletionSource = new TaskCompletionSource(),
};
- PendingRequests.TryAdd(request.RequestId, request);
+ PendingRequests.AddOrUpdate(
+ clientId,
+ _ => new ConcurrentDictionary { [requestId] = request },
+ (_, requests) =>
+ {
+ requests[requestId] = request;
+ return requests;
+ });
return request.TaskCompletionSource.Task;
}
- public virtual WebSocket? GetWebSocket(Guid requestId)
+ public virtual WebSocket? GetWebSocket(Guid clientId, Guid requestId)
{
- if (!PendingRequests.TryGetValue(requestId, out var request))
+ if (!PendingRequests.TryGetValue(clientId, out var requests))
{
return null;
}
- return request.WebSocket;
+ requests.TryGetValue(requestId, out var request);
+
+ return request?.WebSocket;
}
- public virtual Task CompleteAsync(Guid requestId)
+ public virtual Task CompleteAsync(Guid clientId, Guid requestId)
{
- if (!PendingRequests.TryRemove(requestId, out var request))
+ if (!PendingRequests.TryGetValue(clientId, out var requests))
+ {
+ return Task.CompletedTask;
+ }
+
+ if (!requests.TryRemove(requestId, out var request))
{
return Task.CompletedTask;
}
@@ -53,4 +68,17 @@ public virtual Task CompleteAsync(Guid requestId)
return Task.CompletedTask;
}
+
+ public virtual async Task CompleteAsync(Guid clientId)
+ {
+ if (!PendingRequests.TryRemove(clientId, out var requests))
+ {
+ return;
+ }
+
+ foreach (var request in requests)
+ {
+ await CompleteAsync(clientId, request.Key);
+ }
+ }
}
diff --git a/src/Tunnelite.Server/WsTunnel/WsTunnelMiddleware.cs b/src/Tunnelite.Server/WsTunnel/WsTunnelMiddleware.cs
index 4973dfe..fe34f8b 100644
--- a/src/Tunnelite.Server/WsTunnel/WsTunnelMiddleware.cs
+++ b/src/Tunnelite.Server/WsTunnel/WsTunnelMiddleware.cs
@@ -65,7 +65,7 @@ public async Task InvokeAsync(HttpContext context)
_logger.LogInformation("WebSocket connection accepted: {requestId}", requestId);
- var completionTask = _requestsQueue.WaitForCompletionAsync(requestId, webSocket);
+ var completionTask = _requestsQueue.WaitForCompletionAsync(tunnel!.ClientId, requestId, webSocket);
await _hubContext.Clients.Client(connectionId!).SendAsync("NewWsConnection", new WsConnection
{