From 6d290c93041fddfc4598b830022439dd1d7c27a1 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Thu, 29 Feb 2024 14:14:02 -0700 Subject: [PATCH] Add Authorization Proxy Support Closes gh-14596 --- .../object/AuthorizationProxyFactory.java | 136 ++++++++++++++++ .../AuthorizationProxyFactoryTests.java | 147 ++++++++++++++++++ 2 files changed, 283 insertions(+) create mode 100644 core/src/main/java/org/springframework/security/authorization/object/AuthorizationProxyFactory.java create mode 100644 core/src/test/java/org/springframework/security/authorization/object/AuthorizationProxyFactoryTests.java diff --git a/core/src/main/java/org/springframework/security/authorization/object/AuthorizationProxyFactory.java b/core/src/main/java/org/springframework/security/authorization/object/AuthorizationProxyFactory.java new file mode 100644 index 00000000000..a6991858c9e --- /dev/null +++ b/core/src/main/java/org/springframework/security/authorization/object/AuthorizationProxyFactory.java @@ -0,0 +1,136 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.authorization.object; + +import java.lang.reflect.Array; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import org.springframework.aop.Advisor; +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.util.ClassUtils; + +public final class AuthorizationProxyFactory { + + private final Collection advisors; + + public AuthorizationProxyFactory(Advisor... advisors) { + this.advisors = List.of(advisors); + } + + public AuthorizationProxyFactory(Collection advisors) { + this.advisors = List.copyOf(advisors); + } + + public AuthorizationProxyFactory withAdvisors(Advisor... advisors) { + List merged = new ArrayList<>(this.advisors.size() + 1); + merged.addAll(this.advisors); + merged.addAll(List.of(advisors)); + AnnotationAwareOrderComparator.sort(merged); + return new AuthorizationProxyFactory(merged); + } + + public Object proxy(Object target) { + if (target == null) { + return target; + } + if (ClassUtils.isSimpleValueType(target.getClass())) { + return target; + } + if (target instanceof Iterator iterator) { + return proxyIterator(iterator); + } + if (target instanceof Collection collection) { + return proxyCollection(collection); + } + if (target.getClass().isArray()) { + return proxyArray((Object[]) target); + } + if (target instanceof Map) { + return proxyMap((Map) target); + } + if (target instanceof Stream) { + return proxyStream((Stream) target); + } + return proxySingle(target); + } + + private Iterator proxyIterator(Iterator iterator) { + return new Iterator<>() { + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public T next() { + return proxySingle(iterator.next()); + } + }; + } + + private Collection proxyCollection(Collection collection) { + Collection proxies = new ArrayList<>(collection.size()); + for (T toProxy : collection) { + proxies.add(proxySingle(toProxy)); + } + collection.clear(); + collection.addAll(proxies); + return proxies; + } + + private Object proxyArray(Object[] objects) { + List retain = new ArrayList<>(objects.length); + for (Object object : objects) { + retain.add(proxySingle(object)); + } + Object[] proxies = (Object[]) Array.newInstance(objects.getClass().getComponentType(), retain.size()); + for (int i = 0; i < retain.size(); i++) { + proxies[i] = retain.get(i); + } + return proxies; + } + + private Object proxyMap(Map entries) { + Map proxies = new LinkedHashMap<>(entries.size()); + for (Map.Entry entry : entries.entrySet()) { + proxies.put(entry.getKey(), proxySingle(entry.getValue())); + } + entries.clear(); + entries.putAll(proxies); + return entries; + } + + private Object proxyStream(Stream stream) { + return stream.map(this::proxySingle).onClose(stream::close); + } + + private T proxySingle(T target) { + ProxyFactory factory = new ProxyFactory(target); + factory.addAdvisors(this.advisors); + factory.setProxyTargetClass(!Modifier.isFinal(target.getClass().getModifiers())); + return (T) factory.getProxy(); + } + +} diff --git a/core/src/test/java/org/springframework/security/authorization/object/AuthorizationProxyFactoryTests.java b/core/src/test/java/org/springframework/security/authorization/object/AuthorizationProxyFactoryTests.java new file mode 100644 index 00000000000..a4891d2d903 --- /dev/null +++ b/core/src/test/java/org/springframework/security/authorization/object/AuthorizationProxyFactoryTests.java @@ -0,0 +1,147 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.authorization.object; + +import org.junit.jupiter.api.Test; + +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.access.prepost.PreAuthorize; +import org.springframework.security.authentication.TestAuthentication; +import org.springframework.security.authorization.method.AuthorizationManagerBeforeMethodInterceptor; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +public class AuthorizationProxyFactoryTests { + + private final Authentication user = TestAuthentication.authenticatedUser(); + + private final Authentication admin = TestAuthentication.authenticatedAdmin(); + + @Test + public void proxyWhenPreAuthorizeThenHonors() { + SecurityContextHolder.getContext().setAuthentication(this.user); + AuthorizationManagerBeforeMethodInterceptor preAuthorize = AuthorizationManagerBeforeMethodInterceptor + .preAuthorize(); + AuthorizationProxyFactory factory = new AuthorizationProxyFactory(preAuthorize); + Flight flight = new Flight(); + assertThat(flight.getAltitude()).isEqualTo(35000d); + Flight secured = (Flight) factory.proxy(flight); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> secured.getAltitude()); + SecurityContextHolder.clearContext(); + } + + @Test + public void proxyWhenPreAuthorizeOnInterfaceThenHonors() { + SecurityContextHolder.getContext().setAuthentication(this.user); + AuthorizationManagerBeforeMethodInterceptor preAuthorize = AuthorizationManagerBeforeMethodInterceptor + .preAuthorize(); + AuthorizationProxyFactory factory = new AuthorizationProxyFactory(preAuthorize); + User user = new User("user", "First", "Last"); + assertThat(user.getFirstName()).isEqualTo("First"); + User secured = (User) factory.proxy(user); + assertThat(secured.getFirstName()).isEqualTo("First"); + SecurityContextHolder.getContext().setAuthentication(authenticated("wrong")); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> secured.getFirstName()); + SecurityContextHolder.getContext().setAuthentication(this.admin); + assertThat(secured.getFirstName()).isEqualTo("First"); + SecurityContextHolder.clearContext(); + } + + @Test + public void proxyWhenPreAuthorizeOnRecordThenHonors() { + SecurityContextHolder.getContext().setAuthentication(this.user); + AuthorizationManagerBeforeMethodInterceptor preAuthorize = AuthorizationManagerBeforeMethodInterceptor + .preAuthorize(); + AuthorizationProxyFactory factory = new AuthorizationProxyFactory(preAuthorize); + HasSecret repo = new Repository("secret"); + assertThat(repo.secret()).isEqualTo("secret"); + HasSecret secured = (HasSecret) factory.proxy(repo); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> secured.secret()); + SecurityContextHolder.getContext().setAuthentication(this.user); + assertThat(repo.secret()).isEqualTo("secret"); + SecurityContextHolder.clearContext(); + } + + private Authentication authenticated(String user, String... authorities) { + return TestAuthentication.authenticated(TestAuthentication.withUsername(user).authorities(authorities).build()); + } + + static class Flight { + + @PreAuthorize("hasRole('PILOT')") + Double getAltitude() { + return 35000d; + } + + } + + interface Identifiable { + + String getId(); + + @PreAuthorize("authentication.name == this.id || hasRole('ADMIN')") + String getFirstName(); + + @PreAuthorize("authentication.name == this.id || hasRole('ADMIN')") + String getLastName(); + + } + + static class User implements Identifiable { + + private final String id; + + private final String firstName; + + private final String lastName; + + User(String id, String firstName, String lastName) { + this.id = id; + this.firstName = firstName; + this.lastName = lastName; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public String getFirstName() { + return this.firstName; + } + + @Override + public String getLastName() { + return this.lastName; + } + + } + + interface HasSecret { + + String secret(); + + } + + record Repository(@PreAuthorize("hasRole('ADMIN')") String secret) implements HasSecret { + } + +}