Skip to content

Commit

Permalink
Add support for declearn 'Vector' in 'Serializer'.
Browse files Browse the repository at this point in the history
  • Loading branch information
pandrey-fr committed May 9, 2023
1 parent a1f9285 commit 0c46f4b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
9 changes: 6 additions & 3 deletions fedbiomed/common/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

"""MsgPack serialization utils, wrapped into a namespace class."""

from math import ceil
from typing import Any

import msgpack
import numpy as np
import torch
from declearn.model.api import Vector

from math import ceil
from fedbiomed.common.exceptions import FedbiomedTypeError
from fedbiomed.common.logger import logger

Expand Down Expand Up @@ -97,7 +98,6 @@ def _default(obj: Any) -> Any:
return {"__type__": "int", "value": obj.to_bytes(
length=ceil(obj.bit_length()/8),
byteorder="big")}

if isinstance(obj, tuple):
return {"__type__": "tuple", "value": list(obj)}
if isinstance(obj, np.ndarray):
Expand All @@ -110,8 +110,9 @@ def _default(obj: Any) -> Any:
obj = obj.cpu().numpy()
spec = [obj.tobytes(), obj.dtype.name, list(obj.shape)]
return {"__type__": "torch.Tensor", "value": spec}
if isinstance(obj, Vector):
return {"__type__": "Vector", "value": obj.coefs}
# Raise on unsupported types.

raise FedbiomedTypeError(
f"Cannot serialize object of type '{type(obj)}'."
)
Expand All @@ -136,6 +137,8 @@ def _object_hook(obj: Any) -> Any:
data, dtype, shape = obj["value"]
array = np.frombuffer(data, dtype=dtype).reshape(shape).copy()
return torch.from_numpy(array)
if objtype == "Vector":
return Vector.build(obj["value"])
logger.warning(
"Encountered an object that cannot be properly deserialized."
)
Expand Down
34 changes: 32 additions & 2 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import numpy as np
import torch
from declearn.model.sklearn import NumpyVector
from declearn.model.torch import TorchVector

from fedbiomed.common.exceptions import FedbiomedTypeError
from fedbiomed.common.logger import logger
Expand Down Expand Up @@ -102,7 +104,23 @@ def test_serializer_05_file_dump_load(self) -> None:
all(np.all(a == b) for a, b in zip(data["arrays"], datb["arrays"]))
)

def test_serializer_06_raises_dump_error(self) -> None:
def test_serializer_06_numpy_vector(self) -> None:
"""Test that 'Serializer' operates well on 'NumpyVector' instances."""
vector = NumpyVector({
"a": np.random.normal(size=(32, 128)),
"b": np.random.normal(size=(32,)),
})
self.assert_serializable(vector)

def test_serializer_07_torch_vector(self) -> None:
"""Test that 'Serializer' operates well on 'TorchVector' instances."""
vector = TorchVector({
"a": torch.randn(size=(32, 128)),
"b": torch.randn(size=(32,)),
})
self.assert_serializable(vector)

def test_serializer_08_raises_dump_error(self) -> None:
"""Test that 'Serializer.dumps' raises the expected error."""

class UnsupportedType:
Expand All @@ -111,7 +129,7 @@ class UnsupportedType:
with self.assertRaises(FedbiomedTypeError):
Serializer.dumps(UnsupportedType())

def test_serializer_07_warns_load_error(self) -> None:
def test_serializer_09_warns_load_error(self) -> None:
"""Test that 'Serializer.loads' logs the expected warning."""
# Build a dict that looks like the specification for a non-standard
# type dump (e.g. numpy array, torch tensor...).
Expand All @@ -123,6 +141,18 @@ def test_serializer_07_warns_load_error(self) -> None:
p_logger.warning.assert_called_once()
self.assertDictEqual(obj, bis)

def test_serializer_10_nested_dict(self) -> None:
"""Test that 'Serializer' supports operating on nested structures."""
obj = {
"int": 0.,
"vec": TorchVector({"a": torch.randn(size=(4, 8))}),
"dct": {
"vec": NumpyVector({"a": np.random.normal(size=(4, 8))}),
"str": "test",
}
}
self.assert_serializable(obj)


if __name__ == "__main__":
unittest.main()

0 comments on commit 0c46f4b

Please sign in to comment.