Skip to content

Commit

Permalink
Work on sparse matrix serialisation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-wilkins committed Oct 23, 2024
1 parent d20e9f3 commit ebac04c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
38 changes: 35 additions & 3 deletions sasdata/quantities/numerical_encoding.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
from scipy.sparse import coo_matrix, csr_matrix, csc_matrix, coo_array, csr_array, csc_array

import base64
import struct


def numerical_encode(obj: int | float | np.ndarray):
def numerical_encode(obj: int | float | np.ndarray | coo_matrix | coo_array | csr_matrix | csr_array | csc_matrix | csc_array):

if isinstance(obj, int):
return {"type": "int",
Expand All @@ -22,11 +23,38 @@ def numerical_encode(obj: int | float | np.ndarray):
"shape": list(obj.shape)
}

elif isinstance(obj, (coo_matrix, coo_array, csr_matrix, csr_array, csc_matrix, csc_array)):

output = {
"type": obj.__class__.__name__, # not robust to name changes, but more concise
"dtype": obj.dtype.str,
"shape": list(obj.shape)
}

if isinstance(obj, (coo_array, coo_matrix)):

output["data"] = numerical_encode(obj.data)
output["coords"] = [numerical_encode(coord) for coord in obj.coords]


elif isinstance(obj, (csr_array, csr_matrix)):
pass


elif isinstance(obj, (csc_array, csc_matrix)):

pass


return output

else:
raise TypeError(f"Cannot serialise object of type: {type(obj)}")

def numerical_decode(data: dict[str, str | int | list[int]]) -> int | float | np.ndarray:
match data["type"]:
def numerical_decode(data: dict[str, str | int | list[int]]) -> int | float | np.ndarray | coo_matrix | coo_array | csr_matrix | csr_array | csc_matrix | csc_array:
obj_type = data["type"]

match obj_type:
case "int":
return int(data["value"])

Expand All @@ -38,3 +66,7 @@ def numerical_decode(data: dict[str, str | int | list[int]]) -> int | float | np
dtype = np.dtype(data["dtype"])
shape = tuple(data["shape"])
return np.frombuffer(value, dtype=dtype).reshape(*shape)

case _:
raise ValueError(f"Cannot decode objects of type '{obj_type}'")

9 changes: 9 additions & 0 deletions sasdata/quantities/test_numerical_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,12 @@ def test_numpy_dtypes_encode_decode(dtype):
decoded = numerical_decode(encoded)

assert decoded.dtype == test_matrix.dtype

@pytest.mark.parametrize("dtype", [int, float, complex])
@pytest.mark.parametrize("shape, n, m", [
((8, 8), (1,3,5),(2,5,7)),
((6, 8), (1,0,5),(0,5,0)),
((6, 1), (1, 0, 5), (0, 0, 0)),
])
def test_coo_matrix_encode_decode(shape, n, m, dtype):
test_matrix = np.arange()

0 comments on commit ebac04c

Please sign in to comment.