Skip to content

Commit

Permalink
Merge pull request #124 from georchestra/catch_service_errors
Browse files Browse the repository at this point in the history
Fix truncated response body when catching application errors to display standardized error page
  • Loading branch information
groldan authored May 28, 2024
2 parents 1b8f1b1 + 37ff94b commit 04f5aa2
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ public class FiltersAutoConfiguration {
* matched Route's GeorchestraTargetConfig for each HTTP request-response
* interaction before other filters are applied.
*/
@Bean ResolveTargetGlobalFilter resolveTargetWebFilter(GatewayConfigProperties config) {
@Bean
ResolveTargetGlobalFilter resolveTargetWebFilter(GatewayConfigProperties config) {
return new ResolveTargetGlobalFilter(config);
}

Expand All @@ -66,7 +67,8 @@ public class FiltersAutoConfiguration {
return new StripBasePathGatewayFilterFactory();
}

@Bean ApplicationErrorGatewayFilterFactory applicationErrorGatewayFilterFactory() {
@Bean
ApplicationErrorGatewayFilterFactory applicationErrorGatewayFilterFactory() {
return new ApplicationErrorGatewayFilterFactory();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,110 @@
*/
package org.georchestra.gateway.filter.global;

import java.net.URI;

import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.filter.factory.GatewayFilterFactory;
import org.springframework.cloud.gateway.support.HttpStatusHolder;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.core.Ordered;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.Nullable;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;

import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Mono;

/**
* Filter to allow custom error pages to be used when an application behind the
* gateways returns an error.
* <p>
* {@link GatewayFilterFactory} providing a {@link GatewayFilter} that throws a
* {@link ResponseStatusException} with the proxied response status code if the
* target responded with a {@code 400...} or {@code 500...} status code.
*
* <p>
* Usage: to enable it globally, add this to application.yaml :
*
* <pre>
* <code>
* spring:
* cloud:
* gateway:
* default-filters:
* - ApplicationError
* </code>
* </pre>
*
* To enable it only on some routes, add this to concerned routes in
* {@literal routes.yaml}:
*
* <pre>
* <code>
* filters:
* - name: ApplicationError
* </code>
* </pre>
*/
@Slf4j
public class ApplicationErrorGatewayFilterFactory extends AbstractGatewayFilterFactory<Object> {

public ApplicationErrorGatewayFilterFactory() {
super(Object.class);
}

@Override
public GatewayFilter apply(final Object config) {
return new ServiceErrorGatewayFilter();
}

private static class ServiceErrorGatewayFilter implements GatewayFilter, Ordered {

public @Override Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
return chain.filter(exchange).then(Mono.fromRunnable(() -> {
HttpStatus statusCode = exchange.getResponse().getStatusCode();
if (statusCode.is4xxClientError() || statusCode.is5xxServerError()) {
throw new ResponseStatusException(statusCode);
}
}));
}

@Override
public int getOrder() {
return ResolveTargetGlobalFilter.ORDER + 1;
}
}
public ApplicationErrorGatewayFilterFactory() {
super(Object.class);
}

@Override
public GatewayFilter apply(final Object config) {
return new ServiceErrorGatewayFilter();
}

private static class ServiceErrorGatewayFilter implements GatewayFilter, Ordered {

public @Override Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {

ApplicationErrorConveyorHttpResponse response;
response = new ApplicationErrorConveyorHttpResponse(exchange.getResponse());

exchange = exchange.mutate().response(response).build();
return chain.filter(exchange);
}

@Override
public int getOrder() {
return ResolveTargetGlobalFilter.ORDER + 1;
}

}

/**
* A response decorator that throws a {@link ResponseStatusException} at
* {@link #setStatusCode(HttpStatus)} if the status code is an error code, thus
* letting the gateway render the appropriate custom error page instead of the
* original application response body.
*/
private static class ApplicationErrorConveyorHttpResponse extends ServerHttpResponseDecorator {

public ApplicationErrorConveyorHttpResponse(ServerHttpResponse delegate) {
super(delegate);
}

@Override
public boolean setStatusCode(@Nullable HttpStatus status) {
checkStatusCode(status);
return super.setStatusCode(status);
}

private void checkStatusCode(HttpStatus statusCode) {
log.debug("native status code: {}", statusCode);
if (statusCode.is4xxClientError() || statusCode.is5xxServerError()) {
log.debug("Conveying {} response status", statusCode);
throw new ResponseStatusException(statusCode);
}
}
}
}
1 change: 1 addition & 0 deletions gateway/src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ spring:
- RemoveSecurityHeaders
# AddSecHeaders appends sec-* headers to proxied requests based on the currently authenticated user
- AddSecHeaders
- ApplicationError
global-filter:
websocket-routing:
enabled: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR;

import java.net.URI;
Expand All @@ -32,21 +35,24 @@
import org.georchestra.gateway.model.RoleBasedAccessRule;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.handler.FilteringWebHandler;
import org.springframework.cloud.gateway.route.Route;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;

import reactor.core.publisher.Mono;

class ApplicationErrorGatewayFilterFactoryTest {


private GatewayFilterChain chain;
private GatewayFilter filter;

private GatewayFilter filter;
private MockServerWebExchange exchange;

final URI matchedURI = URI.create("http://fake.backend.com:8080");
Expand All @@ -57,47 +63,59 @@ class ApplicationErrorGatewayFilterFactoryTest {

@BeforeEach
void setUp() throws Exception {
var factory = new ApplicationErrorGatewayFilterFactory();
filter = factory.apply(factory.newConfig());
var factory = new ApplicationErrorGatewayFilterFactory();
filter = factory.apply(factory.newConfig());

matchedRoute = mock(Route.class);
when(matchedRoute.getUri()).thenReturn(matchedURI);

chain = mock(GatewayFilterChain.class);
when(chain.filter(any())).thenReturn(Mono.empty());
MockServerHttpRequest request = MockServerHttpRequest.get("/test").build();
exchange = MockServerWebExchange.from(request);
exchange.getAttributes().put(GATEWAY_ROUTE_ATTR, matchedRoute);
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, matchedURI);

}

@Test
void testNotAnErrorResponse() {
exchange.getResponse().setStatusCode(HttpStatus.OK);
Mono<Void> result = filter.filter(exchange, chain);
result.block();
assertThat(exchange.getResponse().getRawStatusCode()).isEqualTo(200);
}
void testNotAnErrorResponse() {
GatewayFilterChain chain = mock(GatewayFilterChain.class);

filter.filter(exchange, chain);

ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
verify(chain).filter(captor.capture());

ServerWebExchange mutated = captor.getValue();
ServerHttpResponse response = mutated.getResponse();
response.setStatusCode(HttpStatus.CREATED);

MockServerHttpResponse origResponse = exchange.getResponse();
assertThat(origResponse.getStatusCode()).isEqualTo(HttpStatus.CREATED);
}

@Test
void test4xx() {
testApplicationError(HttpStatus.BAD_REQUEST);
testApplicationError(HttpStatus.UNAUTHORIZED);
testApplicationError(HttpStatus.FORBIDDEN);
testApplicationError(HttpStatus.NOT_FOUND);
testApplicationError(HttpStatus.BAD_REQUEST);
testApplicationError(HttpStatus.UNAUTHORIZED);
testApplicationError(HttpStatus.FORBIDDEN);
testApplicationError(HttpStatus.NOT_FOUND);
}


@Test
void test5xx() {
testApplicationError(HttpStatus.INTERNAL_SERVER_ERROR);
testApplicationError(HttpStatus.SERVICE_UNAVAILABLE);
testApplicationError(HttpStatus.BAD_GATEWAY);
testApplicationError(HttpStatus.INTERNAL_SERVER_ERROR);
testApplicationError(HttpStatus.SERVICE_UNAVAILABLE);
testApplicationError(HttpStatus.BAD_GATEWAY);
}

private void testApplicationError(HttpStatus status) {
GatewayFilterChain chain = mock(GatewayFilterChain.class);
filter.filter(exchange, chain);
ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
verify(chain).filter(captor.capture());

ServerWebExchange mutated = captor.getValue();
ServerHttpResponse response = mutated.getResponse();
assertThrows(ResponseStatusException.class, () -> response.setStatusCode(status));
}

private void testApplicationError(HttpStatus status) {
exchange.getResponse().setStatusCode(status);
Mono<Void> result = filter.filter(exchange, chain);
ResponseStatusException ex = assertThrows(ResponseStatusException.class, ()-> result.block());
assertThat(ex.getStatus()).isEqualTo(status);
}
}

0 comments on commit 04f5aa2

Please sign in to comment.