diff --git a/Makefile b/Makefile index cbed5e7..e7dfb2e 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,8 @@ check: - make -j2 types style + make -j3 test types style + +test: + python -m pytest types: mypy --strict --implicit-optional . diff --git a/mr_proper/utils/ast.py b/mr_proper/utils/ast.py index f7dc3ef..cd18c0e 100644 --- a/mr_proper/utils/ast.py +++ b/mr_proper/utils/ast.py @@ -62,7 +62,8 @@ def is_imported_from_stdlib(name: str, file_ast_tree: ast.Module) -> Optional[bo def get_local_var_names_from_funcdef(funcdef_node: AnyFuncdef) -> List[str]: local_vars_names: List[str] = [] for assign_node in get_nodes_from_funcdef_body(funcdef_node, [ast.Assign]): - local_vars_names += [t.id for t in assign_node.targets if isinstance(t, ast.Name)] + for target in assign_node.targets: + local_vars_names += [n.id for n in ast.walk(target) if isinstance(n, ast.Name)] for annassign_node in get_nodes_from_funcdef_body(funcdef_node, [ast.AnnAssign]): if isinstance(annassign_node.target, ast.Name): local_vars_names.append(annassign_node.target.id) @@ -74,7 +75,7 @@ def get_local_var_names_from_funcdef(funcdef_node: AnyFuncdef) -> List[str]: for n in ast.walk(funcdef_node) if isinstance(n, ast.ExceptHandler) and n.name } - return local_vars_names + return list(set(local_vars_names)) def get_local_var_names_from_loop(loop_node: Union[ast.comprehension, ast.For]) -> List[str]: diff --git a/tests/test_pure_checker.py b/tests/test_pure_checker.py new file mode 100644 index 0000000..fd82a19 --- /dev/null +++ b/tests/test_pure_checker.py @@ -0,0 +1,12 @@ +import ast + +from mr_proper.public_api import is_function_pure + + +def test_ok_for_destructive_assignment(): + funcdef = ast.parse(""" + def foo(a): + b, c = a + return b * c + """.strip()).body[0] + assert is_function_pure(funcdef)