From f56577cbdf59aec11045fc98bc28b50dd1ed4c24 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Mon, 20 Nov 2023 23:30:13 +0100 Subject: [PATCH 01/16] encode remaining length properly for SUBSCRIBE fixes #160 --- adafruit_minimqtt/adafruit_minimqtt.py | 51 +++++++++++++++++--------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 41e0160..d886944 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -60,7 +60,7 @@ MQTT_PINGREQ = b"\xc0\0" MQTT_PINGRESP = const(0xD0) MQTT_PUBLISH = const(0x30) -MQTT_SUB = b"\x82" +MQTT_SUB = const(0x82) MQTT_UNSUB = b"\xA2" MQTT_DISCONNECT = b"\xe0\0" @@ -626,18 +626,7 @@ def _connect( var_header[6] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3 var_header[6] |= self._lw_retain << 5 - # Remaining length calculation - large_rel_length = False - if remaining_length > 0x7F: - large_rel_length = True - # Calculate Remaining Length [2.2.3] - while remaining_length > 0: - encoded_byte = remaining_length % 0x80 - remaining_length = remaining_length // 0x80 - # if there is more data to encode, set the top bit of the byte - if remaining_length > 0: - encoded_byte |= 0x80 - fixed_header.append(encoded_byte) + large_rel_length = self.encode_remaining_length(fixed_header, remaining_length) if large_rel_length: fixed_header.append(0x00) else: @@ -680,6 +669,25 @@ def _connect( f"No data received from broker for {self._recv_timeout} seconds." ) + # pylint: disable=no-self-use + def encode_remaining_length(self, fixed_header, remaining_length): + """ + Encode Remaining Length [2.2.3] + """ + # Remaining length calculation + large_rel_length = False + if remaining_length > 0x7F: + large_rel_length = True + while remaining_length > 0: + encoded_byte = remaining_length % 0x80 + remaining_length = remaining_length // 0x80 + # if there is more data to encode, set the top bit of the byte + if remaining_length > 0: + encoded_byte |= 0x80 + fixed_header.append(encoded_byte) + + return large_rel_length + def disconnect(self) -> None: """Disconnects the MiniMQTT client from the MQTT broker.""" self._connected() @@ -812,7 +820,7 @@ def publish( def subscribe(self, topic: str, qos: int = 0) -> None: """Subscribes to a topic on the MQTT Broker. - This method can subscribe to one topics or multiple topics. + This method can subscribe to one topic or multiple topics. :param str|tuple|list topic: Unique MQTT topic identifier string. If this is a `tuple`, then the tuple should @@ -842,20 +850,27 @@ def subscribe(self, topic: str, qos: int = 0) -> None: self._valid_topic(t) topics.append((t, q)) # Assemble packet + self.logger.debug("Sending SUBSCRIBE to broker...") + fixed_header = bytearray([MQTT_SUB]) packet_length = 2 + (2 * len(topics)) + (1 * len(topics)) packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics) - packet_length_byte = packet_length.to_bytes(1, "big") + self.encode_remaining_length(fixed_header, remaining_length=packet_length) + self.logger.debug(f"Fixed Header: {fixed_header}") + self._sock.send(fixed_header) self._pid = self._pid + 1 if self._pid < 0xFFFF else 1 packet_id_bytes = self._pid.to_bytes(2, "big") - # Packet with variable and fixed headers - packet = MQTT_SUB + packet_length_byte + packet_id_bytes + var_header = packet_id_bytes + self.logger.debug(f"Variable Header: {var_header}") + self._sock.send(var_header) # attaching topic and QOS level to the packet + packet = bytes() for t, q in topics: topic_size = len(t.encode("utf-8")).to_bytes(2, "big") qos_byte = q.to_bytes(1, "big") packet += topic_size + t.encode() + qos_byte for t, q in topics: self.logger.debug("SUBSCRIBING to topic %s with QoS %d", t, q) + self.logger.debug(f"packet: {packet}") self._sock.send(packet) stamp = self.get_monotonic_time() while True: @@ -869,7 +884,7 @@ def subscribe(self, topic: str, qos: int = 0) -> None: if op == 0x90: rc = self._sock_exact_recv(3) # Check packet identifier. - assert rc[1] == packet[2] and rc[2] == packet[3] + assert rc[1] == var_header[0] and rc[2] == var_header[1] remaining_len = rc[0] - 2 assert remaining_len > 0 rc = self._sock_exact_recv(remaining_len) From 3ce387e05a26f4941fc07ac5877652854a776c2e Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Tue, 28 Nov 2023 22:07:17 +0100 Subject: [PATCH 02/16] use f-string for logging it seems the previous code does not properly work with Adafruit logging. This should fix it. --- adafruit_minimqtt/adafruit_minimqtt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index d886944..4355363 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -869,7 +869,7 @@ def subscribe(self, topic: str, qos: int = 0) -> None: qos_byte = q.to_bytes(1, "big") packet += topic_size + t.encode() + qos_byte for t, q in topics: - self.logger.debug("SUBSCRIBING to topic %s with QoS %d", t, q) + self.logger.debug(f"SUBSCRIBING to topic {t} with QoS {q}") self.logger.debug(f"packet: {packet}") self._sock.send(packet) stamp = self.get_monotonic_time() From 1dc406b503ad0a80f7d5d636fd687b146fbaeed0 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Tue, 28 Nov 2023 22:31:04 +0100 Subject: [PATCH 03/16] fix short remaining length encoding --- adafruit_minimqtt/adafruit_minimqtt.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 4355363..00db309 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -626,13 +626,8 @@ def _connect( var_header[6] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3 var_header[6] |= self._lw_retain << 5 - large_rel_length = self.encode_remaining_length(fixed_header, remaining_length) - if large_rel_length: - fixed_header.append(0x00) - else: - fixed_header.append(remaining_length) - fixed_header.append(0x00) - + self.encode_remaining_length(fixed_header, remaining_length) + fixed_header.append(0x00) self.logger.debug("Sending CONNECT to broker...") self.logger.debug(f"Fixed Header: {fixed_header}") self.logger.debug(f"Variable Header: {var_header}") @@ -670,14 +665,12 @@ def _connect( ) # pylint: disable=no-self-use - def encode_remaining_length(self, fixed_header, remaining_length): + def encode_remaining_length(self, fixed_header: bytearray, remaining_length: int): """ Encode Remaining Length [2.2.3] """ # Remaining length calculation - large_rel_length = False if remaining_length > 0x7F: - large_rel_length = True while remaining_length > 0: encoded_byte = remaining_length % 0x80 remaining_length = remaining_length // 0x80 @@ -685,8 +678,8 @@ def encode_remaining_length(self, fixed_header, remaining_length): if remaining_length > 0: encoded_byte |= 0x80 fixed_header.append(encoded_byte) - - return large_rel_length + else: + fixed_header.append(remaining_length) def disconnect(self) -> None: """Disconnects the MiniMQTT client from the MQTT broker.""" From 008ad191fc2bb2f813a00adebd001b46ad77c5de Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Tue, 28 Nov 2023 22:31:22 +0100 Subject: [PATCH 04/16] rename the variable to match the purpose --- adafruit_minimqtt/adafruit_minimqtt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 00db309..a5f6e6a 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -856,15 +856,15 @@ def subscribe(self, topic: str, qos: int = 0) -> None: self.logger.debug(f"Variable Header: {var_header}") self._sock.send(var_header) # attaching topic and QOS level to the packet - packet = bytes() + payload = bytes() for t, q in topics: topic_size = len(t.encode("utf-8")).to_bytes(2, "big") qos_byte = q.to_bytes(1, "big") - packet += topic_size + t.encode() + qos_byte + payload += topic_size + t.encode() + qos_byte for t, q in topics: self.logger.debug(f"SUBSCRIBING to topic {t} with QoS {q}") - self.logger.debug(f"packet: {packet}") - self._sock.send(packet) + self.logger.debug(f"payload: {payload}") + self._sock.send(payload) stamp = self.get_monotonic_time() while True: op = self._wait_for_msg() From 973df686b3a052ef463d752177e10e0286cc9acb Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Tue, 28 Nov 2023 22:39:06 +0100 Subject: [PATCH 05/16] the zero byte belongs to the variable header --- adafruit_minimqtt/adafruit_minimqtt.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index a5f6e6a..4f1e5f4 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -597,13 +597,12 @@ def _connect( self.broker, self.port, timeout=self._socket_timeout ) - # Fixed Header fixed_header = bytearray([0x10]) # Variable CONNECT header [MQTT 3.1.2] # The byte array is used as a template. - var_header = bytearray(b"\x04MQTT\x04\x02\0\0") - var_header[6] = clean_session << 1 + var_header = bytearray(b"\x00\x04MQTT\x04\x02\0\0") + var_header[7] = clean_session << 1 # Set up variable header and remaining_length remaining_length = 12 + len(self.client_id.encode("utf-8")) @@ -614,20 +613,19 @@ def _connect( + 2 + len(self._password.encode("utf-8")) ) - var_header[6] |= 0xC0 + var_header[7] |= 0xC0 if self.keep_alive: assert self.keep_alive < MQTT_TOPIC_LENGTH_LIMIT - var_header[7] |= self.keep_alive >> 8 - var_header[8] |= self.keep_alive & 0x00FF + var_header[8] |= self.keep_alive >> 8 + var_header[9] |= self.keep_alive & 0x00FF if self._lw_topic: remaining_length += ( 2 + len(self._lw_topic.encode("utf-8")) + 2 + len(self._lw_msg) ) - var_header[6] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3 - var_header[6] |= self._lw_retain << 5 + var_header[7] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3 + var_header[7] |= self._lw_retain << 5 self.encode_remaining_length(fixed_header, remaining_length) - fixed_header.append(0x00) self.logger.debug("Sending CONNECT to broker...") self.logger.debug(f"Fixed Header: {fixed_header}") self.logger.debug(f"Variable Header: {var_header}") From 7aacbfe048f595e091e9fee9eca9a674bf13c9d9 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Tue, 28 Nov 2023 22:43:58 +0100 Subject: [PATCH 06/16] deduplicate remaining length encoding for PUBLISH packet --- adafruit_minimqtt/adafruit_minimqtt.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 4f1e5f4..24d93f9 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -765,16 +765,7 @@ def publish( pub_hdr_var.append(self._pid >> 8) pub_hdr_var.append(self._pid & 0xFF) - # Calculate remaining length [2.2.3] - if remaining_length > 0x7F: - while remaining_length > 0: - encoded_byte = remaining_length % 0x80 - remaining_length = remaining_length // 0x80 - if remaining_length > 0: - encoded_byte |= 0x80 - pub_hdr_fixed.append(encoded_byte) - else: - pub_hdr_fixed.append(remaining_length) + self.encode_remaining_length(pub_hdr_fixed, remaining_length) self.logger.debug( "Sending PUBLISH\nTopic: %s\nMsg: %s\ From 163c7305c76145debede17ec9960547480d46f35 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Thu, 30 Nov 2023 23:06:36 +0100 Subject: [PATCH 07/16] PUBLISH can arrive before SUBACK fixes #192 --- adafruit_minimqtt/adafruit_minimqtt.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 24d93f9..05fc199 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -880,11 +880,15 @@ def subscribe(self, topic: str, qos: int = 0) -> None: if self.on_subscribe is not None: self.on_subscribe(self, self.user_data, t, q) self._subscribed_topics.append(t) + return - raise MMQTTException( - f"invalid message received as response to SUBSCRIBE: {hex(op)}" - ) + if op != MQTT_PUBLISH: + # [3.8.4] The Server is permitted to start sending PUBLISH packets + # matching the Subscription before the Server sends the SUBACK Packet. + raise MMQTTException( + f"invalid message received as response to SUBSCRIBE: {hex(op)}" + ) def unsubscribe(self, topic: str) -> None: """Unsubscribes from a MQTT topic. From 62f66f2c58a01c451da3708705595ddedc41fa11 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Thu, 30 Nov 2023 23:40:30 +0100 Subject: [PATCH 08/16] UNSUBSCRIBE needs to encode remaining length correctly too --- adafruit_minimqtt/adafruit_minimqtt.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 05fc199..ebe43e6 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -61,7 +61,7 @@ MQTT_PINGRESP = const(0xD0) MQTT_PUBLISH = const(0x30) MQTT_SUB = const(0x82) -MQTT_UNSUB = b"\xA2" +MQTT_UNSUB = const(0xA2) MQTT_DISCONNECT = b"\xe0\0" MQTT_PKT_TYPE_MASK = const(0xF0) @@ -911,18 +911,25 @@ def unsubscribe(self, topic: str) -> None: "Topic must be subscribed to before attempting unsubscribe." ) # Assemble packet + self.logger.debug("Sending UNSUBSCRIBE to broker...") + fixed_header = bytearray([MQTT_UNSUB]) packet_length = 2 + (2 * len(topics)) packet_length += sum(len(topic.encode("utf-8")) for topic in topics) - packet_length_byte = packet_length.to_bytes(1, "big") + self.encode_remaining_length(fixed_header, remaining_length=packet_length) + self.logger.debug(f"Fixed Header: {fixed_header}") + self._sock.send(fixed_header) self._pid = self._pid + 1 if self._pid < 0xFFFF else 1 packet_id_bytes = self._pid.to_bytes(2, "big") - packet = MQTT_UNSUB + packet_length_byte + packet_id_bytes + var_header = packet_id_bytes + self.logger.debug(f"Variable Header: {var_header}") + self._sock.send(var_header) + payload = bytes() for t in topics: topic_size = len(t.encode("utf-8")).to_bytes(2, "big") - packet += topic_size + t.encode() + payload += topic_size + t.encode() for t in topics: - self.logger.debug("UNSUBSCRIBING from topic %s", t) - self._sock.send(packet) + self.logger.debug(f"UNSUBSCRIBING from topic {t}") + self._sock.send(payload) self.logger.debug("Waiting for UNSUBACK...") while True: stamp = self.get_monotonic_time() From d65b797d7f20a00f5e9cc14d295064c61dd3faf3 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Wed, 13 Dec 2023 23:19:24 +0100 Subject: [PATCH 09/16] improve type hints for subscribe() --- adafruit_minimqtt/adafruit_minimqtt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index ebe43e6..6bc8030 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -800,7 +800,7 @@ def publish( f"No data received from broker for {self._recv_timeout} seconds." ) - def subscribe(self, topic: str, qos: int = 0) -> None: + def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> None: """Subscribes to a topic on the MQTT Broker. This method can subscribe to one topic or multiple topics. From 2cf2f28afec654319303ff15b5032680502f7898 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Wed, 13 Dec 2023 23:19:43 +0100 Subject: [PATCH 10/16] add protocol level test for SUBSCRIBE --- tests/test_subscribe.py | 148 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 tests/test_subscribe.py diff --git a/tests/test_subscribe.py b/tests/test_subscribe.py new file mode 100644 index 0000000..4455aad --- /dev/null +++ b/tests/test_subscribe.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: 2023 Vladimír Kotal +# +# SPDX-License-Identifier: Unlicense + +"""subscribe tests""" + +import logging +import ssl +from unittest import mock + +import pytest + +import adafruit_minimqtt.adafruit_minimqtt as MQTT + + +class Mocket: + """ + Mock Socket tailored for MiniMQTT testing. Records sent data, + hands out pre-recorded reply. + + Inspired by the Mocket class from Adafruit_CircuitPython_Requests + """ + + def __init__(self, to_send): + self._to_send = to_send + + self.sent = bytearray() + + self.timeout = mock.Mock() + self.connect = mock.Mock() + self.close = mock.Mock() + + def send(self, bytes_to_send): + """merely record the bytes. return the length of this bytearray.""" + self.sent.extend(bytes_to_send) + return len(bytes_to_send) + + # MiniMQTT checks for the presence of "recv_into" and switches behavior based on that. + def recv_into(self, retbuf, bufsize): + """return data from internal buffer""" + size = min(bufsize, len(self._to_send)) + if size == 0: + return size + chop = self._to_send[0:size] + retbuf[0:] = chop + self._to_send = self._to_send[size:] + return size + + +# pylint: disable=unused-argument +def handle_subscribe(client, user_data, topic, qos): + """ + Record topics into user data. + """ + assert topic + assert qos == 0 + + user_data.append(topic) + + +# The MQTT packet contents below were captured using Mosquitto client+server. +testdata = [ + # short topic with remaining length encoded as single byte + ( + "foo/bar", + bytearray([0x90, 0x03, 0x00, 0x01, 0x00]), + bytearray( + [ + 0x82, # fixed header + 0x0C, # remaining length + 0x00, + 0x01, # message ID + 0x00, + 0x07, # topic length + 0x66, # topic + 0x6F, + 0x6F, + 0x2F, + 0x62, + 0x61, + 0x72, + 0x00, # QoS + ] + ), + ), + # remaining length is encoded as 2 bytes due to long topic name. + ( + "f" + "o" * 257, + bytearray([0x90, 0x03, 0x00, 0x01, 0x00]), + bytearray( + [ + 0x82, # fixed header + 0x87, # remaining length + 0x02, + 0x00, # message ID + 0x01, + 0x01, # topic length + 0x02, + 0x66, # topic + ] + + [0x6F] * 257 + + [0x00] # QoS + ), + ), +] + + +@pytest.mark.parametrize( + "topic,to_send,exp_recv", testdata, ids=["short_topic", "long_topic"] +) +def test_subscribe(topic, to_send, exp_recv) -> None: + """ + Protocol level testing of SUBSCRIBE and SUBACK packet handling. + + Nothing will travel over the wire, it is all fake. + """ + logging.basicConfig() + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + + host = "localhost" + port = 1883 + + subscribed_topics = [] + mqtt_client = MQTT.MQTT( + broker=host, + port=port, + ssl_context=ssl.create_default_context(), + connect_retries=1, + user_data=subscribed_topics, + ) + + mqtt_client.on_subscribe = handle_subscribe + + # patch is_connected() to avoid CONNECT/CONNACK handling. + mqtt_client.is_connected = lambda: True + mocket = Mocket(to_send) + # pylint: disable=protected-access + mqtt_client._sock = mocket + + mqtt_client.logger = logger + + # pylint: disable=logging-fstring-interpolation + logger.info(f"subscribing to {topic}") + mqtt_client.subscribe(topic) + + assert topic in subscribed_topics + assert mocket.sent == exp_recv From f016d342470023e861db58336ca60ed07473ea3e Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Thu, 14 Dec 2023 20:45:59 +0100 Subject: [PATCH 11/16] augment type hints for unsubscribe() --- adafruit_minimqtt/adafruit_minimqtt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 6bc8030..f971126 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -890,7 +890,7 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N f"invalid message received as response to SUBSCRIBE: {hex(op)}" ) - def unsubscribe(self, topic: str) -> None: + def unsubscribe(self, topic: Optional[Union[str, list]]) -> None: """Unsubscribes from a MQTT topic. :param str|list topic: Unique MQTT topic identifier string or list. From 1bb729b452aaafdb8367acc83dab2b2c792bbde0 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Thu, 14 Dec 2023 20:49:09 +0100 Subject: [PATCH 12/16] refactor Mocket into separate module --- tests/mocket.py | 41 +++++++++++++++++++++++++++++++++++++++++ tests/test_subscribe.py | 36 +----------------------------------- 2 files changed, 42 insertions(+), 35 deletions(-) create mode 100644 tests/mocket.py diff --git a/tests/mocket.py b/tests/mocket.py new file mode 100644 index 0000000..31b4101 --- /dev/null +++ b/tests/mocket.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: 2023 Vladimír Kotal +# +# SPDX-License-Identifier: Unlicense + +"""fake socket class for protocol level testing""" + +from unittest import mock + + +class Mocket: + """ + Mock Socket tailored for MiniMQTT testing. Records sent data, + hands out pre-recorded reply. + + Inspired by the Mocket class from Adafruit_CircuitPython_Requests + """ + + def __init__(self, to_send): + self._to_send = to_send + + self.sent = bytearray() + + self.timeout = mock.Mock() + self.connect = mock.Mock() + self.close = mock.Mock() + + def send(self, bytes_to_send): + """merely record the bytes. return the length of this bytearray.""" + self.sent.extend(bytes_to_send) + return len(bytes_to_send) + + # MiniMQTT checks for the presence of "recv_into" and switches behavior based on that. + def recv_into(self, retbuf, bufsize): + """return data from internal buffer""" + size = min(bufsize, len(self._to_send)) + if size == 0: + return size + chop = self._to_send[0:size] + retbuf[0:] = chop + self._to_send = self._to_send[size:] + return size diff --git a/tests/test_subscribe.py b/tests/test_subscribe.py index 4455aad..413f10a 100644 --- a/tests/test_subscribe.py +++ b/tests/test_subscribe.py @@ -6,47 +6,13 @@ import logging import ssl -from unittest import mock import pytest +from mocket import Mocket import adafruit_minimqtt.adafruit_minimqtt as MQTT -class Mocket: - """ - Mock Socket tailored for MiniMQTT testing. Records sent data, - hands out pre-recorded reply. - - Inspired by the Mocket class from Adafruit_CircuitPython_Requests - """ - - def __init__(self, to_send): - self._to_send = to_send - - self.sent = bytearray() - - self.timeout = mock.Mock() - self.connect = mock.Mock() - self.close = mock.Mock() - - def send(self, bytes_to_send): - """merely record the bytes. return the length of this bytearray.""" - self.sent.extend(bytes_to_send) - return len(bytes_to_send) - - # MiniMQTT checks for the presence of "recv_into" and switches behavior based on that. - def recv_into(self, retbuf, bufsize): - """return data from internal buffer""" - size = min(bufsize, len(self._to_send)) - if size == 0: - return size - chop = self._to_send[0:size] - retbuf[0:] = chop - self._to_send = self._to_send[size:] - return size - - # pylint: disable=unused-argument def handle_subscribe(client, user_data, topic, qos): """ From 8d50e96a4bb9f846b5e283f72572c420c127aaee Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Thu, 14 Dec 2023 21:18:00 +0100 Subject: [PATCH 13/16] add protocol tests for UNSUBSCRIBE packet --- tests/test_unsubscribe.py | 117 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 tests/test_unsubscribe.py diff --git a/tests/test_unsubscribe.py b/tests/test_unsubscribe.py new file mode 100644 index 0000000..04c456b --- /dev/null +++ b/tests/test_unsubscribe.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: 2023 Vladimír Kotal +# +# SPDX-License-Identifier: Unlicense + +"""unsubscribe tests""" + +import logging +import ssl + +import pytest +from mocket import Mocket + +import adafruit_minimqtt.adafruit_minimqtt as MQTT + + +# pylint: disable=unused-argument +def handle_unsubscribe(client, user_data, topic, pid): + """ + Record topics into user data. + """ + assert topic + + user_data.append(topic) + + +# The MQTT packet contents below were captured using Mosquitto client+server. +# These are verbatim, except message ID that was changed from 2 to 1 since in the real world +# capture the UNSUBSCRIBE packet followed the SUBSCRIBE packet. +testdata = [ + # short topic with remaining length encoded as single byte + ( + "foo/bar", + bytearray([0xB0, 0x02, 0x00, 0x01]), + bytearray( + [ + 0xA2, # fixed header + 0x0B, # remaining length + 0x00, # message ID + 0x01, + 0x00, # topic length + 0x07, + 0x66, # topic + 0x6F, + 0x6F, + 0x2F, + 0x62, + 0x61, + 0x72, + ] + ), + ), + # remaining length is encoded as 2 bytes due to long topic name. + ( + "f" + "o" * 257, + bytearray([0xB0, 0x02, 0x00, 0x01]), + bytearray( + [ + 0xA2, # fixed header + 0x86, # remaining length + 0x02, + 0x00, # message ID + 0x01, + 0x01, # topic length + 0x02, + 0x66, # topic + ] + + [0x6F] * 257 + ), + ), +] + + +@pytest.mark.parametrize( + "topic,to_send,exp_recv", testdata, ids=["short_topic", "long_topic"] +) +def test_unsubscribe(topic, to_send, exp_recv) -> None: + """ + Protocol level testing of UNSUBSCRIBE and UNSUBACK packet handling. + + Nothing will travel over the wire, it is all fake. + Also, the topics are not subscribed into. + """ + logging.basicConfig() + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + + host = "localhost" + port = 1883 + + unsubscribed_topics = [] + mqtt_client = MQTT.MQTT( + broker=host, + port=port, + ssl_context=ssl.create_default_context(), + connect_retries=1, + user_data=unsubscribed_topics, + ) + + mqtt_client.on_unsubscribe = handle_unsubscribe + + # patch is_connected() to avoid CONNECT/CONNACK handling. + mqtt_client.is_connected = lambda: True + mocket = Mocket(to_send) + # pylint: disable=protected-access + mqtt_client._sock = mocket + + mqtt_client.logger = logger + + # pylint: disable=protected-access + mqtt_client._subscribed_topics = [topic] + + # pylint: disable=logging-fstring-interpolation + logger.info(f"unsubscribing from {topic}") + mqtt_client.unsubscribe(topic) + + assert topic in unsubscribed_topics + assert mocket.sent == exp_recv From 4c64236162a668aff77184c0269fe384f55cb173 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Thu, 14 Dec 2023 21:53:33 +0100 Subject: [PATCH 14/16] add test for PUBLISH received first --- tests/test_subscribe.py | 53 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/tests/test_subscribe.py b/tests/test_subscribe.py index 413f10a..9b00419 100644 --- a/tests/test_subscribe.py +++ b/tests/test_subscribe.py @@ -29,7 +29,7 @@ def handle_subscribe(client, user_data, topic, qos): # short topic with remaining length encoded as single byte ( "foo/bar", - bytearray([0x90, 0x03, 0x00, 0x01, 0x00]), + bytearray([0x90, 0x03, 0x00, 0x01, 0x00]), # SUBACK bytearray( [ 0x82, # fixed header @@ -52,7 +52,7 @@ def handle_subscribe(client, user_data, topic, qos): # remaining length is encoded as 2 bytes due to long topic name. ( "f" + "o" * 257, - bytearray([0x90, 0x03, 0x00, 0x01, 0x00]), + bytearray([0x90, 0x03, 0x00, 0x01, 0x00]), # SUBACK bytearray( [ 0x82, # fixed header @@ -68,11 +68,58 @@ def handle_subscribe(client, user_data, topic, qos): + [0x00] # QoS ), ), + # SUBSCRIBE responded to by PUBLISH followed by SUBACK + ( + "foo/bar", + bytearray( + [ + 0x30, # PUBLISH + 0x0C, + 0x00, + 0x07, + 0x66, + 0x6F, + 0x6F, + 0x2F, + 0x62, + 0x61, + 0x72, + 0x66, + 0x6F, + 0x6F, + 0x90, # SUBACK + 0x03, + 0x00, + 0x01, + 0x00, + ] + ), + bytearray( + [ + 0x82, # fixed header + 0x0C, # remaining length + 0x00, + 0x01, # message ID + 0x00, + 0x07, # topic length + 0x66, # topic + 0x6F, + 0x6F, + 0x2F, + 0x62, + 0x61, + 0x72, + 0x00, # QoS + ] + ), + ), ] @pytest.mark.parametrize( - "topic,to_send,exp_recv", testdata, ids=["short_topic", "long_topic"] + "topic,to_send,exp_recv", + testdata, + ids=["short_topic", "long_topic", "publish_first"], ) def test_subscribe(topic, to_send, exp_recv) -> None: """ From dceca0c567f57babf3b05fa95d6296d88e56aad4 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Fri, 15 Dec 2023 22:57:42 +0100 Subject: [PATCH 15/16] add test case with long list of topics --- tests/test_unsubscribe.py | 46 ++++++++++++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/tests/test_unsubscribe.py b/tests/test_unsubscribe.py index 04c456b..d5b67b6 100644 --- a/tests/test_unsubscribe.py +++ b/tests/test_unsubscribe.py @@ -24,8 +24,10 @@ def handle_unsubscribe(client, user_data, topic, pid): # The MQTT packet contents below were captured using Mosquitto client+server. -# These are verbatim, except message ID that was changed from 2 to 1 since in the real world -# capture the UNSUBSCRIBE packet followed the SUBSCRIBE packet. +# These are all verbatim, except: +# - message ID that was changed from 2 to 1 since in the real world +# the UNSUBSCRIBE packet followed the SUBSCRIBE packet. +# - the long list topics is sent as individual UNSUBSCRIBE packets by Mosquitto testdata = [ # short topic with remaining length encoded as single byte ( @@ -67,11 +69,37 @@ def handle_unsubscribe(client, user_data, topic, pid): + [0x6F] * 257 ), ), + # use list of topics for more coverage. If the range was (1, 10000), that would be + # long enough to use 3 bytes for remaining length, however that would make the test + # run for many minutes even on modern systems, so 1000 is used instead. + # This results in 2 bytes for the remaining length. + ( + [f"foo/bar{x:04}" for x in range(1, 1000)], + bytearray([0xB0, 0x02, 0x00, 0x01]), + bytearray( + [ + 0xA2, # fixed header + 0xBD, # remaining length + 0x65, + 0x00, # message ID + 0x01, + ] + + sum( + [ + [0x00, 0x0B] + list(f"foo/bar{x:04}".encode("ascii")) + for x in range(1, 1000) + ], + [], + ) + ), + ), ] @pytest.mark.parametrize( - "topic,to_send,exp_recv", testdata, ids=["short_topic", "long_topic"] + "topic,to_send,exp_recv", + testdata, + ids=["short_topic", "long_topic", "topic_list_long"], ) def test_unsubscribe(topic, to_send, exp_recv) -> None: """ @@ -107,11 +135,19 @@ def test_unsubscribe(topic, to_send, exp_recv) -> None: mqtt_client.logger = logger # pylint: disable=protected-access - mqtt_client._subscribed_topics = [topic] + if isinstance(topic, str): + mqtt_client._subscribed_topics = [topic] + elif isinstance(topic, list): + mqtt_client._subscribed_topics = topic # pylint: disable=logging-fstring-interpolation logger.info(f"unsubscribing from {topic}") mqtt_client.unsubscribe(topic) - assert topic in unsubscribed_topics + if isinstance(topic, str): + assert topic in unsubscribed_topics + elif isinstance(topic, list): + for topic_name in topic: + assert topic_name in unsubscribed_topics assert mocket.sent == exp_recv + assert len(mocket._to_send) == 0 From 279387e60d3e08572c6c2093f177bef7518c604e Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Fri, 15 Dec 2023 22:58:58 +0100 Subject: [PATCH 16/16] add test case for long list of topics this uncovered a bug in SUBACK processing --- adafruit_minimqtt/adafruit_minimqtt.py | 39 ++++++++------ tests/test_subscribe.py | 71 +++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 19 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index f971126..7ffb723 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -625,7 +625,7 @@ def _connect( var_header[7] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3 var_header[7] |= self._lw_retain << 5 - self.encode_remaining_length(fixed_header, remaining_length) + self._encode_remaining_length(fixed_header, remaining_length) self.logger.debug("Sending CONNECT to broker...") self.logger.debug(f"Fixed Header: {fixed_header}") self.logger.debug(f"Variable Header: {var_header}") @@ -663,10 +663,13 @@ def _connect( ) # pylint: disable=no-self-use - def encode_remaining_length(self, fixed_header: bytearray, remaining_length: int): - """ - Encode Remaining Length [2.2.3] - """ + def _encode_remaining_length( + self, fixed_header: bytearray, remaining_length: int + ) -> None: + """Encode Remaining Length [2.2.3]""" + if remaining_length > 268_435_455: + raise MMQTTException("invalid remaining length") + # Remaining length calculation if remaining_length > 0x7F: while remaining_length > 0: @@ -765,7 +768,7 @@ def publish( pub_hdr_var.append(self._pid >> 8) pub_hdr_var.append(self._pid & 0xFF) - self.encode_remaining_length(pub_hdr_fixed, remaining_length) + self._encode_remaining_length(pub_hdr_fixed, remaining_length) self.logger.debug( "Sending PUBLISH\nTopic: %s\nMsg: %s\ @@ -836,7 +839,7 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N fixed_header = bytearray([MQTT_SUB]) packet_length = 2 + (2 * len(topics)) + (1 * len(topics)) packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics) - self.encode_remaining_length(fixed_header, remaining_length=packet_length) + self._encode_remaining_length(fixed_header, remaining_length=packet_length) self.logger.debug(f"Fixed Header: {fixed_header}") self._sock.send(fixed_header) self._pid = self._pid + 1 if self._pid < 0xFFFF else 1 @@ -864,13 +867,13 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N ) else: if op == 0x90: - rc = self._sock_exact_recv(3) - # Check packet identifier. - assert rc[1] == var_header[0] and rc[2] == var_header[1] - remaining_len = rc[0] - 2 + remaining_len = self._decode_remaining_length() assert remaining_len > 0 - rc = self._sock_exact_recv(remaining_len) - for i in range(0, remaining_len): + rc = self._sock_exact_recv(2) + # Check packet identifier. + assert rc[0] == var_header[0] and rc[1] == var_header[1] + rc = self._sock_exact_recv(remaining_len - 2) + for i in range(0, remaining_len - 2): if rc[i] not in [0, 1, 2]: raise MMQTTException( f"SUBACK Failure for topic {topics[i][0]}: {hex(rc[i])}" @@ -915,7 +918,7 @@ def unsubscribe(self, topic: Optional[Union[str, list]]) -> None: fixed_header = bytearray([MQTT_UNSUB]) packet_length = 2 + (2 * len(topics)) packet_length += sum(len(topic.encode("utf-8")) for topic in topics) - self.encode_remaining_length(fixed_header, remaining_length=packet_length) + self._encode_remaining_length(fixed_header, remaining_length=packet_length) self.logger.debug(f"Fixed Header: {fixed_header}") self._sock.send(fixed_header) self._pid = self._pid + 1 if self._pid < 0xFFFF else 1 @@ -1090,7 +1093,7 @@ def _wait_for_msg(self) -> Optional[int]: return pkt_type # Handle only the PUBLISH packet type from now on. - sz = self._recv_len() + sz = self._decode_remaining_length() # topic length MSB & LSB topic_len_buf = self._sock_exact_recv(2) topic_len = int((topic_len_buf[0] << 8) | topic_len_buf[1]) @@ -1123,11 +1126,13 @@ def _wait_for_msg(self) -> Optional[int]: return pkt_type - def _recv_len(self) -> int: - """Unpack MQTT message length.""" + def _decode_remaining_length(self) -> int: + """Decode Remaining Length [2.2.3]""" n = 0 sh = 0 while True: + if sh > 28: + raise MMQTTException("invalid remaining length encoding") b = self._sock_exact_recv(1)[0] n |= (b & 0x7F) << sh if not b & 0x80: diff --git a/tests/test_subscribe.py b/tests/test_subscribe.py index 9b00419..a66e7a8 100644 --- a/tests/test_subscribe.py +++ b/tests/test_subscribe.py @@ -49,6 +49,29 @@ def handle_subscribe(client, user_data, topic, qos): ] ), ), + # same as before but with tuple + ( + ("foo/bar", 0), + bytearray([0x90, 0x03, 0x00, 0x01, 0x00]), # SUBACK + bytearray( + [ + 0x82, # fixed header + 0x0C, # remaining length + 0x00, + 0x01, # message ID + 0x00, + 0x07, # topic length + 0x66, # topic + 0x6F, + 0x6F, + 0x2F, + 0x62, + 0x61, + 0x72, + 0x00, # QoS + ] + ), + ), # remaining length is encoded as 2 bytes due to long topic name. ( "f" + "o" * 257, @@ -113,13 +136,52 @@ def handle_subscribe(client, user_data, topic, qos): ] ), ), + # use list of topics for more coverage. If the range was (1, 10000), that would be + # long enough to use 3 bytes for remaining length, however that would make the test + # run for many minutes even on modern systems, so 1001 is used instead. + # This results in 2 bytes for the remaining length. + ( + [(f"foo/bar{x:04}", 0) for x in range(1, 1001)], + bytearray( + [ + 0x90, + 0xEA, # remaining length + 0x07, + 0x00, # message ID + 0x01, + ] + + [0x00] * 1000 # success for all topics + ), + bytearray( + [ + 0x82, # fixed header + 0xB2, # remaining length + 0x6D, + 0x00, # message ID + 0x01, + ] + + sum( + [ + [0x00, 0x0B] + list(f"foo/bar{x:04}".encode("ascii")) + [0x00] + for x in range(1, 1001) + ], + [], + ) + ), + ), ] @pytest.mark.parametrize( "topic,to_send,exp_recv", testdata, - ids=["short_topic", "long_topic", "publish_first"], + ids=[ + "short_topic", + "short_topic_tuple", + "long_topic", + "publish_first", + "topic_list_long", + ], ) def test_subscribe(topic, to_send, exp_recv) -> None: """ @@ -157,5 +219,10 @@ def test_subscribe(topic, to_send, exp_recv) -> None: logger.info(f"subscribing to {topic}") mqtt_client.subscribe(topic) - assert topic in subscribed_topics + if isinstance(topic, str): + assert topic in subscribed_topics + elif isinstance(topic, list): + for topic_name, _ in topic: + assert topic_name in subscribed_topics assert mocket.sent == exp_recv + assert len(mocket._to_send) == 0