diff --git a/test/scripts/fuzzer.py b/test/scripts/fuzzer.py index 6207f4f..fe7de27 100755 --- a/test/scripts/fuzzer.py +++ b/test/scripts/fuzzer.py @@ -263,7 +263,7 @@ def compare_numpy( ) return False - num_differences = (~np.isclose(expected, found, rtol=rtol)).sum() + num_differences = (~np.isclose(expected, found, rtol=rtol, equal_nan=True)).sum() if num_differences != 0: logging.warning( "[%d] %s, %s (%d nnz): FAIL! Found %d differences!", @@ -333,7 +333,7 @@ def compare_csr( ) return False - num_differences = (~np.isclose(expected.data, found.data, rtol=rtol)).sum() + num_differences = (~np.isclose(expected.data, found.data, rtol=rtol, equal_nan=True)).sum() if num_differences != 0: logging.warning( "[%d] %s, %s (%d nnz): FAIL! Found %d differences!",