Skip to content

Commit

Permalink
fix: Fix closing WebSocket in case of 401 and other exception (#3271)
Browse files Browse the repository at this point in the history
  • Loading branch information
pj892031 authored Jan 18, 2024
1 parent ad45ec5 commit aa8b316
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
package org.zowe.apiml.gateway.ws;

import lombok.extern.slf4j.Slf4j;

import java.util.concurrent.TimeoutException;

import org.apache.http.HttpStatus;
import org.eclipse.jetty.websocket.api.CloseException;
import org.eclipse.jetty.websocket.api.UpgradeException;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;

import java.util.concurrent.TimeoutException;

/**
* Copies data from the client to the server session.
*/
Expand All @@ -43,15 +44,48 @@ public void afterConnectionClosed(WebSocketSession session, CloseStatus status)
webSocketServerSession.close(status);
}

static CloseStatus getCloseStatusByResponseStatus(int responseStatus, String defaultMessage) {
if (responseStatus >= 1000) {
return new CloseStatus(responseStatus, defaultMessage);
}

if (responseStatus >= 500) {
return CloseStatus.SERVER_ERROR.withReason(defaultMessage);
}

switch (responseStatus) {
case HttpStatus.SC_UNAUTHORIZED:
return CloseStatus.NOT_ACCEPTABLE.withReason("Invalid login credentials");
default:
return CloseStatus.NOT_ACCEPTABLE.withReason(defaultMessage);
}
}

static CloseStatus getCloseStatusByError(Throwable exception) {
if (exception instanceof CloseException) {
CloseException closeException = (CloseException) exception;
if (closeException.getCause() instanceof TimeoutException) {
return CloseStatus.NORMAL;
}
return getCloseStatusByResponseStatus(closeException.getStatusCode(), String.valueOf(exception.getCause()));
}

if (exception instanceof UpgradeException) {
UpgradeException upgradeException = (UpgradeException) exception;
return getCloseStatusByResponseStatus(upgradeException.getResponseStatusCode(), String.valueOf(exception));
}

return CloseStatus.SERVER_ERROR.withReason(String.valueOf(exception));
}

@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
log.warn("WebSocket transport error in session {}: {}", session.getId(), exception.getMessage());
if (exception instanceof CloseException && exception.getCause() instanceof TimeoutException) {
// Idle timeout
webSocketServerSession.close(CloseStatus.NORMAL);
} else if (exception instanceof CloseException) {
webSocketServerSession.close(new CloseStatus(((CloseException) exception).getStatusCode(), exception.getMessage()));

if (webSocketServerSession.isOpen()) {
webSocketServerSession.close(getCloseStatusByError(exception));
}

super.handleTransportError(session, exception);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,8 @@

package org.zowe.apiml.gateway.ws;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import java.util.concurrent.TimeoutException;

import org.eclipse.jetty.websocket.api.CloseException;
import org.eclipse.jetty.websocket.api.UpgradeException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
Expand All @@ -26,6 +21,11 @@
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketSession;

import java.util.concurrent.TimeoutException;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
public class WebSocketProxyClientHandlerTest {

Expand Down Expand Up @@ -56,6 +56,11 @@ void thenCloseServer() throws Exception {
@Nested
class AndConnectionTransportError {

@BeforeEach
void setUp() {
doReturn(true).when(serverSession).isOpen();
}

@Test
void andTimeout_thenCloseNormal() throws Exception {
webSocketProxyClientHandler.handleTransportError(mock(WebSocketSession.class), new CloseException(0, new TimeoutException("null")));
Expand All @@ -72,4 +77,101 @@ void andCloseException_thenForwardError() throws Exception {

}

@Nested
class CloseStatuses {

@Test
void whenResponseCode5xx_thenServerErrorWithMessage() {
CloseStatus closeStatus = WebSocketProxyClientHandler.getCloseStatusByResponseStatus(500, "An error message");
assertEquals(CloseStatus.SERVER_ERROR.getCode(), closeStatus.getCode());
assertEquals("An error message", closeStatus.getReason());
}

@Test
void whenResponseCode401_thenNotAcceptableWithMessage() {
CloseStatus closeStatus = WebSocketProxyClientHandler.getCloseStatusByResponseStatus(401, "A default message");
assertEquals(CloseStatus.NOT_ACCEPTABLE.getCode(), closeStatus.getCode());
assertEquals("Invalid login credentials", closeStatus.getReason());
}

@Test
void whenUnknownResponseCode_thenNotAcceptableWithMessage() {
CloseStatus closeStatus = WebSocketProxyClientHandler.getCloseStatusByResponseStatus(405, "A default message");
assertEquals(CloseStatus.NOT_ACCEPTABLE.getCode(), closeStatus.getCode());
assertEquals("A default message", closeStatus.getReason());
}

@Test
void whenCloseExceptionWithTimeout_thenNormal() {
assertEquals(CloseStatus.NORMAL, WebSocketProxyClientHandler.getCloseStatusByError(
new CloseException(500, mock(TimeoutException.class)))
);
}

@Test
void whenCloseExceptionWith401_thenNotAcceptable() {
assertEquals(
CloseStatus.NOT_ACCEPTABLE.withReason("Invalid login credentials"),
WebSocketProxyClientHandler.getCloseStatusByError(new CloseException(401, "unauthMsg"))
);
}

@Test
void whenCloseExceptionWith500_thenServerError() {
assertEquals(
CloseStatus.SERVER_ERROR.withReason("null"),
WebSocketProxyClientHandler.getCloseStatusByError(new CloseException(502, "errorMsg"))
);
}

@Test
void whenUpgradeExceptionWith401_thenNotAcceptable() {
assertEquals(
CloseStatus.NOT_ACCEPTABLE.withReason("Invalid login credentials"),
WebSocketProxyClientHandler.getCloseStatusByError(new UpgradeException(null, 401, "unauthMsg"))
);
}

@Test
void whenUpgradeExceptionWith500_thenServerError() {
assertEquals(
CloseStatus.SERVER_ERROR.withReason("org.eclipse.jetty.websocket.api.UpgradeException: errorMsg"),
WebSocketProxyClientHandler.getCloseStatusByError(new UpgradeException(null, 503, "errorMsg"))
);
}

@Test
void whenUnknownException_thenServerError() {
assertEquals(
CloseStatus.SERVER_ERROR.withReason("java.lang.RuntimeException: errorMsg"),
WebSocketProxyClientHandler.getCloseStatusByError(new RuntimeException("errorMsg"))
);
}

}

@Nested
class ClosingWebSocket {

private final WebSocketSession session = mock(WebSocketSession.class);
private final Throwable exception = new Exception("reason");

@Test
void whenClosed_thenDontClose() throws Exception {
WebSocketSession webSocketSession = mock(WebSocketSession.class);
new WebSocketProxyClientHandler(webSocketSession).handleTransportError(session, exception);
verify(webSocketSession, never()).close(any());
}

@Test
void whenOpened_thenClose() throws Exception {
WebSocketSession session = mock(WebSocketSession.class);
doReturn(true).when(session).isOpen();
new WebSocketProxyClientHandler(session).handleTransportError(session, exception);
verify(session).close(any());

}

}

}
1 change: 1 addition & 0 deletions gradle/sonar.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ sonar {
property "sonar.language", "java"
property "sonar.links.scm", "https://github.com/zowe/api-layer"
property "sonar.links.ci", System.getenv()['BUILD_URL'] ?: null
property "sonar.scanner.force-deprecated-java-version", true
if (pullRequest != null) {
property "sonar.pullrequest.key", System.getenv()['CHANGE_ID'] ?: null
property "sonar.pullrequest.branch", System.getenv()['CHANGE_BRANCH'] ?: null
Expand Down

0 comments on commit aa8b316

Please sign in to comment.