diff --git a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/factory/RemoveRequestParameterGatewayFilterFactory.java b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/factory/RemoveRequestParameterGatewayFilterFactory.java index a2e2c92a71..9d17f6ef32 100644 --- a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/factory/RemoveRequestParameterGatewayFilterFactory.java +++ b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/factory/RemoveRequestParameterGatewayFilterFactory.java @@ -29,6 +29,7 @@ import org.springframework.util.MultiValueMap; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; import static org.springframework.cloud.gateway.support.GatewayToStringStyler.filterToStringCreator; import static org.springframework.util.CollectionUtils.unmodifiableMultiValueMap; @@ -57,14 +58,19 @@ public Mono filter(ServerWebExchange exchange, GatewayFilterChain chain) { MultiValueMap queryParams = new LinkedMultiValueMap<>(request.getQueryParams()); queryParams.remove(config.getName()); - URI newUri = UriComponentsBuilder.fromUri(request.getURI()) - .replaceQueryParams(unmodifiableMultiValueMap(queryParams)) - .build() - .toUri(); + try { + MultiValueMap encodedQueryParams = UriUtils.encodeQueryParams(queryParams); + URI newUri = UriComponentsBuilder.fromUri(request.getURI()) + .replaceQueryParams(unmodifiableMultiValueMap(encodedQueryParams)) + .build(true) + .toUri(); - ServerHttpRequest updatedRequest = exchange.getRequest().mutate().uri(newUri).build(); - - return chain.filter(exchange.mutate().request(updatedRequest).build()); + ServerHttpRequest updatedRequest = exchange.getRequest().mutate().uri(newUri).build(); + return chain.filter(exchange.mutate().request(updatedRequest).build()); + } + catch (IllegalArgumentException ex) { + throw new IllegalStateException("Invalid URI query: \"" + queryParams + "\""); + } } @Override diff --git a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/factory/RewriteRequestParameterGatewayFilterFactory.java b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/factory/RewriteRequestParameterGatewayFilterFactory.java index ecffdf154d..02bccdaf69 100644 --- a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/factory/RewriteRequestParameterGatewayFilterFactory.java +++ b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/factory/RewriteRequestParameterGatewayFilterFactory.java @@ -26,10 +26,14 @@ import org.springframework.cloud.gateway.filter.GatewayFilterChain; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; import static org.springframework.cloud.gateway.support.GatewayToStringStyler.filterToStringCreator; +import static org.springframework.util.CollectionUtils.unmodifiableMultiValueMap; /** * @author Fredrich Ombico @@ -59,14 +63,25 @@ public Mono filter(ServerWebExchange exchange, GatewayFilterChain chain) { ServerHttpRequest req = exchange.getRequest(); UriComponentsBuilder uriComponentsBuilder = UriComponentsBuilder.fromUri(req.getURI()); - if (req.getQueryParams().containsKey(config.getName())) { - uriComponentsBuilder.replaceQueryParam(config.getName(), config.getReplacement()); + + MultiValueMap queryParams = new LinkedMultiValueMap<>(req.getQueryParams()); + if (queryParams.containsKey(config.getName())) { + queryParams.remove(config.getName()); + queryParams.add(config.getName(), config.getReplacement()); } - URI uri = uriComponentsBuilder.build().toUri(); - ServerHttpRequest request = req.mutate().uri(uri).build(); + try { + MultiValueMap encodedQueryParams = UriUtils.encodeQueryParams(queryParams); + URI uri = uriComponentsBuilder.replaceQueryParams(unmodifiableMultiValueMap(encodedQueryParams)) + .build(true) + .toUri(); - return chain.filter(exchange.mutate().request(request).build()); + ServerHttpRequest request = req.mutate().uri(uri).build(); + return chain.filter(exchange.mutate().request(request).build()); + } + catch (IllegalArgumentException ex) { + throw new IllegalStateException("Invalid URI query: \"" + queryParams + "\""); + } } @Override diff --git a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/RemoveRequestParameterGatewayFilterFactoryTests.java b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/RemoveRequestParameterGatewayFilterFactoryTests.java index a1fdf471ef..a96aef3864 100644 --- a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/RemoveRequestParameterGatewayFilterFactoryTests.java +++ b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/RemoveRequestParameterGatewayFilterFactoryTests.java @@ -37,7 +37,7 @@ /** * @author Thirunavukkarasu Ravichandran */ -public class RemoveRequestParameterGatewayFilterFactoryTests { +class RemoveRequestParameterGatewayFilterFactoryTests { private ServerWebExchange exchange; @@ -46,7 +46,7 @@ public class RemoveRequestParameterGatewayFilterFactoryTests { private ArgumentCaptor captor; @BeforeEach - public void setUp() { + void setUp() { filterChain = mock(GatewayFilterChain.class); captor = ArgumentCaptor.forClass(ServerWebExchange.class); when(filterChain.filter(captor.capture())).thenReturn(Mono.empty()); @@ -54,7 +54,7 @@ public void setUp() { } @Test - public void removeRequestParameterFilterWorks() { + void removeRequestParameterFilterWorks() { MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost") .queryParam("foo", singletonList("bar")) .build(); @@ -70,7 +70,7 @@ public void removeRequestParameterFilterWorks() { } @Test - public void removeRequestParameterFilterWorksWhenParamIsNotPresentInRequest() { + void removeRequestParameterFilterWorksWhenParamIsNotPresentInRequest() { MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost").build(); exchange = MockServerWebExchange.from(request); NameConfig config = new NameConfig(); @@ -84,7 +84,7 @@ public void removeRequestParameterFilterWorksWhenParamIsNotPresentInRequest() { } @Test - public void removeRequestParameterFilterShouldOnlyRemoveSpecifiedParam() { + void removeRequestParameterFilterShouldOnlyRemoveSpecifiedParam() { MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost") .queryParam("foo", "bar") .queryParam("abc", "xyz") @@ -102,7 +102,7 @@ public void removeRequestParameterFilterShouldOnlyRemoveSpecifiedParam() { } @Test - public void removeRequestParameterFilterShouldHandleRemainingParamsWhichRequiringEncoding() { + void removeRequestParameterFilterShouldHandleRemainingParamsWhichRequiringEncoding() { MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost") .queryParam("foo", "bar") .queryParam("aaa", "abc xyz") @@ -123,4 +123,40 @@ public void removeRequestParameterFilterShouldHandleRemainingParamsWhichRequirin assertThat(actualRequest.getQueryParams()).containsEntry("ccc", singletonList(",xyz")); } + @Test + void removeRequestParameterFilterShouldHandleEncodedParameterName() { + MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost") + .queryParam("foo", "bar") + .queryParam("baz[]", "qux") + .build(); + exchange = MockServerWebExchange.from(request); + NameConfig config = new NameConfig(); + config.setName("baz[]"); + GatewayFilter filter = new RemoveRequestParameterGatewayFilterFactory().apply(config); + + filter.filter(exchange, filterChain); + + ServerHttpRequest actualRequest = captor.getValue().getRequest(); + assertThat(actualRequest.getQueryParams()).doesNotContainKey("baz[]"); + assertThat(actualRequest.getQueryParams()).containsEntry("foo", singletonList("bar")); + } + + @Test + void removeRequestParameterFilterShouldMaintainEncodedParameters() { + MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost") + .queryParam("foo", "bar") + .queryParam("baz[]", "qux") + .build(); + exchange = MockServerWebExchange.from(request); + NameConfig config = new NameConfig(); + config.setName("foo"); + GatewayFilter filter = new RemoveRequestParameterGatewayFilterFactory().apply(config); + + filter.filter(exchange, filterChain); + + ServerHttpRequest actualRequest = captor.getValue().getRequest(); + assertThat(actualRequest.getQueryParams()).doesNotContainKey("foo"); + assertThat(actualRequest.getQueryParams()).containsEntry("baz[]", singletonList("qux")); + } + } diff --git a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/RewriteRequestParameterGatewayFilterFactoryTests.java b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/RewriteRequestParameterGatewayFilterFactoryTests.java index 9df80c61f3..f780386152 100644 --- a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/RewriteRequestParameterGatewayFilterFactoryTests.java +++ b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/RewriteRequestParameterGatewayFilterFactoryTests.java @@ -71,11 +71,23 @@ void rewriteRequestParameterFilterDoesNotAddParamIfNameNotFound() { } @Test - void rewriteRequestParameterFilterWorksWithSpecialCharacters() { + void rewriteRequestParameterFilterWithSpecialCharactersInParameterValue() { testRewriteRequestParameterFilter("campaign", "black friday~(1.A-B_C!)", "campaign=old&color=green", Map.of("campaign", List.of("black friday~(1.A-B_C!)"), "color", List.of("green"))); } + @Test + void rewriteRequestParameterFilterWithSpecialCharactersInParameterName() { + testRewriteRequestParameterFilter("campaign[]", "red", "campaign%5B%5D=blue&color=green", + Map.of("campaign[]", List.of("red"), "color", List.of("green"))); + } + + @Test + void rewriteRequestParameterFilterKeepsOtherParamsEncoded() { + testRewriteRequestParameterFilter("color", "white", "campaign%5B%5D=blue&color=green", + Map.of("campaign[]", List.of("blue"), "color", List.of("white"))); + } + private void testRewriteRequestParameterFilter(String name, String replacement, String query, Map> expectedQueryParams) { GatewayFilter filter = new RewriteRequestParameterGatewayFilterFactory()