Skip to content

Commit

Permalink
Fix publish() a bytearray payload
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreF committed Apr 29, 2024
1 parent 29c1d43 commit 7795dcd
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/paho/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def _force_bytes(s: str | bytes) -> bytes:
return s


def _encode_payload(payload: str | bytes | bytearray | int | float | None) -> bytes:
def _encode_payload(payload: str | bytes | bytearray | int | float | None) -> bytes|bytearray:
if isinstance(payload, str):
return payload.encode("utf-8")

Expand Down Expand Up @@ -3368,7 +3368,7 @@ def _send_publish(
self,
mid: int,
topic: bytes,
payload: bytes = b"",
payload: bytes|bytearray = b"",
qos: int = 0,
retain: bool = False,
dup: bool = False,
Expand All @@ -3378,7 +3378,7 @@ def _send_publish(
# we assume that topic and payload are already properly encoded
if not isinstance(topic, bytes):
raise TypeError('topic must be bytes, not str')
if payload and not isinstance(payload, bytes):
if payload and not isinstance(payload, (bytes, bytearray)):
raise TypeError('payload must be bytes if set')

if self._sock is None:
Expand Down
3 changes: 2 additions & 1 deletion tests/paho_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ def gen_publish(topic, qos, payload=None, retain=False, dup=False, mid=0, proto_
pack_format = pack_format + "%ds"%(len(properties))

if payload is not None:
payload = payload.encode("utf-8")
if isinstance(payload, str):
payload = payload.encode("utf-8")
rl = rl + len(payload)
pack_format = pack_format + str(len(payload)) + "s"
else:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,59 @@ def on_connect(mqttc, obj, flags, rc):
packet_in = fake_broker.receive_packet(1)
assert not packet_in # Check connection is closed

@pytest.mark.parametrize("user_payload,sent_payload", [
("string", b"string"),
(b"byte", b"byte"),
(bytearray(b"bytearray"), b"bytearray"),
(42, b"42"),
(4.2, b"4.2"),
(None, b""),
])
def test_publish_various_payload(self, user_payload: client.PayloadType, sent_payload: bytes, fake_broker: FakeBroker) -> None:
mqttc = client.Client(
CallbackAPIVersion.VERSION2,
"test_publish_various_payload",
transport=fake_broker.transport,
)

mqttc.connect("localhost", fake_broker.port)
mqttc.loop_start()
mqttc.enable_logger()

try:
fake_broker.start()

connect_packet = paho_test.gen_connect(
"test_publish_various_payload", keepalive=60,
proto_ver=client.MQTTv311)
fake_broker.expect_packet("connect", connect_packet)

connack_packet = paho_test.gen_connack(rc=0)
count = fake_broker.send_packet(connack_packet)
assert count # Check connection was not closed
assert count == len(connack_packet)

mqttc.publish("test", user_payload)

publish_packet = paho_test.gen_publish(
b"test", payload=sent_payload, qos=0
)
fake_broker.expect_packet("publish", publish_packet)

mqttc.disconnect()

disconnect_packet = paho_test.gen_disconnect()
packet_in = fake_broker.receive_packet(1000)
assert packet_in # Check connection was not closed
assert packet_in == disconnect_packet

finally:
mqttc.loop_stop()

packet_in = fake_broker.receive_packet(1)
assert not packet_in # Check connection is closed


@pytest.mark.parametrize("callback_version", [
(CallbackAPIVersion.VERSION1),
(CallbackAPIVersion.VERSION2),
Expand Down

0 comments on commit 7795dcd

Please sign in to comment.