diff --git a/src/Ocelot/WebSockets/WebSocketsProxyMiddleware.cs b/src/Ocelot/WebSockets/WebSocketsProxyMiddleware.cs index 385efc8316..70a811abe5 100644 --- a/src/Ocelot/WebSockets/WebSocketsProxyMiddleware.cs +++ b/src/Ocelot/WebSockets/WebSocketsProxyMiddleware.cs @@ -52,14 +52,22 @@ private static async Task PumpWebSocket(WebSocket source, WebSocket destination, } catch (OperationCanceledException) { - await destination.CloseOutputAsync(WebSocketCloseStatus.EndpointUnavailable, null, cancellationToken); + await TryCloseWebSocket( + destination, + WebSocketCloseStatus.EndpointUnavailable, + null, + cancellationToken); return; } catch (WebSocketException e) { if (e.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) { - await destination.CloseOutputAsync(WebSocketCloseStatus.EndpointUnavailable, null, cancellationToken); + await TryCloseWebSocket( + destination, + WebSocketCloseStatus.EndpointUnavailable, + null, + cancellationToken); return; } @@ -68,11 +76,18 @@ private static async Task PumpWebSocket(WebSocket source, WebSocket destination, if (result.MessageType == WebSocketMessageType.Close) { - await destination.CloseOutputAsync(source.CloseStatus.Value, source.CloseStatusDescription, cancellationToken); + await TryCloseWebSocket( + destination, + source.CloseStatus.Value, + source.CloseStatusDescription, + cancellationToken); return; } - await destination.SendAsync(new ArraySegment(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, cancellationToken); + if (destination.State == WebSocketState.Open) + { + await destination.SendAsync(new ArraySegment(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, cancellationToken); + } } } @@ -154,5 +169,20 @@ await Task.WhenAll( PumpWebSocket(server, client.ToWebSocket(), DefaultWebSocketBufferSize, context.RequestAborted)); } } + + private static async Task TryCloseWebSocket( + WebSocket webSocket, + WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken) + { + if (webSocket.State == WebSocketState.Open || webSocket.State == WebSocketState.CloseReceived) + { + await webSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken); + return true; + } + + return false; + } } }