Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle partial socket send()'s #231

Merged
merged 3 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,21 @@ def connect(
raise MMQTTException(exc_msg) from last_exception
raise MMQTTException(exc_msg)

def _send_bytes(
self,
buffer: Union[bytes, bytearray, memoryview],
):
bytes_sent: int = 0
bytes_to_send = len(buffer)
view = memoryview(buffer)
while bytes_sent < bytes_to_send:
try:
bytes_sent += self._sock.send(view[bytes_sent:])
except OSError as exc:
if exc.errno == EAGAIN:
continue
raise

def _connect( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
self,
clean_session: bool = True,
Expand Down Expand Up @@ -529,8 +544,8 @@ def _connect( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
self.logger.debug("Sending CONNECT to broker...")
self.logger.debug(f"Fixed Header: {fixed_header}")
self.logger.debug(f"Variable Header: {var_header}")
self._sock.send(fixed_header)
self._sock.send(var_header)
self._send_bytes(fixed_header)
self._send_bytes(var_header)
# [MQTT-3.1.3-4]
self._send_str(self.client_id)
if self._lw_topic:
Expand Down Expand Up @@ -591,7 +606,7 @@ def disconnect(self) -> None:
self._connected()
self.logger.debug("Sending DISCONNECT packet to broker")
try:
self._sock.send(MQTT_DISCONNECT)
self._send_bytes(MQTT_DISCONNECT)
except (MemoryError, OSError, RuntimeError) as e:
self.logger.warning(f"Unable to send DISCONNECT packet: {e}")
self._close_socket()
Expand All @@ -608,7 +623,7 @@ def ping(self) -> list[int]:
"""
self._connected()
self.logger.debug("Sending PINGREQ")
self._sock.send(MQTT_PINGREQ)
self._send_bytes(MQTT_PINGREQ)
ping_timeout = self.keep_alive
stamp = ticks_ms()

Expand Down Expand Up @@ -683,9 +698,9 @@ def publish( # noqa: PLR0912, Too many branches
qos,
retain,
)
self._sock.send(pub_hdr_fixed)
self._sock.send(pub_hdr_var)
self._sock.send(msg)
self._send_bytes(pub_hdr_fixed)
self._send_bytes(pub_hdr_var)
self._send_bytes(msg)
self._last_msg_sent_timestamp = ticks_ms()
if qos == 0 and self.on_publish is not None:
self.on_publish(self, self.user_data, topic, self._pid)
Expand Down Expand Up @@ -749,12 +764,12 @@ def subscribe( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics)
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
self.logger.debug(f"Fixed Header: {fixed_header}")
self._sock.send(fixed_header)
self._send_bytes(fixed_header)
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
packet_id_bytes = self._pid.to_bytes(2, "big")
var_header = packet_id_bytes
self.logger.debug(f"Variable Header: {var_header}")
self._sock.send(var_header)
self._send_bytes(var_header)
# attaching topic and QOS level to the packet
payload = b""
for t, q in topics:
Expand All @@ -764,7 +779,7 @@ def subscribe( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
for t, q in topics:
self.logger.debug(f"SUBSCRIBING to topic {t} with QoS {q}")
self.logger.debug(f"payload: {payload}")
self._sock.send(payload)
self._send_bytes(payload)
stamp = ticks_ms()
self._last_msg_sent_timestamp = stamp
while True:
Expand Down Expand Up @@ -829,19 +844,19 @@ def unsubscribe( # noqa: PLR0912, Too many branches
packet_length += sum(len(topic.encode("utf-8")) for topic in topics)
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
self.logger.debug(f"Fixed Header: {fixed_header}")
self._sock.send(fixed_header)
self._send_bytes(fixed_header)
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
packet_id_bytes = self._pid.to_bytes(2, "big")
var_header = packet_id_bytes
self.logger.debug(f"Variable Header: {var_header}")
self._sock.send(var_header)
self._send_bytes(var_header)
payload = b""
for t in topics:
topic_size = len(t.encode("utf-8")).to_bytes(2, "big")
payload += topic_size + t.encode()
for t in topics:
self.logger.debug(f"UNSUBSCRIBING from topic {t}")
self._sock.send(payload)
self._send_bytes(payload)
self._last_msg_sent_timestamp = ticks_ms()
self.logger.debug("Waiting for UNSUBACK...")
while True:
Expand Down Expand Up @@ -1028,7 +1043,7 @@ def _wait_for_msg( # noqa: PLR0912, Too many branches
if res[0] & 0x06 == 0x02:
pkt = bytearray(b"\x40\x02\0\0")
struct.pack_into("!H", pkt, 2, pid)
self._sock.send(pkt)
self._send_bytes(pkt)
elif res[0] & 6 == 4:
assert 0

Expand Down Expand Up @@ -1109,11 +1124,11 @@ def _send_str(self, string: str) -> None:

"""
if isinstance(string, str):
self._sock.send(struct.pack("!H", len(string.encode("utf-8"))))
self._sock.send(str.encode(string, "utf-8"))
self._send_bytes(struct.pack("!H", len(string.encode("utf-8"))))
self._send_bytes(str.encode(string, "utf-8"))
else:
self._sock.send(struct.pack("!H", len(string)))
self._sock.send(string)
self._send_bytes(struct.pack("!H", len(string)))
self._send_bytes(string)

@staticmethod
def _valid_topic(topic: str) -> None:
Expand Down
10 changes: 4 additions & 6 deletions tests/test_recv_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from unittest import TestCase, main
from unittest.mock import Mock

from mocket import Mocket

import adafruit_minimqtt.adafruit_minimqtt as MQTT


Expand All @@ -34,7 +36,7 @@ def test_recv_timeout_vs_keepalive(self) -> None:
)

# Create a mock socket that will accept anything and return nothing.
socket_mock = Mock()
socket_mock = Mocket(b"")
socket_mock.recv_into = Mock(side_effect=side_effect)
mqtt_client._sock = socket_mock

Expand All @@ -43,12 +45,8 @@ def test_recv_timeout_vs_keepalive(self) -> None:
with self.assertRaises(MQTT.MMQTTException):
mqtt_client.ping()

# Verify the mock interactions.
socket_mock.send.assert_called_once()
socket_mock.recv_into.assert_called()

now = time.monotonic()
assert recv_timeout <= (now - start) <= (keep_alive + 0.1)
assert recv_timeout <= (now - start) <= (keep_alive + 0.2)


if __name__ == "__main__":
Expand Down
Loading