Skip to content

Commit

Permalink
Support Python 3 kwonlyargs
Browse files Browse the repository at this point in the history
  • Loading branch information
drslump committed Sep 19, 2016
1 parent 8374a29 commit f71e7d4
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 21 deletions.
48 changes: 36 additions & 12 deletions di/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,37 @@ def __ne__(self, other):
return self.value != other


def injector(dependencies):
def get_callable_defaults(fn, follow_wrapped=False):
""" Helper function to extracts a map of name:default from the signature
of a function.
"""
try: # PY35
sign = inspect.signature(fn, follow_wrapped=follow_wrapped)
defaults = dict(
(p.name, p.annotation if p.default is Key else p.default)
for p in sign.parameters.values()
if p.default is not p.empty
)
except (TypeError, ValueError, AttributeError) as ex:
if follow_wrapped and not isinstance(ex, ValueError):
raise RuntimeError(
'injector is configured to follow wrapped methods but your Python '
'version does not support this feature')

try: # PY3
args, _, _, defaults, _, kwonlydefaults, _ = inspect.getfullargspec(fn)
except AttributeError: # PY2
args, _, _, defaults = inspect.getargspec(fn)
kwonlydefaults = None

defaults = dict(zip(reversed(args), reversed(defaults))) if defaults else {}
if kwonlydefaults:
defaults.update(kwonlydefaults)

return defaults


def injector(dependencies, warn=True, follow_wrapped=False):
""" Factory for the dependency injection decorator. It's meant to be
initialized with the map of dependencies to use on decorated functions.
Expand Down Expand Up @@ -114,23 +144,17 @@ def __init__(self, config=ConfigManager):
# Prepare the dependencies storage stack
deps_stack = [dependencies]

def wrapper(fn, warn=True):
# Extract default values for keyword arguments
args, varargs, keywords, defaults = inspect.getargspec(fn)
if defaults:
defaults = dict(zip(reversed(args), reversed(defaults)))
else:
defaults = {}

def wrapper(fn, __warn__=warn, follow_wrapped=follow_wrapped):
# Mapping for injectable values (classes used as default value)
mapping = {}
defaults = get_callable_defaults(fn, follow_wrapped=follow_wrapped)
for name, default in defaults.items():
if isinstance(default, Key):
mapping[name] = default.value
elif inspect.isclass(default):
mapping[name] = default

if warn and not mapping:
if __warn__ and not mapping:
warnings.warn('{0}: No injectable params found. You can safely remove the decorator.'.format(fn.__name__), stacklevel=2)
return fn

Expand All @@ -147,7 +171,7 @@ def inner(*args, **kwargs):
deps = deps_stack[-1]

# Adapt for deprecated property
if deps is not wrapper.dependencies:
if __warn__ and deps is not wrapper.dependencies:
warnings.warn('dependencies property is deprecated, please use patch/unpatch', stacklevel=2)
patch(wrapper.dependencies)
deps = wrapper.dependencies
Expand Down Expand Up @@ -223,7 +247,7 @@ def __new__(cls, name, bases, dct):
methods = ((k, v) for (k, v) in dct.items() if is_user_function(k, v))

for m, fn in methods:
dct[m] = inject_fn(fn, warn=False)
dct[m] = inject_fn(fn, __warn__=False)

return type.__new__(cls, name, bases, dct)

Expand Down
5 changes: 2 additions & 3 deletions tests/di_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,12 @@ def test_positional_arg(self):
foo(self)

def test_no_injectable_params(self):
foo = self.inject(lambda: True, warn=False)
foo = self.inject(lambda: True, __warn__=False)
foo() | should.be_True

foo = self.inject(lambda x: x, warn=False)
foo = self.inject(lambda x: x, __warn__=False)
foo(10) | should.be(10)

@pytest.mark.skipif(PY3, reason='Python 3 deprecates getargspec generating an extra warning')
def test_warns_when_unneeded(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
Expand Down
40 changes: 34 additions & 6 deletions tests/py3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
:license: see LICENSE for more details.
"""

from typing import Any

import unittest
from pyshould import should

Expand All @@ -19,14 +21,40 @@
KeyC = Key('C')


def test_py3_kwonly():
def test_py3_kwonlydefaults():

inject = injector({KeyA: 'A', KeyB: 'B', KeyC: 'C'})

# While we don't fix the inject decorator we'll receive an error
with should.throw(ValueError):
@inject
def foo(a, b=KeyB, *args, c=KeyC):
pass
@inject
def foo(a, b=KeyB, *, c=KeyC):
return (b, c)

foo(10) | should.eql(('B', 'C'))

@inject
def bar(a, *, b=KeyB):
return b

bar(10) | should.eql('B')


def test_py3_annotations():

class Foo: pass
class Bar: pass
class Baz: pass
class Qux: pass

inject = injector({Foo: Foo(), Bar: Bar(), Baz: Baz(), Qux: Qux()})

@inject
def foo(a: Foo = Key, b: Any = Bar, c: Baz = Baz, d=Qux):
return (a, b, c, d)

foo() | should.eql((
should.be_a(Foo),
should.be_a(Bar),
should.be_a(Baz),
should.be_a(Qux)
))

0 comments on commit f71e7d4

Please sign in to comment.