diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 09f7662..5028dcd 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -461,6 +461,21 @@ def connect( raise MMQTTException(exc_msg) from last_exception raise MMQTTException(exc_msg) + def _send_bytes( + self, + buffer: Union[bytes, bytearray, memoryview], + ): + bytes_sent: int = 0 + bytes_to_send = len(buffer) + view = memoryview(buffer) + while bytes_sent < bytes_to_send: + try: + bytes_sent += self._sock.send(view[bytes_sent:]) + except OSError as exc: + if exc.errno == EAGAIN: + continue + raise + def _connect( # noqa: PLR0912, PLR0915, Too many branches, Too many statements self, clean_session: bool = True, @@ -529,8 +544,8 @@ def _connect( # noqa: PLR0912, PLR0915, Too many branches, Too many statements self.logger.debug("Sending CONNECT to broker...") self.logger.debug(f"Fixed Header: {fixed_header}") self.logger.debug(f"Variable Header: {var_header}") - self._sock.send(fixed_header) - self._sock.send(var_header) + self._send_bytes(fixed_header) + self._send_bytes(var_header) # [MQTT-3.1.3-4] self._send_str(self.client_id) if self._lw_topic: @@ -591,7 +606,7 @@ def disconnect(self) -> None: self._connected() self.logger.debug("Sending DISCONNECT packet to broker") try: - self._sock.send(MQTT_DISCONNECT) + self._send_bytes(MQTT_DISCONNECT) except (MemoryError, OSError, RuntimeError) as e: self.logger.warning(f"Unable to send DISCONNECT packet: {e}") self._close_socket() @@ -608,7 +623,7 @@ def ping(self) -> list[int]: """ self._connected() self.logger.debug("Sending PINGREQ") - self._sock.send(MQTT_PINGREQ) + self._send_bytes(MQTT_PINGREQ) ping_timeout = self.keep_alive stamp = ticks_ms() @@ -683,9 +698,9 @@ def publish( # noqa: PLR0912, Too many branches qos, retain, ) - self._sock.send(pub_hdr_fixed) - self._sock.send(pub_hdr_var) - self._sock.send(msg) + self._send_bytes(pub_hdr_fixed) + self._send_bytes(pub_hdr_var) + self._send_bytes(msg) self._last_msg_sent_timestamp = ticks_ms() if qos == 0 and self.on_publish is not None: self.on_publish(self, self.user_data, topic, self._pid) @@ -749,12 +764,12 @@ def subscribe( # noqa: PLR0912, PLR0915, Too many branches, Too many statements packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics) self._encode_remaining_length(fixed_header, remaining_length=packet_length) self.logger.debug(f"Fixed Header: {fixed_header}") - self._sock.send(fixed_header) + self._send_bytes(fixed_header) self._pid = self._pid + 1 if self._pid < 0xFFFF else 1 packet_id_bytes = self._pid.to_bytes(2, "big") var_header = packet_id_bytes self.logger.debug(f"Variable Header: {var_header}") - self._sock.send(var_header) + self._send_bytes(var_header) # attaching topic and QOS level to the packet payload = b"" for t, q in topics: @@ -764,7 +779,7 @@ def subscribe( # noqa: PLR0912, PLR0915, Too many branches, Too many statements for t, q in topics: self.logger.debug(f"SUBSCRIBING to topic {t} with QoS {q}") self.logger.debug(f"payload: {payload}") - self._sock.send(payload) + self._send_bytes(payload) stamp = ticks_ms() self._last_msg_sent_timestamp = stamp while True: @@ -829,19 +844,19 @@ def unsubscribe( # noqa: PLR0912, Too many branches packet_length += sum(len(topic.encode("utf-8")) for topic in topics) self._encode_remaining_length(fixed_header, remaining_length=packet_length) self.logger.debug(f"Fixed Header: {fixed_header}") - self._sock.send(fixed_header) + self._send_bytes(fixed_header) self._pid = self._pid + 1 if self._pid < 0xFFFF else 1 packet_id_bytes = self._pid.to_bytes(2, "big") var_header = packet_id_bytes self.logger.debug(f"Variable Header: {var_header}") - self._sock.send(var_header) + self._send_bytes(var_header) payload = b"" for t in topics: topic_size = len(t.encode("utf-8")).to_bytes(2, "big") payload += topic_size + t.encode() for t in topics: self.logger.debug(f"UNSUBSCRIBING from topic {t}") - self._sock.send(payload) + self._send_bytes(payload) self._last_msg_sent_timestamp = ticks_ms() self.logger.debug("Waiting for UNSUBACK...") while True: @@ -1028,7 +1043,7 @@ def _wait_for_msg( # noqa: PLR0912, Too many branches if res[0] & 0x06 == 0x02: pkt = bytearray(b"\x40\x02\0\0") struct.pack_into("!H", pkt, 2, pid) - self._sock.send(pkt) + self._send_bytes(pkt) elif res[0] & 6 == 4: assert 0 @@ -1109,11 +1124,11 @@ def _send_str(self, string: str) -> None: """ if isinstance(string, str): - self._sock.send(struct.pack("!H", len(string.encode("utf-8")))) - self._sock.send(str.encode(string, "utf-8")) + self._send_bytes(struct.pack("!H", len(string.encode("utf-8")))) + self._send_bytes(str.encode(string, "utf-8")) else: - self._sock.send(struct.pack("!H", len(string))) - self._sock.send(string) + self._send_bytes(struct.pack("!H", len(string))) + self._send_bytes(string) @staticmethod def _valid_topic(topic: str) -> None: diff --git a/tests/test_recv_timeout.py b/tests/test_recv_timeout.py index 1855525..099a504 100644 --- a/tests/test_recv_timeout.py +++ b/tests/test_recv_timeout.py @@ -9,6 +9,8 @@ from unittest import TestCase, main from unittest.mock import Mock +from mocket import Mocket + import adafruit_minimqtt.adafruit_minimqtt as MQTT @@ -34,7 +36,7 @@ def test_recv_timeout_vs_keepalive(self) -> None: ) # Create a mock socket that will accept anything and return nothing. - socket_mock = Mock() + socket_mock = Mocket(b"") socket_mock.recv_into = Mock(side_effect=side_effect) mqtt_client._sock = socket_mock @@ -43,12 +45,8 @@ def test_recv_timeout_vs_keepalive(self) -> None: with self.assertRaises(MQTT.MMQTTException): mqtt_client.ping() - # Verify the mock interactions. - socket_mock.send.assert_called_once() - socket_mock.recv_into.assert_called() - now = time.monotonic() - assert recv_timeout <= (now - start) <= (keep_alive + 0.1) + assert recv_timeout <= (now - start) <= (keep_alive + 0.2) if __name__ == "__main__":