From 37ff94b9ebf668f3fbeb99a18baaa75ebb2d5acf Mon Sep 17 00:00:00 2001 From: Gabriel Roldan Date: Mon, 27 May 2024 19:02:11 -0300 Subject: [PATCH] Fix truncated response body when catching application errors to display standardized error page Throwing the `ResponseStatusException` too late during the request processing causes the response body to be partially committed and not displaying the customized error page. Use a `ServerHttpResponseDecorator` that throws the exception as soon as its `setStatusCode()` method is called. --- .../app/FiltersAutoConfiguration.java | 6 +- .../ApplicationErrorGatewayFilterFactory.java | 113 ++++++++++++++---- gateway/src/main/resources/application.yml | 1 + ...licationErrorGatewayFilterFactoryTest.java | 76 +++++++----- 4 files changed, 140 insertions(+), 56 deletions(-) diff --git a/gateway/src/main/java/org/georchestra/gateway/autoconfigure/app/FiltersAutoConfiguration.java b/gateway/src/main/java/org/georchestra/gateway/autoconfigure/app/FiltersAutoConfiguration.java index 8a1db74c..70c238fc 100644 --- a/gateway/src/main/java/org/georchestra/gateway/autoconfigure/app/FiltersAutoConfiguration.java +++ b/gateway/src/main/java/org/georchestra/gateway/autoconfigure/app/FiltersAutoConfiguration.java @@ -45,7 +45,8 @@ public class FiltersAutoConfiguration { * matched Route's GeorchestraTargetConfig for each HTTP request-response * interaction before other filters are applied. */ - @Bean ResolveTargetGlobalFilter resolveTargetWebFilter(GatewayConfigProperties config) { + @Bean + ResolveTargetGlobalFilter resolveTargetWebFilter(GatewayConfigProperties config) { return new ResolveTargetGlobalFilter(config); } @@ -66,7 +67,8 @@ public class FiltersAutoConfiguration { return new StripBasePathGatewayFilterFactory(); } - @Bean ApplicationErrorGatewayFilterFactory applicationErrorGatewayFilterFactory() { + @Bean + ApplicationErrorGatewayFilterFactory applicationErrorGatewayFilterFactory() { return new ApplicationErrorGatewayFilterFactory(); } } diff --git a/gateway/src/main/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactory.java b/gateway/src/main/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactory.java index d35dbeb7..15aa2d0d 100644 --- a/gateway/src/main/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactory.java +++ b/gateway/src/main/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactory.java @@ -18,47 +18,110 @@ */ package org.georchestra.gateway.filter.global; +import java.net.URI; + import org.springframework.cloud.gateway.filter.GatewayFilter; import org.springframework.cloud.gateway.filter.GatewayFilterChain; import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory; import org.springframework.cloud.gateway.filter.factory.GatewayFilterFactory; +import org.springframework.cloud.gateway.support.HttpStatusHolder; +import org.springframework.cloud.gateway.support.ServerWebExchangeUtils; import org.springframework.core.Ordered; import org.springframework.http.HttpStatus; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.http.server.reactive.ServerHttpResponseDecorator; +import org.springframework.lang.Nullable; import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ServerWebExchange; + +import lombok.extern.slf4j.Slf4j; import reactor.core.publisher.Mono; /** + * Filter to allow custom error pages to be used when an application behind the + * gateways returns an error. + *

* {@link GatewayFilterFactory} providing a {@link GatewayFilter} that throws a * {@link ResponseStatusException} with the proxied response status code if the * target responded with a {@code 400...} or {@code 500...} status code. * + *

+ * Usage: to enable it globally, add this to application.yaml : + * + *

+ * 
+ * spring:
+ *  cloud:
+ *    gateway:
+ *      default-filters:
+ *        - ApplicationError
+ * 
+ * 
+ * + * To enable it only on some routes, add this to concerned routes in + * {@literal routes.yaml}: + * + *
+ * 
+ *        filters:
+ *       - name: ApplicationError
+ * 
+ * 
*/ +@Slf4j public class ApplicationErrorGatewayFilterFactory extends AbstractGatewayFilterFactory { - public ApplicationErrorGatewayFilterFactory() { - super(Object.class); - } - - @Override - public GatewayFilter apply(final Object config) { - return new ServiceErrorGatewayFilter(); - } - - private static class ServiceErrorGatewayFilter implements GatewayFilter, Ordered { - - public @Override Mono filter(ServerWebExchange exchange, GatewayFilterChain chain) { - return chain.filter(exchange).then(Mono.fromRunnable(() -> { - HttpStatus statusCode = exchange.getResponse().getStatusCode(); - if (statusCode.is4xxClientError() || statusCode.is5xxServerError()) { - throw new ResponseStatusException(statusCode); - } - })); - } - - @Override - public int getOrder() { - return ResolveTargetGlobalFilter.ORDER + 1; - } - } + public ApplicationErrorGatewayFilterFactory() { + super(Object.class); + } + + @Override + public GatewayFilter apply(final Object config) { + return new ServiceErrorGatewayFilter(); + } + + private static class ServiceErrorGatewayFilter implements GatewayFilter, Ordered { + + public @Override Mono filter(ServerWebExchange exchange, GatewayFilterChain chain) { + + ApplicationErrorConveyorHttpResponse response; + response = new ApplicationErrorConveyorHttpResponse(exchange.getResponse()); + + exchange = exchange.mutate().response(response).build(); + return chain.filter(exchange); + } + + @Override + public int getOrder() { + return ResolveTargetGlobalFilter.ORDER + 1; + } + + } + + /** + * A response decorator that throws a {@link ResponseStatusException} at + * {@link #setStatusCode(HttpStatus)} if the status code is an error code, thus + * letting the gateway render the appropriate custom error page instead of the + * original application response body. + */ + private static class ApplicationErrorConveyorHttpResponse extends ServerHttpResponseDecorator { + + public ApplicationErrorConveyorHttpResponse(ServerHttpResponse delegate) { + super(delegate); + } + + @Override + public boolean setStatusCode(@Nullable HttpStatus status) { + checkStatusCode(status); + return super.setStatusCode(status); + } + + private void checkStatusCode(HttpStatus statusCode) { + log.debug("native status code: {}", statusCode); + if (statusCode.is4xxClientError() || statusCode.is5xxServerError()) { + log.debug("Conveying {} response status", statusCode); + throw new ResponseStatusException(statusCode); + } + } + } } diff --git a/gateway/src/main/resources/application.yml b/gateway/src/main/resources/application.yml index ead19e59..fdb21a7d 100644 --- a/gateway/src/main/resources/application.yml +++ b/gateway/src/main/resources/application.yml @@ -52,6 +52,7 @@ spring: - RemoveSecurityHeaders # AddSecHeaders appends sec-* headers to proxied requests based on the currently authenticated user - AddSecHeaders + - ApplicationError global-filter: websocket-routing: enabled: true diff --git a/gateway/src/test/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactoryTest.java b/gateway/src/test/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactoryTest.java index 7d0ef066..a2061c96 100644 --- a/gateway/src/test/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactoryTest.java +++ b/gateway/src/test/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactoryTest.java @@ -22,7 +22,10 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR; import java.net.URI; @@ -32,21 +35,24 @@ import org.georchestra.gateway.model.RoleBasedAccessRule; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.cloud.gateway.filter.GatewayFilter; import org.springframework.cloud.gateway.filter.GatewayFilterChain; +import org.springframework.cloud.gateway.handler.FilteringWebHandler; import org.springframework.cloud.gateway.route.Route; import org.springframework.http.HttpStatus; +import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpResponse; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.web.server.ResponseStatusException; +import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; class ApplicationErrorGatewayFilterFactoryTest { - - - private GatewayFilterChain chain; - private GatewayFilter filter; + + private GatewayFilter filter; private MockServerWebExchange exchange; final URI matchedURI = URI.create("http://fake.backend.com:8080"); @@ -57,47 +63,59 @@ class ApplicationErrorGatewayFilterFactoryTest { @BeforeEach void setUp() throws Exception { - var factory = new ApplicationErrorGatewayFilterFactory(); - filter = factory.apply(factory.newConfig()); + var factory = new ApplicationErrorGatewayFilterFactory(); + filter = factory.apply(factory.newConfig()); matchedRoute = mock(Route.class); when(matchedRoute.getUri()).thenReturn(matchedURI); - chain = mock(GatewayFilterChain.class); - when(chain.filter(any())).thenReturn(Mono.empty()); MockServerHttpRequest request = MockServerHttpRequest.get("/test").build(); exchange = MockServerWebExchange.from(request); exchange.getAttributes().put(GATEWAY_ROUTE_ATTR, matchedRoute); + exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, matchedURI); + } @Test - void testNotAnErrorResponse() { - exchange.getResponse().setStatusCode(HttpStatus.OK); - Mono result = filter.filter(exchange, chain); - result.block(); - assertThat(exchange.getResponse().getRawStatusCode()).isEqualTo(200); - } + void testNotAnErrorResponse() { + GatewayFilterChain chain = mock(GatewayFilterChain.class); + + filter.filter(exchange, chain); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ServerWebExchange.class); + verify(chain).filter(captor.capture()); + + ServerWebExchange mutated = captor.getValue(); + ServerHttpResponse response = mutated.getResponse(); + response.setStatusCode(HttpStatus.CREATED); + + MockServerHttpResponse origResponse = exchange.getResponse(); + assertThat(origResponse.getStatusCode()).isEqualTo(HttpStatus.CREATED); + } @Test void test4xx() { - testApplicationError(HttpStatus.BAD_REQUEST); - testApplicationError(HttpStatus.UNAUTHORIZED); - testApplicationError(HttpStatus.FORBIDDEN); - testApplicationError(HttpStatus.NOT_FOUND); + testApplicationError(HttpStatus.BAD_REQUEST); + testApplicationError(HttpStatus.UNAUTHORIZED); + testApplicationError(HttpStatus.FORBIDDEN); + testApplicationError(HttpStatus.NOT_FOUND); } - @Test void test5xx() { - testApplicationError(HttpStatus.INTERNAL_SERVER_ERROR); - testApplicationError(HttpStatus.SERVICE_UNAVAILABLE); - testApplicationError(HttpStatus.BAD_GATEWAY); + testApplicationError(HttpStatus.INTERNAL_SERVER_ERROR); + testApplicationError(HttpStatus.SERVICE_UNAVAILABLE); + testApplicationError(HttpStatus.BAD_GATEWAY); + } + + private void testApplicationError(HttpStatus status) { + GatewayFilterChain chain = mock(GatewayFilterChain.class); + filter.filter(exchange, chain); + ArgumentCaptor captor = ArgumentCaptor.forClass(ServerWebExchange.class); + verify(chain).filter(captor.capture()); + + ServerWebExchange mutated = captor.getValue(); + ServerHttpResponse response = mutated.getResponse(); + assertThrows(ResponseStatusException.class, () -> response.setStatusCode(status)); } - - private void testApplicationError(HttpStatus status) { - exchange.getResponse().setStatusCode(status); - Mono result = filter.filter(exchange, chain); - ResponseStatusException ex = assertThrows(ResponseStatusException.class, ()-> result.block()); - assertThat(ex.getStatus()).isEqualTo(status); - } }