Skip to content

Commit

Permalink
Add AuthorizeReturnObject
Browse files Browse the repository at this point in the history
  • Loading branch information
jzheaux committed Mar 15, 2024
1 parent c611b7e commit 3b46db4
Show file tree
Hide file tree
Showing 14 changed files with 675 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import java.util.ArrayList;
import java.util.List;

import org.aopalliance.intercept.MethodInterceptor;

import org.springframework.aop.framework.AopInfrastructureBean;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.config.BeanDefinition;
Expand All @@ -27,6 +29,7 @@
import org.springframework.context.annotation.Role;
import org.springframework.security.authorization.AuthorizationAdvisorProxyFactory;
import org.springframework.security.authorization.method.AuthorizationAdvisor;
import org.springframework.security.authorization.method.AuthorizeReturnObjectMethodInterceptor;

@Configuration(proxyBeanMethods = false)
final class AuthorizationProxyConfiguration implements AopInfrastructureBean {
Expand All @@ -41,4 +44,17 @@ static AuthorizationAdvisorProxyFactory authorizationProxyFactory(ObjectProvider
return factory;
}

@Bean
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
static MethodInterceptor authorizeReturnObjectMethodInterceptor(ObjectProvider<AuthorizationAdvisor> provider,
AuthorizationAdvisorProxyFactory authorizationProxyFactory) {
AuthorizeReturnObjectMethodInterceptor interceptor = new AuthorizeReturnObjectMethodInterceptor(
authorizationProxyFactory);
List<AuthorizationAdvisor> advisors = new ArrayList<>();
provider.forEach(advisors::add);
advisors.add(interceptor);
authorizationProxyFactory.setAdvisors(advisors);
return interceptor;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, B
registerAsAdvisor("postAuthorizeAuthorization", registry);
registerAsAdvisor("securedAuthorization", registry);
registerAsAdvisor("jsr250Authorization", registry);
registerAsAdvisor("authorizeReturnObject", registry);
}

private void registerAsAdvisor(String prefix, BeanDefinitionRegistry registry) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import java.util.ArrayList;
import java.util.List;

import org.aopalliance.intercept.MethodInterceptor;

import org.springframework.aop.framework.AopInfrastructureBean;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.config.BeanDefinition;
Expand All @@ -27,6 +29,7 @@
import org.springframework.context.annotation.Role;
import org.springframework.security.authorization.ReactiveAuthorizationAdvisorProxyFactory;
import org.springframework.security.authorization.method.AuthorizationAdvisor;
import org.springframework.security.authorization.method.AuthorizeReturnObjectMethodInterceptor;

@Configuration(proxyBeanMethods = false)
final class ReactiveAuthorizationProxyConfiguration implements AopInfrastructureBean {
Expand All @@ -42,4 +45,17 @@ static ReactiveAuthorizationAdvisorProxyFactory authorizationProxyFactory(
return factory;
}

@Bean
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
static MethodInterceptor authorizeReturnObjectMethodInterceptor(ObjectProvider<AuthorizationAdvisor> provider,
ReactiveAuthorizationAdvisorProxyFactory authorizationProxyFactory) {
AuthorizeReturnObjectMethodInterceptor interceptor = new AuthorizeReturnObjectMethodInterceptor(
authorizationProxyFactory);
List<AuthorizationAdvisor> advisors = new ArrayList<>();
provider.forEach(advisors::add);
advisors.add(interceptor);
authorizationProxyFactory.setAdvisors(advisors);
return interceptor;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Supplier;

Expand Down Expand Up @@ -60,6 +63,7 @@
import org.springframework.security.authorization.AuthorizationManager;
import org.springframework.security.authorization.method.AuthorizationInterceptorsOrder;
import org.springframework.security.authorization.method.AuthorizationManagerBeforeMethodInterceptor;
import org.springframework.security.authorization.method.AuthorizeReturnObject;
import org.springframework.security.authorization.method.MethodInvocationResult;
import org.springframework.security.authorization.method.PrePostTemplateDefaults;
import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
Expand All @@ -80,6 +84,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -662,6 +667,79 @@ public void methodWhenPostFilterMetaAnnotationThenFilters() {
.containsExactly("dave");
}

@Test
@WithMockUser(authorities = "airplane:read")
public void findByIdWhenAuthorizedResultThenAuthorizes() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
Flight flight = flights.findById("1");
assertThatNoException().isThrownBy(flight::getAltitude);
assertThatNoException().isThrownBy(flight::getSeats);
}

@Test
@WithMockUser(authorities = "seating:read")
public void findByIdWhenUnauthorizedResultThenDenies() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
Flight flight = flights.findById("1");
assertThatNoException().isThrownBy(flight::getSeats);
assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(flight::getAltitude);
}

@Test
@WithMockUser(authorities = "seating:read")
public void findAllWhenUnauthorizedResultThenDenies() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findAll().forEachRemaining((flight) -> {
assertThatNoException().isThrownBy(flight::getSeats);
assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(flight::getAltitude);
});
}

@Test
public void removeWhenAuthorizedResultThenRemoves() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.remove("1");
}

@Test
@WithMockUser(authorities = "airplane:read")
public void findAllWhenPostFilterThenFilters() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findAll()
.forEachRemaining((flight) -> assertThat(flight.getPassengers()).extracting(Passenger::getName)
.doesNotContain("Kevin Mitnick"));
}

@Test
@WithMockUser(authorities = "airplane:read")
public void findAllWhenPreFilterThenFilters() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findAll().forEachRemaining((flight) -> {
flight.board(new ArrayList<>(List.of("John")));
assertThat(flight.getPassengers()).extracting(Passenger::getName).doesNotContain("John");
flight.board(new ArrayList<>(List.of("John Doe")));
assertThat(flight.getPassengers()).extracting(Passenger::getName).contains("John Doe");
});
}

@Test
@WithMockUser(authorities = "seating:read")
public void findAllWhenNestedPreAuthorizeThenAuthorizes() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findAll().forEachRemaining((flight) -> {
List<Passenger> passengers = flight.getPassengers();
passengers.forEach((passenger) -> assertThatExceptionOfType(AccessDeniedException.class)
.isThrownBy(passenger::getName));
});
}

private static Consumer<ConfigurableWebApplicationContext> disallowBeanOverriding() {
return (context) -> ((AnnotationConfigWebApplicationContext) context).setAllowBeanDefinitionOverriding(false);
}
Expand Down Expand Up @@ -1061,4 +1139,113 @@ List<String> resultsContainDave(List<String> list) {

}

@EnableMethodSecurity
@Configuration
static class AuthorizeResultConfig {

@Bean
FlightRepository flights() {
FlightRepository flights = new FlightRepository();
Flight one = new Flight("1", 35000d, 35);
one.board(new ArrayList<>(List.of("Marie Curie", "Kevin Mitnick", "Ada Lovelace")));
flights.save(one);
Flight two = new Flight("2", 32000d, 72);
two.board(new ArrayList<>(List.of("Albert Einstein")));
flights.save(two);
return flights;
}

@Bean
RoleHierarchy roleHierarchy() {
return RoleHierarchyImpl.withRolePrefix("").role("airplane:read").implies("seating:read").build();
}

}

@AuthorizeReturnObject
static class FlightRepository {

private final Map<String, Flight> flights = new ConcurrentHashMap<>();

Iterator<Flight> findAll() {
return this.flights.values().iterator();
}

Flight findById(String id) {
return this.flights.get(id);
}

Flight save(Flight flight) {
this.flights.put(flight.getId(), flight);
return flight;
}

void remove(String id) {
this.flights.remove(id);
}

}

static class Flight {

private final String id;

private final Double altitude;

private final Integer seats;

private final List<Passenger> passengers = new ArrayList<>();

Flight(String id, Double altitude, Integer seats) {
this.id = id;
this.altitude = altitude;
this.seats = seats;
}

String getId() {
return this.id;
}

@PreAuthorize("hasAuthority('airplane:read')")
Double getAltitude() {
return this.altitude;
}

@PreAuthorize("hasAuthority('seating:read')")
Integer getSeats() {
return this.seats;
}

@AuthorizeReturnObject
@PostAuthorize("hasAuthority('seating:read')")
@PostFilter("filterObject.name != 'Kevin Mitnick'")
List<Passenger> getPassengers() {
return this.passengers;
}

@PreAuthorize("hasAuthority('seating:read')")
@PreFilter("filterObject.contains(' ')")
void board(List<String> passengers) {
for (String passenger : passengers) {
this.passengers.add(new Passenger(passenger));
}
}

}

public static class Passenger {

String name;

public Passenger(String name) {
this.name = name;
}

@PreAuthorize("hasAuthority('airplane:read')")
public String getName() {
return this.name;
}

}

}
Loading

0 comments on commit 3b46db4

Please sign in to comment.