Skip to content

Commit

Permalink
Explicitly check dtype in assert_array_equal. Adjust a few test cas…
Browse files Browse the repository at this point in the history
…es that had the wrong dtype expectation (`np.arange` usually uses a different dtype from `np.zeros`).

PiperOrigin-RevId: 726228211
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Feb 12, 2025
1 parent 1ee10cb commit e190177
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,12 @@ def test_reshape_padding(self, strict: bool):
axes = jax.sharding.PartitionSpec(
'x',
)
tree = {'x': test_utils.create_sharded_array(np.arange(8), mesh, axes)}
dtype = np.float32
tree = {
'x': test_utils.create_sharded_array(
np.arange(8, dtype=dtype), mesh, axes
)
}
restore_args = {
'x': ArrayRestoreArgs(
mesh=mesh, mesh_axes=axes, global_shape=(16,), strict=strict
Expand All @@ -1533,7 +1538,11 @@ def test_reshape_padding(self, strict: bool):
)
expected = {
'x': test_utils.create_sharded_array(
np.concatenate((np.arange(8), np.zeros(8))), mesh, axes
np.concatenate(
(np.arange(8, dtype=dtype), np.zeros(8, dtype=dtype))
),
mesh,
axes,
)
}
self.validate_restore(expected, restored)
Expand Down
2 changes: 2 additions & 0 deletions checkpoint/orbax/checkpoint/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def f(arr):

def assert_array_equal(testclass, v_expected, v_actual):
"""Asserts that two arrays are equal."""
if hasattr(v_expected, 'dtype'):
testclass.assertEqual(v_expected.dtype, v_actual.dtype)
testclass.assertIsInstance(v_actual, type(v_expected))
if isinstance(v_expected, jax.Array):
testclass.assertEqual(
Expand Down

0 comments on commit e190177

Please sign in to comment.