From f71e7d47946f8d5fe2a294b70f6f40a6b6ff7ab6 Mon Sep 17 00:00:00 2001 From: DrSlump Date: Mon, 19 Sep 2016 17:19:42 +0200 Subject: [PATCH] Support Python 3 kwonlyargs --- di/main.py | 48 +++++++++++++++++++++++++++++++++++------------ tests/di_tests.py | 5 ++--- tests/py3.py | 40 +++++++++++++++++++++++++++++++++------ 3 files changed, 72 insertions(+), 21 deletions(-) diff --git a/di/main.py b/di/main.py index f6b1e08..17342bc 100644 --- a/di/main.py +++ b/di/main.py @@ -71,7 +71,37 @@ def __ne__(self, other): return self.value != other -def injector(dependencies): +def get_callable_defaults(fn, follow_wrapped=False): + """ Helper function to extracts a map of name:default from the signature + of a function. + """ + try: # PY35 + sign = inspect.signature(fn, follow_wrapped=follow_wrapped) + defaults = dict( + (p.name, p.annotation if p.default is Key else p.default) + for p in sign.parameters.values() + if p.default is not p.empty + ) + except (TypeError, ValueError, AttributeError) as ex: + if follow_wrapped and not isinstance(ex, ValueError): + raise RuntimeError( + 'injector is configured to follow wrapped methods but your Python ' + 'version does not support this feature') + + try: # PY3 + args, _, _, defaults, _, kwonlydefaults, _ = inspect.getfullargspec(fn) + except AttributeError: # PY2 + args, _, _, defaults = inspect.getargspec(fn) + kwonlydefaults = None + + defaults = dict(zip(reversed(args), reversed(defaults))) if defaults else {} + if kwonlydefaults: + defaults.update(kwonlydefaults) + + return defaults + + +def injector(dependencies, warn=True, follow_wrapped=False): """ Factory for the dependency injection decorator. It's meant to be initialized with the map of dependencies to use on decorated functions. @@ -114,23 +144,17 @@ def __init__(self, config=ConfigManager): # Prepare the dependencies storage stack deps_stack = [dependencies] - def wrapper(fn, warn=True): - # Extract default values for keyword arguments - args, varargs, keywords, defaults = inspect.getargspec(fn) - if defaults: - defaults = dict(zip(reversed(args), reversed(defaults))) - else: - defaults = {} - + def wrapper(fn, __warn__=warn, follow_wrapped=follow_wrapped): # Mapping for injectable values (classes used as default value) mapping = {} + defaults = get_callable_defaults(fn, follow_wrapped=follow_wrapped) for name, default in defaults.items(): if isinstance(default, Key): mapping[name] = default.value elif inspect.isclass(default): mapping[name] = default - if warn and not mapping: + if __warn__ and not mapping: warnings.warn('{0}: No injectable params found. You can safely remove the decorator.'.format(fn.__name__), stacklevel=2) return fn @@ -147,7 +171,7 @@ def inner(*args, **kwargs): deps = deps_stack[-1] # Adapt for deprecated property - if deps is not wrapper.dependencies: + if __warn__ and deps is not wrapper.dependencies: warnings.warn('dependencies property is deprecated, please use patch/unpatch', stacklevel=2) patch(wrapper.dependencies) deps = wrapper.dependencies @@ -223,7 +247,7 @@ def __new__(cls, name, bases, dct): methods = ((k, v) for (k, v) in dct.items() if is_user_function(k, v)) for m, fn in methods: - dct[m] = inject_fn(fn, warn=False) + dct[m] = inject_fn(fn, __warn__=False) return type.__new__(cls, name, bases, dct) diff --git a/tests/di_tests.py b/tests/di_tests.py index 8147373..18a7ae4 100644 --- a/tests/di_tests.py +++ b/tests/di_tests.py @@ -109,13 +109,12 @@ def test_positional_arg(self): foo(self) def test_no_injectable_params(self): - foo = self.inject(lambda: True, warn=False) + foo = self.inject(lambda: True, __warn__=False) foo() | should.be_True - foo = self.inject(lambda x: x, warn=False) + foo = self.inject(lambda x: x, __warn__=False) foo(10) | should.be(10) - @pytest.mark.skipif(PY3, reason='Python 3 deprecates getargspec generating an extra warning') def test_warns_when_unneeded(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") diff --git a/tests/py3.py b/tests/py3.py index 85af31d..54402e4 100644 --- a/tests/py3.py +++ b/tests/py3.py @@ -9,6 +9,8 @@ :license: see LICENSE for more details. """ +from typing import Any + import unittest from pyshould import should @@ -19,14 +21,40 @@ KeyC = Key('C') -def test_py3_kwonly(): +def test_py3_kwonlydefaults(): inject = injector({KeyA: 'A', KeyB: 'B', KeyC: 'C'}) - # While we don't fix the inject decorator we'll receive an error - with should.throw(ValueError): - @inject - def foo(a, b=KeyB, *args, c=KeyC): - pass + @inject + def foo(a, b=KeyB, *, c=KeyC): + return (b, c) + + foo(10) | should.eql(('B', 'C')) + + @inject + def bar(a, *, b=KeyB): + return b + + bar(10) | should.eql('B') + + +def test_py3_annotations(): + + class Foo: pass + class Bar: pass + class Baz: pass + class Qux: pass + + inject = injector({Foo: Foo(), Bar: Bar(), Baz: Baz(), Qux: Qux()}) + + @inject + def foo(a: Foo = Key, b: Any = Bar, c: Baz = Baz, d=Qux): + return (a, b, c, d) + foo() | should.eql(( + should.be_a(Foo), + should.be_a(Bar), + should.be_a(Baz), + should.be_a(Qux) + ))