diff --git a/src/__tests__/auto-injectable.test.ts b/src/__tests__/auto-injectable.test.ts index 84c0d1d..c55df81 100644 --- a/src/__tests__/auto-injectable.test.ts +++ b/src/__tests__/auto-injectable.test.ts @@ -1,6 +1,7 @@ import {autoInjectable, injectable, singleton} from "../decorators"; import {instance as globalContainer} from "../dependency-container"; import injectAll from "../decorators/inject-all"; +import {constructor} from "../types"; afterEach(() => { globalContainer.reset(); @@ -179,3 +180,26 @@ test("@autoInjectable resolves multiple transient dependencies", () => { expect(bar.foo!.length).toBe(1); expect(bar.foo![0]).toBeInstanceOf(Foo); }); + +test("@autoInjectable with factory allows factory to see target class", () => { + class Bar { + public target?: constructor; + } + @autoInjectable() + class Foo { + constructor(public myBar?: Bar) {} + } + + globalContainer.register(Bar, { + useFactory: (_, target) => { + const bar = new Bar(); + bar.target = target; + return bar; + } + }); + + const myFoo = new Foo(); + + // It is impossible to compare Foo type to target, because Foo here is extended by @autoInjectable + expect(myFoo.myBar!.target!.name).toBe(Foo.name); +}); diff --git a/src/decorators/auto-injectable.ts b/src/decorators/auto-injectable.ts index b123926..450d6d2 100644 --- a/src/decorators/auto-injectable.ts +++ b/src/decorators/auto-injectable.ts @@ -23,10 +23,10 @@ function autoInjectable(): (target: constructor) => any { try { if (isTokenDescriptor(type)) { return type.multiple - ? globalContainer.resolveAll(type.token) - : globalContainer.resolve(type.token); + ? globalContainer.resolveAll(type.token, target) + : globalContainer.resolve(type.token, target); } - return globalContainer.resolve(type); + return globalContainer.resolve(type, target); } catch (e) { const argIndex = index + args.length; diff --git a/src/dependency-container.ts b/src/dependency-container.ts index c2759b1..08570d9 100644 --- a/src/dependency-container.ts +++ b/src/dependency-container.ts @@ -145,9 +145,10 @@ class InternalDependencyContainer implements DependencyContainer { * Resolve a token into an instance * * @param token {InjectionToken} The dependency token + * @param target {constructor} Constructor resolving the dependency token * @return {T} An instance of the dependency */ - public resolve(token: InjectionToken): T { + public resolve(token: InjectionToken, target?: constructor): T { const registration = this.getRegistration(token); if (!registration && isNormalToken(token)) { @@ -155,23 +156,27 @@ class InternalDependencyContainer implements DependencyContainer { } if (registration) { - return this.resolveRegistration(registration); + return this.resolveRegistration(registration, target); } // No registration for this token, but since it's a constructor, return an instance return this.construct(>token); } - private resolveRegistration(registration: Registration): T { + private resolveRegistration( + registration: Registration, + target?: constructor + ): T { if (isValueProvider(registration.provider)) { return registration.provider.useValue; } else if (isTokenProvider(registration.provider)) { return registration.options.singleton ? registration.instance || (registration.instance = this.resolve( - registration.provider.useToken + registration.provider.useToken, + target )) - : this.resolve(registration.provider.useToken); + : this.resolve(registration.provider.useToken, target); } else if (isClassProvider(registration.provider)) { return registration.options.singleton ? registration.instance || @@ -180,13 +185,16 @@ class InternalDependencyContainer implements DependencyContainer { )) : this.construct(registration.provider.useClass); } else if (isFactoryProvider(registration.provider)) { - return registration.provider.useFactory(this); + return registration.provider.useFactory(this, target); } else { return this.construct(registration.provider); } } - public resolveAll(token: InjectionToken): T[] { + public resolveAll( + token: InjectionToken, + parent?: constructor + ): T[] { const registration = this.getAllRegistrations(token); if (!registration && isNormalToken(token)) { @@ -194,7 +202,9 @@ class InternalDependencyContainer implements DependencyContainer { } if (registration) { - return registration.map(item => this.resolveRegistration(item)); + return registration.map(item => + this.resolveRegistration(item, parent) + ); } // No registration for this token, but since it's a constructor, return an instance @@ -261,10 +271,10 @@ class InternalDependencyContainer implements DependencyContainer { const params = paramInfo.map(param => { if (isTokenDescriptor(param)) { return param.multiple - ? this.resolveAll(param.token) - : this.resolve(param.token); + ? this.resolveAll(param.token, ctor) + : this.resolve(param.token, ctor); } - return this.resolve(param); + return this.resolve(param, ctor); }); return new ctor(...params); diff --git a/src/factories/factory-function.ts b/src/factories/factory-function.ts index ebd9b58..7f218e6 100644 --- a/src/factories/factory-function.ts +++ b/src/factories/factory-function.ts @@ -1,5 +1,9 @@ import DependencyContainer from "../types/dependency-container"; +import {constructor} from "../types"; -type FactoryFunction = (dependencyContainer: DependencyContainer) => T; +type FactoryFunction = ( + dependencyContainer: DependencyContainer, + target?: constructor +) => T; export default FactoryFunction; diff --git a/src/factories/instance-caching-factory.ts b/src/factories/instance-caching-factory.ts index 8c0eb64..654a20b 100644 --- a/src/factories/instance-caching-factory.ts +++ b/src/factories/instance-caching-factory.ts @@ -1,13 +1,17 @@ import DependencyContainer from "../types/dependency-container"; import FactoryFunction from "./factory-function"; +import {constructor} from "../types"; export default function instanceCachingFactory( factoryFunc: FactoryFunction ): FactoryFunction { let instance: T; - return (dependencyContainer: DependencyContainer) => { + return ( + dependencyContainer: DependencyContainer, + target?: constructor + ) => { if (instance == undefined) { - instance = factoryFunc(dependencyContainer); + instance = factoryFunc(dependencyContainer, target); } return instance; }; diff --git a/src/factories/predicate-aware-class-factory.ts b/src/factories/predicate-aware-class-factory.ts index 0d84ffb..492839c 100644 --- a/src/factories/predicate-aware-class-factory.ts +++ b/src/factories/predicate-aware-class-factory.ts @@ -3,20 +3,34 @@ import constructor from "../types/constructor"; import FactoryFunction from "./factory-function"; export default function predicateAwareClassFactory( - predicate: (dependencyContainer: DependencyContainer) => boolean, + predicate: ( + dependencyContainer: DependencyContainer, + target?: constructor + ) => boolean, trueConstructor: constructor, falseConstructor: constructor, useCaching = true ): FactoryFunction { let instance: T; let previousPredicate: boolean; - return (dependencyContainer: DependencyContainer) => { - const currentPredicate = predicate(dependencyContainer); + return ( + dependencyContainer: DependencyContainer, + target?: constructor + ) => { + const currentPredicate = predicate(dependencyContainer, target); if (!useCaching || previousPredicate !== currentPredicate) { if ((previousPredicate = currentPredicate)) { - instance = dependencyContainer.resolve(trueConstructor); + if (target) { + instance = dependencyContainer.resolve(trueConstructor, target); + } else { + instance = dependencyContainer.resolve(trueConstructor); + } } else { - instance = dependencyContainer.resolve(falseConstructor); + if (target) { + instance = dependencyContainer.resolve(falseConstructor, target); + } else { + instance = dependencyContainer.resolve(falseConstructor); + } } } return instance; diff --git a/src/providers/factory-provider.ts b/src/providers/factory-provider.ts index f6ba34c..44a849e 100644 --- a/src/providers/factory-provider.ts +++ b/src/providers/factory-provider.ts @@ -1,5 +1,6 @@ import DependencyContainer from "../types/dependency-container"; import Provider from "./provider"; +import {constructor} from "../types"; /** * Provide a dependency using a factory. @@ -7,7 +8,10 @@ import Provider from "./provider"; * you need instance caching, your factory method must implement it. */ export default interface FactoryProvider { - useFactory: (dependencyContainer: DependencyContainer) => T; + useFactory: ( + dependencyContainer: DependencyContainer, + target?: constructor + ) => T; } export function isFactoryProvider( diff --git a/src/types/dependency-container.ts b/src/types/dependency-container.ts index be327b8..e2939cf 100644 --- a/src/types/dependency-container.ts +++ b/src/types/dependency-container.ts @@ -39,7 +39,9 @@ export default interface DependencyContainer { instance: T ): DependencyContainer; resolve(token: InjectionToken): T; + resolve(token: InjectionToken, parent: constructor): T; resolveAll(token: InjectionToken): T[]; + resolveAll(token: InjectionToken, parent: constructor): T[]; isRegistered(token: InjectionToken): boolean; reset(): void; createChildContainer(): DependencyContainer;