diff --git a/.github/workflows/pytest_and_autopublish.yml b/.github/workflows/pytest_and_autopublish.yml index c7301fac..35f24e6c 100644 --- a/.github/workflows/pytest_and_autopublish.yml +++ b/.github/workflows/pytest_and_autopublish.yml @@ -39,7 +39,6 @@ jobs: # * sweep_utils_test: Depends on kxm # * lpips_test: Missing VGG weights # * partial_loader_test: Orbax partial checkpoint loader not yet open-sourced (TODO(epot): Restore) - # * typing tests: Not yet supported due to typeguard version issues. - name: Run core tests run: | pytest -vv -n auto \ @@ -48,9 +47,7 @@ jobs: --ignore=kauldron/xm/ \ --ignore=kauldron/metrics/lpips_test.py \ --ignore=kauldron/checkpoints/partial_loader_test.py \ - --ignore=kauldron/utils/sweep_utils_test.py \ - --ignore=kauldron/typing/shape_spec_test.py \ - --ignore=kauldron/typing/type_check_test.py + --ignore=kauldron/utils/sweep_utils_test.py # Auto-publish when version is increased publish-job: diff --git a/kauldron/typing/shape_spec.py b/kauldron/typing/shape_spec.py index cf252ef4..7403dabd 100644 --- a/kauldron/typing/shape_spec.py +++ b/kauldron/typing/shape_spec.py @@ -83,8 +83,8 @@ def _assert_caller_is_typechecked_func() -> None: if stack[i + 1].function != "_reraise_with_shape_info": caller_name = stack[i].function raise AssertionError( - "Dim and Shape not yet supported due to `typeguard` issue." - f" Raised in {caller_name!r}" + "Dim and Shape only work inside of @typechecked functions. But" + f" {caller_name!r} lacks @typechecked." ) diff --git a/kauldron/typing/type_check.py b/kauldron/typing/type_check.py index 1b881c1b..a04ede42 100644 --- a/kauldron/typing/type_check.py +++ b/kauldron/typing/type_check.py @@ -46,13 +46,10 @@ def check_type( expected_type: Any, ) -> None: """Ensure that value matches expected_type, alias for typeguard.check_type.""" - if True: # Typeguard not yet supported - return return typeguard.check_type(value, expected_type) -exc_cls = Exception -class TypeCheckError(exc_cls): +class TypeCheckError(typeguard.TypeCheckError): """Indicates a runtime typechecking error from the @typechecked decorator.""" def __init__( @@ -112,9 +109,6 @@ def _annotation_repr(ann: Any) -> str: def typechecked(fn): """Decorator to enable runtime type-checking and shape-checking.""" - if True: # Typeguard not yet supported - return fn - if hasattr(fn, "__wrapped__"): raise AssertionError("@typechecked should be the innermost decorator") @@ -126,17 +120,36 @@ def _reraise_with_shape_info(*args, _typecheck: bool = True, **kwargs): # typchecking disabled globally or locally -> just return fn(...) return fn(*args, **kwargs) - # Find either the first Python wrapper or the actual function - python_func = inspect.unwrap(fn, stop=lambda f: hasattr(f, "__code__")) + sig = inspect.signature(fn) + bound_args = sig.bind(*args, **kwargs) # manually reproduce the functionality of typeguard.typechecked, so that # we get access to the returnvalue of the function localns = sys._getframe(1).f_locals # pylint: disable=protected-access - memo = typeguard.CallMemo(python_func, localns, args=args, kwargs=kwargs) + globalns = fn.__globals__ # pylint: disable=protected-access + memo = typeguard.TypeCheckMemo(globalns, localns) retval = _undef + + annotations = typing.get_type_hints( + fn, + globalns=globalns, + localns=localns, + include_extras=True, + ) + annotated_arguments = { + k: (v, annotations[k]) + for k, v in bound_args.arguments.items() + if k in annotations + } + try: - typeguard.check_argument_types(memo) + typeguard._functions.check_argument_types( # pylint: disable=protected-access + fn.__name__, annotated_arguments, memo=memo + ) retval = fn(*args, **kwargs) - typeguard.check_return_type(retval, memo) + if "return" in annotations: + typeguard._functions.check_return_type( # pylint: disable=protected-access + fn.__name__, retval, annotations["return"], memo + ) return retval except typeguard.TypeCheckError as e: # Use function signature to construct a complete list of named arguments @@ -144,7 +157,6 @@ def _reraise_with_shape_info(*args, _typecheck: bool = True, **kwargs): bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() - annotations = {k: p.annotation for k, p in sig.parameters.items()} # TODO(klausg): filter the stacktrace to exclude all the typechecking raise TypeCheckError( str(e), @@ -396,7 +408,7 @@ def _custom_dataclass_checker( dataclass_as_typed_dict.__module__ = origin_type.__module__ values = {k.name: getattr(value, k.name) for k in fields} try: - return typeguard.check_type( + return typeguard.check_type_internal( dataclass_as_typed_dict(**values), dataclass_as_typed_dict, memo=memo, @@ -469,3 +481,7 @@ def add_custom_checker_lookup_fn(lookup_fn): break else: # prepend checker_lookup_fns[:0] = [lookup_fn] + + +add_custom_checker_lookup_fn(_array_spec_checker_lookup) +add_custom_checker_lookup_fn(_dataclass_checker_lookup) diff --git a/kauldron/typing/type_check_test.py b/kauldron/typing/type_check_test.py index 9bd172f3..0fd65512 100644 --- a/kauldron/typing/type_check_test.py +++ b/kauldron/typing/type_check_test.py @@ -16,7 +16,7 @@ import jaxtyping as jt from kauldron.typing import Float, TypeCheckError, typechecked # pylint: disable=g-multiple-import,g-importing-member -from kauldron.typing import type_check +from kauldron.typing import type_check # pylint: disable=g-bad-import-order import numpy as np import pytest diff --git a/pyproject.toml b/pyproject.toml index fd705c91..3f512613 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,9 +42,7 @@ dependencies = [ "tfds-nightly", # TODO(klausg): switch back to tensorflow_datasets>=4.9.7 # once released: https://github.com/tensorflow/datasets/commit/d4bfd59863c6cb5b64d043b7cb6ab566e7d92440 "tqdm", - # TODO(klausg): Restore typeguard or switch to something else - # closest match to the internal typeguard - # "typeguard@git+https://github.com/agronholm/typeguard@0dd7f7510b7c694e66a0d17d1d58d185125bad5d", + "typeguard>=4.4.1", "typing_extensions", "xmanager", # lazy deps (should ideally remove those)