Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix truncated response body when catching application errors to display standardized error page #124

Merged
merged 1 commit into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ public class FiltersAutoConfiguration {
* matched Route's GeorchestraTargetConfig for each HTTP request-response
* interaction before other filters are applied.
*/
@Bean ResolveTargetGlobalFilter resolveTargetWebFilter(GatewayConfigProperties config) {
@Bean
ResolveTargetGlobalFilter resolveTargetWebFilter(GatewayConfigProperties config) {
return new ResolveTargetGlobalFilter(config);
}

Expand All @@ -66,7 +67,8 @@ public class FiltersAutoConfiguration {
return new StripBasePathGatewayFilterFactory();
}

@Bean ApplicationErrorGatewayFilterFactory applicationErrorGatewayFilterFactory() {
@Bean
ApplicationErrorGatewayFilterFactory applicationErrorGatewayFilterFactory() {
return new ApplicationErrorGatewayFilterFactory();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,110 @@
*/
package org.georchestra.gateway.filter.global;

import java.net.URI;

import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.filter.factory.GatewayFilterFactory;
import org.springframework.cloud.gateway.support.HttpStatusHolder;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.core.Ordered;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.Nullable;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;

import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Mono;

/**
* Filter to allow custom error pages to be used when an application behind the
* gateways returns an error.
* <p>
* {@link GatewayFilterFactory} providing a {@link GatewayFilter} that throws a
* {@link ResponseStatusException} with the proxied response status code if the
* target responded with a {@code 400...} or {@code 500...} status code.
*
* <p>
* Usage: to enable it globally, add this to application.yaml :
*
* <pre>
* <code>
* spring:
* cloud:
* gateway:
* default-filters:
* - ApplicationError
* </code>
* </pre>
*
* To enable it only on some routes, add this to concerned routes in
* {@literal routes.yaml}:
*
* <pre>
* <code>
* filters:
* - name: ApplicationError
* </code>
* </pre>
*/
@Slf4j
public class ApplicationErrorGatewayFilterFactory extends AbstractGatewayFilterFactory<Object> {

public ApplicationErrorGatewayFilterFactory() {
super(Object.class);
}

@Override
public GatewayFilter apply(final Object config) {
return new ServiceErrorGatewayFilter();
}

private static class ServiceErrorGatewayFilter implements GatewayFilter, Ordered {

public @Override Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
return chain.filter(exchange).then(Mono.fromRunnable(() -> {
HttpStatus statusCode = exchange.getResponse().getStatusCode();
if (statusCode.is4xxClientError() || statusCode.is5xxServerError()) {
throw new ResponseStatusException(statusCode);
}
}));
}

@Override
public int getOrder() {
return ResolveTargetGlobalFilter.ORDER + 1;
}
}
public ApplicationErrorGatewayFilterFactory() {
super(Object.class);
}

@Override
public GatewayFilter apply(final Object config) {
return new ServiceErrorGatewayFilter();
}

private static class ServiceErrorGatewayFilter implements GatewayFilter, Ordered {

public @Override Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {

ApplicationErrorConveyorHttpResponse response;
response = new ApplicationErrorConveyorHttpResponse(exchange.getResponse());

exchange = exchange.mutate().response(response).build();
return chain.filter(exchange);
}

@Override
public int getOrder() {
return ResolveTargetGlobalFilter.ORDER + 1;
}

}

/**
* A response decorator that throws a {@link ResponseStatusException} at
* {@link #setStatusCode(HttpStatus)} if the status code is an error code, thus
* letting the gateway render the appropriate custom error page instead of the
* original application response body.
*/
private static class ApplicationErrorConveyorHttpResponse extends ServerHttpResponseDecorator {

public ApplicationErrorConveyorHttpResponse(ServerHttpResponse delegate) {
super(delegate);
}

@Override
public boolean setStatusCode(@Nullable HttpStatus status) {
checkStatusCode(status);
return super.setStatusCode(status);
}

private void checkStatusCode(HttpStatus statusCode) {
log.debug("native status code: {}", statusCode);
if (statusCode.is4xxClientError() || statusCode.is5xxServerError()) {
log.debug("Conveying {} response status", statusCode);
throw new ResponseStatusException(statusCode);
}
}
}
}
1 change: 1 addition & 0 deletions gateway/src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ spring:
- RemoveSecurityHeaders
# AddSecHeaders appends sec-* headers to proxied requests based on the currently authenticated user
- AddSecHeaders
- ApplicationError
global-filter:
websocket-routing:
enabled: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR;

import java.net.URI;
Expand All @@ -32,21 +35,24 @@
import org.georchestra.gateway.model.RoleBasedAccessRule;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.handler.FilteringWebHandler;
import org.springframework.cloud.gateway.route.Route;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;

import reactor.core.publisher.Mono;

class ApplicationErrorGatewayFilterFactoryTest {


private GatewayFilterChain chain;
private GatewayFilter filter;

private GatewayFilter filter;
private MockServerWebExchange exchange;

final URI matchedURI = URI.create("http://fake.backend.com:8080");
Expand All @@ -57,47 +63,59 @@ class ApplicationErrorGatewayFilterFactoryTest {

@BeforeEach
void setUp() throws Exception {
var factory = new ApplicationErrorGatewayFilterFactory();
filter = factory.apply(factory.newConfig());
var factory = new ApplicationErrorGatewayFilterFactory();
filter = factory.apply(factory.newConfig());

matchedRoute = mock(Route.class);
when(matchedRoute.getUri()).thenReturn(matchedURI);

chain = mock(GatewayFilterChain.class);
when(chain.filter(any())).thenReturn(Mono.empty());
MockServerHttpRequest request = MockServerHttpRequest.get("/test").build();
exchange = MockServerWebExchange.from(request);
exchange.getAttributes().put(GATEWAY_ROUTE_ATTR, matchedRoute);
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, matchedURI);

}

@Test
void testNotAnErrorResponse() {
exchange.getResponse().setStatusCode(HttpStatus.OK);
Mono<Void> result = filter.filter(exchange, chain);
result.block();
assertThat(exchange.getResponse().getRawStatusCode()).isEqualTo(200);
}
void testNotAnErrorResponse() {
GatewayFilterChain chain = mock(GatewayFilterChain.class);

filter.filter(exchange, chain);

ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
verify(chain).filter(captor.capture());

ServerWebExchange mutated = captor.getValue();
ServerHttpResponse response = mutated.getResponse();
response.setStatusCode(HttpStatus.CREATED);

MockServerHttpResponse origResponse = exchange.getResponse();
assertThat(origResponse.getStatusCode()).isEqualTo(HttpStatus.CREATED);
}

@Test
void test4xx() {
testApplicationError(HttpStatus.BAD_REQUEST);
testApplicationError(HttpStatus.UNAUTHORIZED);
testApplicationError(HttpStatus.FORBIDDEN);
testApplicationError(HttpStatus.NOT_FOUND);
testApplicationError(HttpStatus.BAD_REQUEST);
testApplicationError(HttpStatus.UNAUTHORIZED);
testApplicationError(HttpStatus.FORBIDDEN);
testApplicationError(HttpStatus.NOT_FOUND);
}


@Test
void test5xx() {
testApplicationError(HttpStatus.INTERNAL_SERVER_ERROR);
testApplicationError(HttpStatus.SERVICE_UNAVAILABLE);
testApplicationError(HttpStatus.BAD_GATEWAY);
testApplicationError(HttpStatus.INTERNAL_SERVER_ERROR);
testApplicationError(HttpStatus.SERVICE_UNAVAILABLE);
testApplicationError(HttpStatus.BAD_GATEWAY);
}

private void testApplicationError(HttpStatus status) {
GatewayFilterChain chain = mock(GatewayFilterChain.class);
filter.filter(exchange, chain);
ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
verify(chain).filter(captor.capture());

ServerWebExchange mutated = captor.getValue();
ServerHttpResponse response = mutated.getResponse();
assertThrows(ResponseStatusException.class, () -> response.setStatusCode(status));
}

private void testApplicationError(HttpStatus status) {
exchange.getResponse().setStatusCode(status);
Mono<Void> result = filter.filter(exchange, chain);
ResponseStatusException ex = assertThrows(ResponseStatusException.class, ()-> result.block());
assertThat(ex.getStatus()).isEqualTo(status);
}
}
Loading