diff --git a/tests/utils/test_data_type_utils.py b/tests/utils/test_data_type_utils.py index 94b1e61e..0714001a 100644 --- a/tests/utils/test_data_type_utils.py +++ b/tests/utils/test_data_type_utils.py @@ -29,5 +29,6 @@ def test_str_to_torch_dtype_exit(): def test_get_torch_dtype(): for t in dtype_dict.keys(): + # When passed a string, it gets converted to torch.dtype assert data_type_utils.get_torch_dtype(t) == dtype_dict.get(t) assert data_type_utils.get_torch_dtype(dtype_dict.get(t)) == dtype_dict.get(t)