Skip to content

Commit

Permalink
Add test for zero-dim tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
catherio committed Jan 15, 2025
1 parent 78b8b2d commit dfa372b
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion test/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,53 @@ def test_tensor_on_gpu(self):
if not torch.cuda.is_available():
self.skipTest("CUDA not available")
tensor = torch.tensor([1, 2, 3]).cuda()
self.assertEqual(orjson.dumps(tensor, option=orjson.OPT_SERIALIZE_NUMPY), b'[1,2,3]')
self.assertEqual(orjson.dumps(tensor, option=orjson.OPT_SERIALIZE_NUMPY), b'[1,2,3]')

def test_tensor_zero_dim(self):
"""
Test 0-dimensional tensors are properly serialized as scalar values
"""
# Test float scalar tensor
tensor_float = torch.tensor(0.03)
self.assertEqual(orjson.dumps(tensor_float, option=orjson.OPT_SERIALIZE_NUMPY), b'0.03')

# Test int scalar tensor
tensor_int = torch.tensor(42)
self.assertEqual(orjson.dumps(tensor_int, option=orjson.OPT_SERIALIZE_NUMPY), b'42')

# Test in a nested structure
data = {
"scalar_float": torch.tensor(0.03),
"scalar_int": torch.tensor(42),
"array": torch.tensor([1, 2, 3]),
}
self.assertEqual(
orjson.dumps(data, option=orjson.OPT_SERIALIZE_NUMPY),
b'{"scalar_float":0.03,"scalar_int":42,"array":[1,2,3]}'
)

def test_tensor_special_values(self):
"""
Test that special values (nan, inf) are properly serialized
"""
# Test nan
tensor_nan = torch.tensor(float('nan'))
self.assertEqual(orjson.dumps(tensor_nan, option=orjson.OPT_SERIALIZE_NUMPY), b'NaN')

# Test inf
tensor_inf = torch.tensor(float('inf'))
self.assertEqual(orjson.dumps(tensor_inf, option=orjson.OPT_SERIALIZE_NUMPY), b'Infinity')
tensor_neg_inf = torch.tensor(float('-inf'))
self.assertEqual(orjson.dumps(tensor_neg_inf, option=orjson.OPT_SERIALIZE_NUMPY), b'-Infinity')

# Test in a nested structure
data = {
"nan": torch.tensor(float('nan')),
"inf": torch.tensor(float('inf')),
"neg_inf": torch.tensor(float('-inf')),
"mixed": torch.tensor([1.0, float('nan'), float('inf'), float('-inf')]),
}
self.assertEqual(
orjson.dumps(data, option=orjson.OPT_SERIALIZE_NUMPY),
b'{"nan":NaN,"inf":Infinity,"neg_inf":-Infinity,"mixed":[1.0,NaN,Infinity,-Infinity]}'
)

0 comments on commit dfa372b

Please sign in to comment.