diff --git a/shared/src/main/java/gov/hhs/cdc/trustedintermediary/context/ApplicationContext.java b/shared/src/main/java/gov/hhs/cdc/trustedintermediary/context/ApplicationContext.java index 04de60817..61b261c72 100644 --- a/shared/src/main/java/gov/hhs/cdc/trustedintermediary/context/ApplicationContext.java +++ b/shared/src/main/java/gov/hhs/cdc/trustedintermediary/context/ApplicationContext.java @@ -31,6 +31,8 @@ public class ApplicationContext { protected static final Map TEST_ENV_VARS = new ConcurrentHashMap<>(); protected static final Set IMPLEMENTATIONS = new HashSet<>(); + protected static boolean skipMissingImplementations = false; + protected ApplicationContext() {} public static void register(Class clazz, Object implementation) { @@ -53,17 +55,39 @@ public static Set> getImplementors(Class interfaze) { } public static void injectRegisteredImplementations() { - injectRegisteredImplementations(false); + doInjectRegisteredImplementations(); } - protected static void injectRegisteredImplementations(boolean skipMissingImplementations) { + protected static void doInjectRegisteredImplementations() { var fields = Reflection.getFieldsAnnotatedWith(Inject.class); - fields.forEach(field -> injectIntoField(field, skipMissingImplementations)); + fields.forEach(ApplicationContext::injectIntoField); + } + + public static void injectIntoNonSingleton(Object instance) { + var fields = Reflection.getFieldsAnnotatedWithInstance(instance.getClass(), Inject.class); + + fields.forEach(field -> injectIntoField(field, instance)); } - private static void injectIntoField(Field field, boolean skipMissingImplementations) { + private static void injectIntoField(Field field, Object instance) { var fieldType = field.getType(); + + Object fieldImplementation = getFieldImplementation(fieldType); + if (fieldImplementation == null) { + return; + } + + field.trySetAccessible(); + try { + field.set(instance, fieldImplementation); + } catch (IllegalAccessException | IllegalArgumentException exception) { + throw new IllegalArgumentException( + "unable to inject " + fieldType + " into " + instance.getClass(), exception); + } + } + + private static void injectIntoField(Field field) { var declaringClass = field.getDeclaringClass(); if (!IMPLEMENTATIONS.contains(declaringClass)) { @@ -76,29 +100,16 @@ private static void injectIntoField(Field field, boolean skipMissingImplementati declaringClassesToTry.add(declaringClass); declaringClassesToTry.addAll(Arrays.asList(declaringClass.getInterfaces())); - Object fieldImplementation = getFieldImplementation(fieldType, skipMissingImplementations); - if (fieldImplementation == null) { - return; - } - Object declaringClassImplementation = - getDeclaringClassImplementation(declaringClassesToTry, skipMissingImplementations); + getDeclaringClassImplementation(declaringClassesToTry); if (declaringClassImplementation == null) { return; } - field.trySetAccessible(); - - try { - field.set(declaringClassImplementation, fieldImplementation); - } catch (IllegalAccessException | IllegalArgumentException exception) { - throw new IllegalArgumentException( - "Unable to inject " + fieldType + " into " + declaringClass, exception); - } + injectIntoField(field, declaringClassImplementation); } - private static Object getFieldImplementation( - Class fieldType, boolean skipMissingImplementations) { + private static Object getFieldImplementation(Class fieldType) { Object fieldImplementation; try { @@ -116,8 +127,7 @@ private static Object getFieldImplementation( return fieldImplementation; } - private static Object getDeclaringClassImplementation( - List> declaringClassesToTry, boolean skipMissingImplementations) { + private static Object getDeclaringClassImplementation(List> declaringClassesToTry) { Object declaringClassImplementation = declaringClassesToTry.stream() .map( diff --git a/shared/src/main/java/gov/hhs/cdc/trustedintermediary/context/Reflection.java b/shared/src/main/java/gov/hhs/cdc/trustedintermediary/context/Reflection.java index 482dade7a..de6d58176 100644 --- a/shared/src/main/java/gov/hhs/cdc/trustedintermediary/context/Reflection.java +++ b/shared/src/main/java/gov/hhs/cdc/trustedintermediary/context/Reflection.java @@ -3,8 +3,11 @@ import static org.reflections.scanners.Scanners.FieldsAnnotated; import static org.reflections.scanners.Scanners.SubTypes; +import java.lang.annotation.Annotation; import java.lang.reflect.Field; +import java.util.Arrays; import java.util.Set; +import java.util.stream.Collectors; import org.reflections.Reflections; /** @@ -27,4 +30,10 @@ public static Set> getImplementors(Class interfaze) { public static Set getFieldsAnnotatedWith(Class annotation) { return REFLECTIONS.get(FieldsAnnotated.with(annotation).as(Field.class)); } + + public static Set getFieldsAnnotatedWithInstance(Class clazz, Class annotation) { + return Arrays.stream(clazz.getDeclaredFields()) + .filter(field -> field.isAnnotationPresent(annotation.asSubclass(Annotation.class))) + .collect(Collectors.toSet()); + } } diff --git a/shared/src/test/groovy/gov/hhs/cdc/trustedintermediary/context/ApplicationContextTest.groovy b/shared/src/test/groovy/gov/hhs/cdc/trustedintermediary/context/ApplicationContextTest.groovy index 59902b4f8..72c45dca7 100644 --- a/shared/src/test/groovy/gov/hhs/cdc/trustedintermediary/context/ApplicationContextTest.groovy +++ b/shared/src/test/groovy/gov/hhs/cdc/trustedintermediary/context/ApplicationContextTest.groovy @@ -1,5 +1,6 @@ package gov.hhs.cdc.trustedintermediary.context +import gov.hhs.cdc.trustedintermediary.wrappers.Logger import spock.lang.Specification import javax.inject.Inject @@ -8,6 +9,34 @@ import java.nio.file.Paths class ApplicationContextTest extends Specification { + interface TestingInterface { + void test() + } + + class NonSingletonClazz { + @Inject + Logger logger + void test() {} + } + + static class DogCow implements TestingInterface { + + @Override + void test() { + print("test()") + } + } + + static class DogCowTwo implements TestingInterface { + + @Override + void test() { + print("testTwo()") + } + } + def DOGCOW = new DogCow() + def DOGCOWTWO = new DogCowTwo() + def setup() { TestApplicationContext.reset() } @@ -21,6 +50,48 @@ class ApplicationContextTest extends Specification { result == ApplicationContext.getImplementation(String.class) } + def "implementors retrieval test"() { + setup: + def dogCow = DOGCOW + def dogCowTwo = DOGCOWTWO + def implementors = new HashSet() + implementors.add(DogCow) + implementors.add(DogCowTwo) + + expect: + implementors == ApplicationContext.getImplementors(TestingInterface) + } + + def "injectIntoNonSingleton unhappy path"() { + given: + def nonSingletonClass = new NonSingletonClazz() + def object = new Object() + ApplicationContext.register(Logger, object) + when: + ApplicationContext.injectIntoNonSingleton(nonSingletonClass) + then: + thrown(IllegalArgumentException) + } + + def "injectIntoNonSingleton unhappy path when fieldImplementation runs into an error"() { + given: + def nonSingletonClass = new NonSingletonClazz() + when: + ApplicationContext.injectIntoNonSingleton(nonSingletonClass) + then: + thrown(IllegalArgumentException) + } + + def "injectIntoNonSingleton unhappy path when fieldImplementation is null"() { + given: + def nonSingletonClass = new NonSingletonClazz() + when: + ApplicationContext.skipMissingImplementations = true + ApplicationContext.injectIntoNonSingleton(nonSingletonClass) + then: + noExceptionThrown() + } + def "implementation injection test"() { given: def injectedValue = "DogCow" @@ -133,6 +204,25 @@ class ApplicationContextTest extends Specification { Files.deleteIfExists(directoryPath) } + def "registering an unsupported injection class"() { + given: + def injectedValue = "DogCow" + def injectionInstantiation = new InjectionDeclaringClass() + + TestApplicationContext.register(List.class, injectionInstantiation) + // notice above that I'm registering the injectionInstantiation object as a List class. + // injectionInstantiation is of class InjectionDeclaringClass, + // and InjectionDeclaringClass doesn't implement List (it only implements AFieldInterface). + TestApplicationContext.register(String.class, injectedValue) + + when: + TestApplicationContext.injectRegisteredImplementations() + injectionInstantiation.getAField() + + then: + noExceptionThrown() + } + class InjectionDeclaringClass { @Inject private String aField diff --git a/shared/src/testFixtures/groovy/gov/hhs/cdc/trustedintermediary/context/TestApplicationContext.groovy b/shared/src/testFixtures/groovy/gov/hhs/cdc/trustedintermediary/context/TestApplicationContext.groovy index 24c37e2a8..735e78fc7 100644 --- a/shared/src/testFixtures/groovy/gov/hhs/cdc/trustedintermediary/context/TestApplicationContext.groovy +++ b/shared/src/testFixtures/groovy/gov/hhs/cdc/trustedintermediary/context/TestApplicationContext.groovy @@ -20,7 +20,8 @@ class TestApplicationContext extends ApplicationContext { } def static injectRegisteredImplementations() { - injectRegisteredImplementations(true) + skipMissingImplementations = true + ApplicationContext.injectRegisteredImplementations() } def static addEnvironmentVariable(String key, String value) {