diff --git a/src/Tunnelite.Client/Tunnelite.Client.csproj b/src/Tunnelite.Client/Tunnelite.Client.csproj index 2999b2e..e92482f 100644 --- a/src/Tunnelite.Client/Tunnelite.Client.csproj +++ b/src/Tunnelite.Client/Tunnelite.Client.csproj @@ -16,7 +16,7 @@ Tool for tunneling URLs https://github.com/cristipufu/ws-tunnel-signalr https://github.com/cristipufu/ws-tunnel-signalr - 1.1.1 + 1.1.2 @@ -27,8 +27,8 @@ - - + + diff --git a/src/Tunnelite.Server/HttpTunnel/HttpAppExtensions.cs b/src/Tunnelite.Server/HttpTunnel/HttpAppExtensions.cs index 05b86af..f4448d0 100644 --- a/src/Tunnelite.Server/HttpTunnel/HttpAppExtensions.cs +++ b/src/Tunnelite.Server/HttpTunnel/HttpAppExtensions.cs @@ -284,12 +284,19 @@ static async Task TunnelRequestAsync(HttpContext context, IHubContext PendingRequests = new(); + private readonly ConcurrentDictionary PendingRequests = new(); public virtual Task WaitForCompletionAsync(Guid requestId, HttpContext context, TimeSpan? timeout = null, CancellationToken cancellationToken = default) { diff --git a/src/Tunnelite.Server/HttpTunnel/HttpTunnelHub.cs b/src/Tunnelite.Server/HttpTunnel/HttpTunnelHub.cs index 0816914..b346b94 100644 --- a/src/Tunnelite.Server/HttpTunnel/HttpTunnelHub.cs +++ b/src/Tunnelite.Server/HttpTunnel/HttpTunnelHub.cs @@ -22,7 +22,9 @@ public override Task OnConnectedAsync() public async IAsyncEnumerable<(ReadOnlyMemory, WebSocketMessageType)> StreamIncomingWsAsync(WsConnection wsConnection) { - var webSocket = _wsRequestsQueue.GetWebSocket(wsConnection.RequestId); + var clientId = GetClientId(Context); + + var webSocket = _wsRequestsQueue.GetWebSocket(clientId, wsConnection.RequestId); if (webSocket == null) { @@ -32,19 +34,26 @@ public override Task OnConnectedAsync() const int chunkSize = 32 * 1024; byte[] buffer = ArrayPool.Shared.Rent(chunkSize); - WebSocketReceiveResult result; + WebSocketReceiveResult? result = null; try { do { - result = await webSocket.ReceiveAsync(new ArraySegment(buffer), Context.ConnectionAborted); + try + { + result = await webSocket.ReceiveAsync(new ArraySegment(buffer), Context.ConnectionAborted); + } + catch (WebSocketException) + { + break; + } yield return (new ReadOnlyMemory(buffer, 0, result.Count), result.MessageType); } while (!result.CloseStatus.HasValue && !Context.ConnectionAborted.IsCancellationRequested); - if (result.MessageType == WebSocketMessageType.Close) + if (result?.MessageType == WebSocketMessageType.Close) { await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, result.CloseStatusDescription, CancellationToken.None); } @@ -62,7 +71,9 @@ public override Task OnConnectedAsync() public async Task StreamOutgoingWsAsync(WsConnection wsConnection, IAsyncEnumerable<(ReadOnlyMemory Data, WebSocketMessageType Type)> stream) { - var webSocket = _wsRequestsQueue.GetWebSocket(wsConnection.RequestId); + var clientId = GetClientId(Context); + + var webSocket = _wsRequestsQueue.GetWebSocket(clientId, wsConnection.RequestId); if (webSocket == null) { @@ -83,6 +94,14 @@ public async Task StreamOutgoingWsAsync(WsConnection wsConnection, IAsyncEnumera } } } + catch (OperationCanceledException) + { + // ignore + } + catch (Exception ex) when (ex.Message == "Stream canceled by client.") + { + // ignore + } catch (Exception ex) { _logger.LogError(ex, "An unexpected error occurred while streaming outgoing data for {RequestId}", wsConnection.RequestId); @@ -93,7 +112,7 @@ public async Task StreamOutgoingWsAsync(WsConnection wsConnection, IAsyncEnumera } } - public override Task OnDisconnectedAsync(Exception? exception) + public override async Task OnDisconnectedAsync(Exception? exception) { var clientId = GetClientId(Context); @@ -104,9 +123,9 @@ public override Task OnDisconnectedAsync(Exception? exception) _httpTunnelStore.Clients.Remove(clientId, out _); } - // todo close and dispose all websockets for clientId + await _wsRequestsQueue.CompleteAsync(clientId); - return base.OnDisconnectedAsync(exception); + await base.OnDisconnectedAsync(exception); } private static Guid GetClientId(HubCallerContext context) diff --git a/src/Tunnelite.Server/TcpTunnel/TcpClientStore.cs b/src/Tunnelite.Server/TcpTunnel/TcpClientStore.cs index c4f866d..b5ae9e4 100644 --- a/src/Tunnelite.Server/TcpTunnel/TcpClientStore.cs +++ b/src/Tunnelite.Server/TcpTunnel/TcpClientStore.cs @@ -7,13 +7,13 @@ namespace Tunnelite.Server.TcpTunnel; public class TcpClientStore { // client, [requestId, TcpClient] - private readonly ConcurrentDictionary> _clientStore = new(); + private readonly ConcurrentDictionary> PendingRequests = new(); // clientId, TcpListener - public ConcurrentDictionary _listenerStore = new(); + private readonly ConcurrentDictionary Listeners = new(); public void AddTcpClient(Guid clientId, Guid requestId, TcpClient tcpClient) { - _clientStore.AddOrUpdate( + PendingRequests.AddOrUpdate( clientId, _ => new ConcurrentDictionary { [requestId] = tcpClient }, (_, tcpClients) => @@ -25,7 +25,7 @@ public void AddTcpClient(Guid clientId, Guid requestId, TcpClient tcpClient) public TcpClient GetTcpClient(Guid clientId, Guid requestId) { - if (!_clientStore.TryGetValue(clientId, out var tcpClients)) + if (!PendingRequests.TryGetValue(clientId, out var tcpClients)) { return null; } @@ -37,7 +37,7 @@ public TcpClient GetTcpClient(Guid clientId, Guid requestId) public void DisposeTcpClient(Guid clientId, Guid requestId) { - if (!_clientStore.TryGetValue(clientId, out var tcpClients)) + if (!PendingRequests.TryGetValue(clientId, out var tcpClients)) { return; } @@ -52,12 +52,12 @@ public void DisposeTcpClient(Guid clientId, Guid requestId) public void AddTcpListener(Guid clientId, TcpListenerContext tcpListener) { - _listenerStore.AddOrUpdate(clientId, tcpListener, (key, oldValue) => tcpListener); + Listeners.AddOrUpdate(clientId, tcpListener, (key, oldValue) => tcpListener); } public void DisposeTcpListener(Guid clientId) { - if (_clientStore.TryRemove(clientId, out var tcpClients)) + if (PendingRequests.TryRemove(clientId, out var tcpClients)) { foreach (var client in tcpClients.Values) { @@ -65,7 +65,7 @@ public void DisposeTcpListener(Guid clientId) } } - if (_listenerStore.TryRemove(clientId, out var listener)) + if (Listeners.TryRemove(clientId, out var listener)) { listener?.Dispose(); } diff --git a/src/Tunnelite.Server/Tunnelite.Server.csproj b/src/Tunnelite.Server/Tunnelite.Server.csproj index c7954a2..988b8ee 100644 --- a/src/Tunnelite.Server/Tunnelite.Server.csproj +++ b/src/Tunnelite.Server/Tunnelite.Server.csproj @@ -7,8 +7,8 @@ - - + + diff --git a/src/Tunnelite.Server/WsTunnel/WsRequestsQueue.cs b/src/Tunnelite.Server/WsTunnel/WsRequestsQueue.cs index 1248c8e..fbcc575 100644 --- a/src/Tunnelite.Server/WsTunnel/WsRequestsQueue.cs +++ b/src/Tunnelite.Server/WsTunnel/WsRequestsQueue.cs @@ -5,9 +5,10 @@ namespace Tunnelite.Server.WsTunnel; public class WsRequestsQueue { - public ConcurrentDictionary PendingRequests = new(); + // client, [requestId, WsDefferedRequest] + private readonly ConcurrentDictionary> PendingRequests = new(); - public virtual Task WaitForCompletionAsync(Guid requestId, WebSocket webSocket) + public virtual Task WaitForCompletionAsync(Guid clientId, Guid requestId, WebSocket webSocket) { WsDefferedRequest request = new() { @@ -16,24 +17,38 @@ public virtual Task WaitForCompletionAsync(Guid requestId, WebSocket webSocket) TaskCompletionSource = new TaskCompletionSource(), }; - PendingRequests.TryAdd(request.RequestId, request); + PendingRequests.AddOrUpdate( + clientId, + _ => new ConcurrentDictionary { [requestId] = request }, + (_, requests) => + { + requests[requestId] = request; + return requests; + }); return request.TaskCompletionSource.Task; } - public virtual WebSocket? GetWebSocket(Guid requestId) + public virtual WebSocket? GetWebSocket(Guid clientId, Guid requestId) { - if (!PendingRequests.TryGetValue(requestId, out var request)) + if (!PendingRequests.TryGetValue(clientId, out var requests)) { return null; } - return request.WebSocket; + requests.TryGetValue(requestId, out var request); + + return request?.WebSocket; } - public virtual Task CompleteAsync(Guid requestId) + public virtual Task CompleteAsync(Guid clientId, Guid requestId) { - if (!PendingRequests.TryRemove(requestId, out var request)) + if (!PendingRequests.TryGetValue(clientId, out var requests)) + { + return Task.CompletedTask; + } + + if (!requests.TryRemove(requestId, out var request)) { return Task.CompletedTask; } @@ -53,4 +68,17 @@ public virtual Task CompleteAsync(Guid requestId) return Task.CompletedTask; } + + public virtual async Task CompleteAsync(Guid clientId) + { + if (!PendingRequests.TryRemove(clientId, out var requests)) + { + return; + } + + foreach (var request in requests) + { + await CompleteAsync(clientId, request.Key); + } + } } diff --git a/src/Tunnelite.Server/WsTunnel/WsTunnelMiddleware.cs b/src/Tunnelite.Server/WsTunnel/WsTunnelMiddleware.cs index 4973dfe..fe34f8b 100644 --- a/src/Tunnelite.Server/WsTunnel/WsTunnelMiddleware.cs +++ b/src/Tunnelite.Server/WsTunnel/WsTunnelMiddleware.cs @@ -65,7 +65,7 @@ public async Task InvokeAsync(HttpContext context) _logger.LogInformation("WebSocket connection accepted: {requestId}", requestId); - var completionTask = _requestsQueue.WaitForCompletionAsync(requestId, webSocket); + var completionTask = _requestsQueue.WaitForCompletionAsync(tunnel!.ClientId, requestId, webSocket); await _hubContext.Clients.Client(connectionId!).SendAsync("NewWsConnection", new WsConnection {