Skip to content

Commit

Permalink
Handle numpy arrays when serializing (#579)
Browse files Browse the repository at this point in the history
Fixes #578
  • Loading branch information
DominicOram authored Jul 31, 2024
1 parent ca718d8 commit 8733c72
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
10 changes: 7 additions & 3 deletions src/blueapi/messaging/stomptemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from threading import Event
from typing import Any

import orjson
import stomp
from pydantic import parse_obj_as
from stomp.exception import ConnectFailedException
Expand Down Expand Up @@ -122,17 +123,20 @@ def send(
correlation_id: str | None = None,
) -> None:
self._send_str(
destination, json.dumps(serialize(obj)), on_reply, correlation_id
destination,
orjson.dumps(serialize(obj), option=orjson.OPT_SERIALIZE_NUMPY),
on_reply,
correlation_id,
)

def _send_str(
self,
destination: str,
message: str,
message: bytes,
on_reply: MessageListener | None = None,
correlation_id: str | None = None,
) -> None:
LOGGER.info(f"SENDING {message} to {destination}")
LOGGER.info(f"SENDING {message!r} to {destination}")

headers: dict[str, Any] = {"JMSType": "TextMessage"}
if on_reply is not None:
Expand Down
10 changes: 9 additions & 1 deletion tests/messaging/test_stomptemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from unittest.mock import ANY, MagicMock, call, patch

import numpy as np
import pytest
from pydantic import BaseModel, BaseSettings, Field
from stomp import Connection
Expand Down Expand Up @@ -148,7 +149,12 @@ class Foo(BaseModel):
@pytest.mark.stomp
@pytest.mark.parametrize(
"message,message_type",
[("test", str), (1, int), (Foo(a=1, b="test"), Foo)],
[
("test", str),
(1, int),
(Foo(a=1, b="test"), Foo),
(np.array([1, 2, 3]), list),
],
)
def test_deserialization(
template: MessagingTemplate, test_queue: str, message: Any, message_type: type
Expand All @@ -163,6 +169,8 @@ def server(ctx: MessageContext, message: message_type) -> None: # type: ignore
reply = template.send_and_receive(test_queue, message, message_type).result(
timeout=_TIMEOUT
)
if type(message) == np.ndarray:
message = message.tolist()
assert reply == message


Expand Down

0 comments on commit 8733c72

Please sign in to comment.