Skip to content

Commit

Permalink
Add named_product test utility to generate test cases with names. (#…
Browse files Browse the repository at this point in the history
…18626)

This is the be used with `absl.parameterized.named_parameters` and instead of `absl.parameterized.product`.
It creates testcases that have intuitive names.

Tests in `numpy_test.py` and `dtypes_test.py` that were using `product` now use `named_product`.

Also standardized the way we create `KerasTensors` in `numpy_test.py` for readability.
  • Loading branch information
hertschuh authored Oct 16, 2023
1 parent aad448e commit 6e05182
Show file tree
Hide file tree
Showing 4 changed files with 739 additions and 562 deletions.
21 changes: 14 additions & 7 deletions keras/backend/common/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@
from keras.backend.common.variables import ALLOWED_DTYPES
from keras.backend.torch.core import to_torch_dtype
from keras.testing import test_case
from keras.testing.test_utils import named_product


class DtypesTest(test_case.TestCase, parameterized.TestCase):
"""Test the dtype to verify that the behavior matches JAX."""

if backend.backend() == "torch":
# TODO: torch doesn't support uint64.
ALL_DTYPES = [
str(to_torch_dtype(x)).split(".")[-1]
for x in ALLOWED_DTYPES
if x not in ["string", "uint64"]
] + [None]
ALL_DTYPES = []
for x in ALLOWED_DTYPES:
if x not in ["string", "uint64"]:
x = str(to_torch_dtype(x)).split(".")[-1]
if x not in ALL_DTYPES: # skip duplicates created by remapping
ALL_DTYPES.append(x)
ALL_DTYPES += [None]
else:
ALL_DTYPES = [x for x in ALLOWED_DTYPES if x != "string"] + [None]

Expand All @@ -32,15 +35,19 @@ def tearDown(self) -> None:
self.jax_enable_x64.__exit__(None, None, None)
return super().tearDown()

@parameterized.product(dtype1=ALL_DTYPES, dtype2=[bool, int, float])
@parameterized.named_parameters(
named_product(dtype1=ALL_DTYPES, dtype2=[bool, int, float])
)
def test_result_type_with_python_scalar_types(self, dtype1, dtype2):
import jax.numpy as jnp

out = backend.result_type(dtype1, dtype2)
expected = jnp.result_type(dtype1, dtype2).name
self.assertEqual(out, expected)

@parameterized.product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES)
@parameterized.named_parameters(
named_product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES)
)
def test_result_type_with_tensor(self, dtype1, dtype2):
import jax.numpy as jnp

Expand Down
Loading

0 comments on commit 6e05182

Please sign in to comment.