Skip to content

Commit

Permalink
Added mvc method, added servlet aware filter
Browse files Browse the repository at this point in the history
  • Loading branch information
jzheaux committed Sep 18, 2023
1 parent 19cea77 commit a1ee14d
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 201 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletRegistration;
import jakarta.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.context.ApplicationContext;
import org.springframework.http.HttpMethod;
Expand All @@ -40,6 +43,8 @@ public class RequestMatchersBuilder {

private static final String HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME = "mvcHandlerMappingIntrospector";

private static final Log logger = LogFactory.getLog(RequestMatchersBuilder.class);

private final ApplicationContext context;

private final RequestMatcherBuilder builder;
Expand All @@ -65,21 +70,46 @@ private static RequestMatcherBuilder requestMatcherBuilder(ApplicationContext co
if (!hasIntrospector) {
return new AntPathRequestMatcherBuilder(servletPath);
}
if (!hasDispatcherServlet(registrations)) {
return new AntPathRequestMatcherBuilder(servletPath);
}
if (registrations.isEmpty()) {
if (registrations == null || registrations.isEmpty()) {
return new MvcRequestMatcherBuilder(context, servletPath);
}
Collection<ServletRegistration> dispatcherServlets = dispatcherServlets(registrations);
if (dispatcherServlets.isEmpty()) {
return new AntPathRequestMatcherBuilder(servletPath);
}
if (registrations.size() == 1) {
ServletRegistration registration = registrations.iterator().next();
if (servletPath == null) {
servletPath = deduceServletPath(registration);
if (servletPath != null) {
return new MvcRequestMatcherBuilder(context, servletPath);
}
return isDispatcherServlet(registration) ? new MvcRequestMatcherBuilder(context, servletPath)
: new AntPathRequestMatcherBuilder(servletPath);
Collection<String> mappings = registration.getMappings();
if (mappings.size() != 1) {
return null;
}
String mapping = mappings.iterator().next();
if ("/".equals(mapping)) {
return new MvcRequestMatcherBuilder(context, null);
}
return new MvcRequestMatcherBuilder(context, mapping);
}
return null;
if (dispatcherServlets.size() > 1) {
return null;
}
Collection<String> mappings = dispatcherServlets.iterator().next().getMappings();
if (mappings.size() != 1) {
return null;
}
logger.warn(computeErrorMessage("Your configuration has multiple path-based servlets. As such, you should "
+ "declare your authorization rules using a RequestMatchersBuilder bean, specifying the servlet path "
+ "in each pattern, as follows: " + "\n" + "\n\thttp "
+ "\n\t\t.authorizeHttpRequests((authorize) -> authorize"
+ "\n\t\t\t.requestMatchers(requestMatchersBuilder.servletPath(\"/\").matchers(\"/my/**\", \"/endpoints/**\")).hasAuthority(...) "
+ "\n\n" + "As an alternative, you can remove any unneeded servlets from your application. "
+ "For your reference, your the servlet paths in your configuration are as follows: %s",
registrations));
return new ServletPathAwareRequestMatcherBuilder(
new MvcRequestMatcherBuilder(context, mappings.iterator().next()),
new AntPathRequestMatcherBuilder(null));
}

private static Collection<ServletRegistration> registrations(ApplicationContext context, String servletPath) {
Expand All @@ -101,59 +131,52 @@ private static Collection<ServletRegistration> registrations(ApplicationContext
continue;
}
if (servletPath == null) {
for (String mapping : mappings) {
if (mapping.equals("/") || mapping.endsWith("/*")) {
filtered.add(registration);
break;
}
}
continue;
filtered.add(registration);
}
if (mappings.contains(servletPath) || mappings.contains(servletPath + "/*")) {
else if (mappings.contains(servletPath) || mappings.contains(servletPath + "/*")) {
filtered.add(registration);
}
}
return filtered;
}

private static boolean hasDispatcherServlet(Collection<ServletRegistration> registrations) {
for (ServletRegistration registration : registrations) {
if (isDispatcherServlet(registration)) {
return true;
}
if (servletPath == null) {
return filtered;
}
if (filtered.isEmpty()) {
throw new IllegalArgumentException(computeErrorMessage(
"The servlet path you specified does not seem to match any " + "configured servlets: %s",
registrations.values()));
}
return false;
return filtered;
}

private static boolean isDispatcherServlet(ServletRegistration registration) {
private static Collection<ServletRegistration> dispatcherServlets(Collection<ServletRegistration> registrations) {
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
null);
try {
Class<?> clazz = Class.forName(registration.getClassName());
if (dispatcherServlet.isAssignableFrom(clazz)) {
return true;
Collection<ServletRegistration> dispatcherServlets = new ArrayList<>();
for (ServletRegistration registration : registrations) {
try {
Class<?> clazz = Class.forName(registration.getClassName());
if (dispatcherServlet.isAssignableFrom(clazz)) {
dispatcherServlets.add(registration);
}
}
catch (ClassNotFoundException ignored) {
// ignore
}
}
catch (ClassNotFoundException ex) {
return false;
}
return false;
return dispatcherServlets;
}

private static String deduceServletPath(ServletRegistration registration) {
Collection<String> mappings = registration.getMappings();
if (mappings.size() > 1) {
return null;
}
String mapping = mappings.iterator().next();
if (mapping.endsWith("/*")) {
return mapping.substring(0, mapping.length() - 2);
private static String computeErrorMessage(String template,
Collection<? extends ServletRegistration> registrations) {
Map<String, Collection<String>> mappings = new LinkedHashMap<>();
for (ServletRegistration registration : registrations) {
mappings.put(registration.getClassName(), registration.getMappings());
}
return null;
return String.format(template, mappings);
}

public RequestMatcher matcher() {
Assert.notNull(this.servletPath, computeErrorMessage());
Assert.notNull(this.servletPath, "To use `#matcher`, you must also specify a servlet path");
return new AntPathRequestMatcher(this.servletPath);
}

Expand All @@ -175,29 +198,39 @@ public RequestMatcher[] matchers(String... patterns) {
return matchers;
}

public RequestMatchersBuilder mvc() {
Collection<ServletRegistration> dispatcherServlets = dispatcherServlets(this.registrations);
if (dispatcherServlets.isEmpty()) {
throw new IllegalArgumentException(
"Spring MVC does not appear to be configured for this application; please either configure Spring MVC or use `#servletPath` instead.");
}
if (dispatcherServlets.size() > 1) {
throw new IllegalArgumentException(
"There appears to be more than one dispatcher servlet configured. As such, you will need to use `#servletPath` instead in order to specify which path these matchers are for.");
}
if (dispatcherServlets.iterator().next().getMappings().size() > 1) {
throw new IllegalArgumentException(
"There apppears to be more than one mapping for this dispatcher servlet. As such, you will need to use `#servletPath` instead in order to specify which path these matchers are for.");
}
return servletPath(dispatcherServlets.iterator().next().getMappings().iterator().next());
}

public RequestMatchersBuilder servletPath(String path) {
return new RequestMatchersBuilder(this.context, path);
}

private void checkServletPath() {
if (this.builder == null) {
throw new IllegalArgumentException(computeErrorMessage());
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
+ "You will need to specify the servlet path for each endpoint to assist with disambiguation. "
+ "\n\nFor your reference, these are the servlets that have potentially ambiguous paths: %s"
+ "\n\nTo do this, you can use the RequestMatchersBuilder bean in conjunction with requestMatchers like so: "
+ "\n\n\thttp" + "\n\t\t.authorizeHttpRequests((authorize) -> authorize"
+ "\n\t\t\t.requestMatchers(builder.servletPath(\"/\").matchers(\"/my\", \"/controller\", \"endpoints\")).";
throw new IllegalArgumentException(computeErrorMessage(template, this.registrations));
}
}

private String computeErrorMessage() {
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
+ "You will need to specify the servlet path for each endpoint to assist with disambiguation. "
+ "\n\nFor your reference, these are the servlets that have potentially ambiguous paths: %s"
+ "\n\nTo do this, you can use the RequestMatchersBuilder bean in conjunction with requestMatchers like so: "
+ "\n\n\t.requestMatchers(builder.servletPath(\"/\").matchers(\"/my\", \"/controller\", \"endpoints\")).";
Map<String, Collection<String>> mappings = new LinkedHashMap<>();
for (ServletRegistration registration : this.registrations) {
mappings.put(registration.getClassName(), registration.getMappings());
}
return String.format(template, mappings);
}

private interface RequestMatcherBuilder {

RequestMatcher matcher(String pattern);
Expand All @@ -215,11 +248,16 @@ private static final class MvcRequestMatcherBuilder implements RequestMatcherBui
private MvcRequestMatcherBuilder(ApplicationContext context, String servletPath) {
this.introspector = context.getBean(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME,
HandlerMappingIntrospector.class);
this.servletPath = servletPath;
if (servletPath != null && servletPath.endsWith("/*")) {
this.servletPath = servletPath.substring(0, servletPath.length() - 2);
}
else {
this.servletPath = servletPath;
}
}

@Override
public RequestMatcher matcher(String pattern) {
public MvcRequestMatcher matcher(String pattern) {
MvcRequestMatcher matcher = new MvcRequestMatcher(this.introspector, pattern);
if (this.servletPath != null) {
matcher.setServletPath(this.servletPath);
Expand All @@ -228,7 +266,7 @@ public RequestMatcher matcher(String pattern) {
}

@Override
public RequestMatcher matcher(HttpMethod method, String pattern) {
public MvcRequestMatcher matcher(HttpMethod method, String pattern) {
MvcRequestMatcher matcher = new MvcRequestMatcher(this.introspector, pattern);
matcher.setMethod(method);
if (this.servletPath != null) {
Expand All @@ -244,20 +282,25 @@ private static final class AntPathRequestMatcherBuilder implements RequestMatche
private final String servletPath;

private AntPathRequestMatcherBuilder(String servletPath) {
this.servletPath = servletPath;
if (servletPath != null && servletPath.endsWith("/*")) {
this.servletPath = servletPath.substring(0, servletPath.length() - 2);
}
else {
this.servletPath = servletPath;
}
}

@Override
public RequestMatcher matcher(String pattern) {
public AntPathRequestMatcher matcher(String pattern) {
return matcher((String) null, pattern);
}

@Override
public RequestMatcher matcher(HttpMethod method, String pattern) {
public AntPathRequestMatcher matcher(HttpMethod method, String pattern) {
return matcher((method != null) ? method.name() : null, pattern);
}

private RequestMatcher matcher(String method, String pattern) {
private AntPathRequestMatcher matcher(String method, String pattern) {
return new AntPathRequestMatcher(prependServletPath(pattern), method);
}

Expand All @@ -273,4 +316,79 @@ private String prependServletPath(String pattern) {

}

private static final class ServletPathAwareRequestMatcherBuilder implements RequestMatcherBuilder {

private final MvcRequestMatcherBuilder mvc;

private final AntPathRequestMatcherBuilder ant;

private ServletPathAwareRequestMatcherBuilder(MvcRequestMatcherBuilder mvc, AntPathRequestMatcherBuilder ant) {
this.mvc = mvc;
this.ant = ant;
}

@Override
public RequestMatcher matcher(String pattern) {
MvcRequestMatcher mvc = this.mvc.matcher(pattern);
AntPathRequestMatcher ant = this.ant.matcher(pattern);
return new ServletPathAwareRequestMatcher(mvc, ant);
}

@Override
public RequestMatcher matcher(HttpMethod method, String pattern) {
MvcRequestMatcher mvc = this.mvc.matcher(method, pattern);
AntPathRequestMatcher ant = this.ant.matcher(method, pattern);
return new ServletPathAwareRequestMatcher(mvc, ant);
}

}

static final class ServletPathAwareRequestMatcher implements RequestMatcher {

final MvcRequestMatcher mvc;

final AntPathRequestMatcher ant;

ServletPathAwareRequestMatcher(MvcRequestMatcher mvc, AntPathRequestMatcher ant) {
this.mvc = mvc;
this.ant = ant;
}

@Override
public boolean matches(HttpServletRequest request) {
String servletName = request.getHttpServletMapping().getServletName();
ServletRegistration registration = request.getServletContext().getServletRegistration(servletName);
if (isDispatcherServlet(registration)) {
return this.mvc.matches(request);
}
return this.ant.matches(request);
}

@Override
public MatchResult matcher(HttpServletRequest request) {
String servletName = request.getHttpServletMapping().getServletName();
ServletRegistration registration = request.getServletContext().getServletRegistration(servletName);
if (isDispatcherServlet(registration)) {
return this.mvc.matcher(request);
}
return this.ant.matcher(request);
}

private static boolean isDispatcherServlet(ServletRegistration registration) {
Class<?> dispatcherServlet = ClassUtils
.resolveClassName("org.springframework.web.servlet.DispatcherServlet", null);
try {
Class<?> clazz = Class.forName(registration.getClassName());
if (dispatcherServlet.isAssignableFrom(clazz)) {
return true;
}
}
catch (ClassNotFoundException ex) {
return false;
}
return false;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,20 @@ public void requestMatchersWhenMvcPresentInClassPathAndMvcIntrospectorBeanNotAva
public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("servletOne", Servlet.class).addMapping("/one");
servletContext.addServlet("servletTwo", Servlet.class).addMapping("/two");
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).isNotEmpty();
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class);
servletContext.addServlet("servletOne", Servlet.class);
servletContext.addServlet("servletTwo", Servlet.class);
requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).isNotEmpty();
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class);
}

@Test
public void requestMatchersWhenAmbiguousServletsThenException() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/*");
servletContext.addServlet("servletTwo", DispatcherServlet.class).addMapping("/servlet/*");
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
}
Expand Down
Loading

0 comments on commit a1ee14d

Please sign in to comment.