diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 41e0160..7ffb723 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -60,8 +60,8 @@ MQTT_PINGREQ = b"\xc0\0" MQTT_PINGRESP = const(0xD0) MQTT_PUBLISH = const(0x30) -MQTT_SUB = b"\x82" -MQTT_UNSUB = b"\xA2" +MQTT_SUB = const(0x82) +MQTT_UNSUB = const(0xA2) MQTT_DISCONNECT = b"\xe0\0" MQTT_PKT_TYPE_MASK = const(0xF0) @@ -597,13 +597,12 @@ def _connect( self.broker, self.port, timeout=self._socket_timeout ) - # Fixed Header fixed_header = bytearray([0x10]) # Variable CONNECT header [MQTT 3.1.2] # The byte array is used as a template. - var_header = bytearray(b"\x04MQTT\x04\x02\0\0") - var_header[6] = clean_session << 1 + var_header = bytearray(b"\x00\x04MQTT\x04\x02\0\0") + var_header[7] = clean_session << 1 # Set up variable header and remaining_length remaining_length = 12 + len(self.client_id.encode("utf-8")) @@ -614,36 +613,19 @@ def _connect( + 2 + len(self._password.encode("utf-8")) ) - var_header[6] |= 0xC0 + var_header[7] |= 0xC0 if self.keep_alive: assert self.keep_alive < MQTT_TOPIC_LENGTH_LIMIT - var_header[7] |= self.keep_alive >> 8 - var_header[8] |= self.keep_alive & 0x00FF + var_header[8] |= self.keep_alive >> 8 + var_header[9] |= self.keep_alive & 0x00FF if self._lw_topic: remaining_length += ( 2 + len(self._lw_topic.encode("utf-8")) + 2 + len(self._lw_msg) ) - var_header[6] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3 - var_header[6] |= self._lw_retain << 5 - - # Remaining length calculation - large_rel_length = False - if remaining_length > 0x7F: - large_rel_length = True - # Calculate Remaining Length [2.2.3] - while remaining_length > 0: - encoded_byte = remaining_length % 0x80 - remaining_length = remaining_length // 0x80 - # if there is more data to encode, set the top bit of the byte - if remaining_length > 0: - encoded_byte |= 0x80 - fixed_header.append(encoded_byte) - if large_rel_length: - fixed_header.append(0x00) - else: - fixed_header.append(remaining_length) - fixed_header.append(0x00) + var_header[7] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3 + var_header[7] |= self._lw_retain << 5 + self._encode_remaining_length(fixed_header, remaining_length) self.logger.debug("Sending CONNECT to broker...") self.logger.debug(f"Fixed Header: {fixed_header}") self.logger.debug(f"Variable Header: {var_header}") @@ -680,6 +662,26 @@ def _connect( f"No data received from broker for {self._recv_timeout} seconds." ) + # pylint: disable=no-self-use + def _encode_remaining_length( + self, fixed_header: bytearray, remaining_length: int + ) -> None: + """Encode Remaining Length [2.2.3]""" + if remaining_length > 268_435_455: + raise MMQTTException("invalid remaining length") + + # Remaining length calculation + if remaining_length > 0x7F: + while remaining_length > 0: + encoded_byte = remaining_length % 0x80 + remaining_length = remaining_length // 0x80 + # if there is more data to encode, set the top bit of the byte + if remaining_length > 0: + encoded_byte |= 0x80 + fixed_header.append(encoded_byte) + else: + fixed_header.append(remaining_length) + def disconnect(self) -> None: """Disconnects the MiniMQTT client from the MQTT broker.""" self._connected() @@ -766,16 +768,7 @@ def publish( pub_hdr_var.append(self._pid >> 8) pub_hdr_var.append(self._pid & 0xFF) - # Calculate remaining length [2.2.3] - if remaining_length > 0x7F: - while remaining_length > 0: - encoded_byte = remaining_length % 0x80 - remaining_length = remaining_length // 0x80 - if remaining_length > 0: - encoded_byte |= 0x80 - pub_hdr_fixed.append(encoded_byte) - else: - pub_hdr_fixed.append(remaining_length) + self._encode_remaining_length(pub_hdr_fixed, remaining_length) self.logger.debug( "Sending PUBLISH\nTopic: %s\nMsg: %s\ @@ -810,9 +803,9 @@ def publish( f"No data received from broker for {self._recv_timeout} seconds." ) - def subscribe(self, topic: str, qos: int = 0) -> None: + def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> None: """Subscribes to a topic on the MQTT Broker. - This method can subscribe to one topics or multiple topics. + This method can subscribe to one topic or multiple topics. :param str|tuple|list topic: Unique MQTT topic identifier string. If this is a `tuple`, then the tuple should @@ -842,21 +835,28 @@ def subscribe(self, topic: str, qos: int = 0) -> None: self._valid_topic(t) topics.append((t, q)) # Assemble packet + self.logger.debug("Sending SUBSCRIBE to broker...") + fixed_header = bytearray([MQTT_SUB]) packet_length = 2 + (2 * len(topics)) + (1 * len(topics)) packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics) - packet_length_byte = packet_length.to_bytes(1, "big") + self._encode_remaining_length(fixed_header, remaining_length=packet_length) + self.logger.debug(f"Fixed Header: {fixed_header}") + self._sock.send(fixed_header) self._pid = self._pid + 1 if self._pid < 0xFFFF else 1 packet_id_bytes = self._pid.to_bytes(2, "big") - # Packet with variable and fixed headers - packet = MQTT_SUB + packet_length_byte + packet_id_bytes + var_header = packet_id_bytes + self.logger.debug(f"Variable Header: {var_header}") + self._sock.send(var_header) # attaching topic and QOS level to the packet + payload = bytes() for t, q in topics: topic_size = len(t.encode("utf-8")).to_bytes(2, "big") qos_byte = q.to_bytes(1, "big") - packet += topic_size + t.encode() + qos_byte + payload += topic_size + t.encode() + qos_byte for t, q in topics: - self.logger.debug("SUBSCRIBING to topic %s with QoS %d", t, q) - self._sock.send(packet) + self.logger.debug(f"SUBSCRIBING to topic {t} with QoS {q}") + self.logger.debug(f"payload: {payload}") + self._sock.send(payload) stamp = self.get_monotonic_time() while True: op = self._wait_for_msg() @@ -867,13 +867,13 @@ def subscribe(self, topic: str, qos: int = 0) -> None: ) else: if op == 0x90: - rc = self._sock_exact_recv(3) - # Check packet identifier. - assert rc[1] == packet[2] and rc[2] == packet[3] - remaining_len = rc[0] - 2 + remaining_len = self._decode_remaining_length() assert remaining_len > 0 - rc = self._sock_exact_recv(remaining_len) - for i in range(0, remaining_len): + rc = self._sock_exact_recv(2) + # Check packet identifier. + assert rc[0] == var_header[0] and rc[1] == var_header[1] + rc = self._sock_exact_recv(remaining_len - 2) + for i in range(0, remaining_len - 2): if rc[i] not in [0, 1, 2]: raise MMQTTException( f"SUBACK Failure for topic {topics[i][0]}: {hex(rc[i])}" @@ -883,13 +883,17 @@ def subscribe(self, topic: str, qos: int = 0) -> None: if self.on_subscribe is not None: self.on_subscribe(self, self.user_data, t, q) self._subscribed_topics.append(t) + return - raise MMQTTException( - f"invalid message received as response to SUBSCRIBE: {hex(op)}" - ) + if op != MQTT_PUBLISH: + # [3.8.4] The Server is permitted to start sending PUBLISH packets + # matching the Subscription before the Server sends the SUBACK Packet. + raise MMQTTException( + f"invalid message received as response to SUBSCRIBE: {hex(op)}" + ) - def unsubscribe(self, topic: str) -> None: + def unsubscribe(self, topic: Optional[Union[str, list]]) -> None: """Unsubscribes from a MQTT topic. :param str|list topic: Unique MQTT topic identifier string or list. @@ -910,18 +914,25 @@ def unsubscribe(self, topic: str) -> None: "Topic must be subscribed to before attempting unsubscribe." ) # Assemble packet + self.logger.debug("Sending UNSUBSCRIBE to broker...") + fixed_header = bytearray([MQTT_UNSUB]) packet_length = 2 + (2 * len(topics)) packet_length += sum(len(topic.encode("utf-8")) for topic in topics) - packet_length_byte = packet_length.to_bytes(1, "big") + self._encode_remaining_length(fixed_header, remaining_length=packet_length) + self.logger.debug(f"Fixed Header: {fixed_header}") + self._sock.send(fixed_header) self._pid = self._pid + 1 if self._pid < 0xFFFF else 1 packet_id_bytes = self._pid.to_bytes(2, "big") - packet = MQTT_UNSUB + packet_length_byte + packet_id_bytes + var_header = packet_id_bytes + self.logger.debug(f"Variable Header: {var_header}") + self._sock.send(var_header) + payload = bytes() for t in topics: topic_size = len(t.encode("utf-8")).to_bytes(2, "big") - packet += topic_size + t.encode() + payload += topic_size + t.encode() for t in topics: - self.logger.debug("UNSUBSCRIBING from topic %s", t) - self._sock.send(packet) + self.logger.debug(f"UNSUBSCRIBING from topic {t}") + self._sock.send(payload) self.logger.debug("Waiting for UNSUBACK...") while True: stamp = self.get_monotonic_time() @@ -1082,7 +1093,7 @@ def _wait_for_msg(self) -> Optional[int]: return pkt_type # Handle only the PUBLISH packet type from now on. - sz = self._recv_len() + sz = self._decode_remaining_length() # topic length MSB & LSB topic_len_buf = self._sock_exact_recv(2) topic_len = int((topic_len_buf[0] << 8) | topic_len_buf[1]) @@ -1115,11 +1126,13 @@ def _wait_for_msg(self) -> Optional[int]: return pkt_type - def _recv_len(self) -> int: - """Unpack MQTT message length.""" + def _decode_remaining_length(self) -> int: + """Decode Remaining Length [2.2.3]""" n = 0 sh = 0 while True: + if sh > 28: + raise MMQTTException("invalid remaining length encoding") b = self._sock_exact_recv(1)[0] n |= (b & 0x7F) << sh if not b & 0x80: diff --git a/tests/mocket.py b/tests/mocket.py new file mode 100644 index 0000000..31b4101 --- /dev/null +++ b/tests/mocket.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: 2023 Vladimír Kotal +# +# SPDX-License-Identifier: Unlicense + +"""fake socket class for protocol level testing""" + +from unittest import mock + + +class Mocket: + """ + Mock Socket tailored for MiniMQTT testing. Records sent data, + hands out pre-recorded reply. + + Inspired by the Mocket class from Adafruit_CircuitPython_Requests + """ + + def __init__(self, to_send): + self._to_send = to_send + + self.sent = bytearray() + + self.timeout = mock.Mock() + self.connect = mock.Mock() + self.close = mock.Mock() + + def send(self, bytes_to_send): + """merely record the bytes. return the length of this bytearray.""" + self.sent.extend(bytes_to_send) + return len(bytes_to_send) + + # MiniMQTT checks for the presence of "recv_into" and switches behavior based on that. + def recv_into(self, retbuf, bufsize): + """return data from internal buffer""" + size = min(bufsize, len(self._to_send)) + if size == 0: + return size + chop = self._to_send[0:size] + retbuf[0:] = chop + self._to_send = self._to_send[size:] + return size diff --git a/tests/test_subscribe.py b/tests/test_subscribe.py new file mode 100644 index 0000000..a66e7a8 --- /dev/null +++ b/tests/test_subscribe.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: 2023 Vladimír Kotal +# +# SPDX-License-Identifier: Unlicense + +"""subscribe tests""" + +import logging +import ssl + +import pytest +from mocket import Mocket + +import adafruit_minimqtt.adafruit_minimqtt as MQTT + + +# pylint: disable=unused-argument +def handle_subscribe(client, user_data, topic, qos): + """ + Record topics into user data. + """ + assert topic + assert qos == 0 + + user_data.append(topic) + + +# The MQTT packet contents below were captured using Mosquitto client+server. +testdata = [ + # short topic with remaining length encoded as single byte + ( + "foo/bar", + bytearray([0x90, 0x03, 0x00, 0x01, 0x00]), # SUBACK + bytearray( + [ + 0x82, # fixed header + 0x0C, # remaining length + 0x00, + 0x01, # message ID + 0x00, + 0x07, # topic length + 0x66, # topic + 0x6F, + 0x6F, + 0x2F, + 0x62, + 0x61, + 0x72, + 0x00, # QoS + ] + ), + ), + # same as before but with tuple + ( + ("foo/bar", 0), + bytearray([0x90, 0x03, 0x00, 0x01, 0x00]), # SUBACK + bytearray( + [ + 0x82, # fixed header + 0x0C, # remaining length + 0x00, + 0x01, # message ID + 0x00, + 0x07, # topic length + 0x66, # topic + 0x6F, + 0x6F, + 0x2F, + 0x62, + 0x61, + 0x72, + 0x00, # QoS + ] + ), + ), + # remaining length is encoded as 2 bytes due to long topic name. + ( + "f" + "o" * 257, + bytearray([0x90, 0x03, 0x00, 0x01, 0x00]), # SUBACK + bytearray( + [ + 0x82, # fixed header + 0x87, # remaining length + 0x02, + 0x00, # message ID + 0x01, + 0x01, # topic length + 0x02, + 0x66, # topic + ] + + [0x6F] * 257 + + [0x00] # QoS + ), + ), + # SUBSCRIBE responded to by PUBLISH followed by SUBACK + ( + "foo/bar", + bytearray( + [ + 0x30, # PUBLISH + 0x0C, + 0x00, + 0x07, + 0x66, + 0x6F, + 0x6F, + 0x2F, + 0x62, + 0x61, + 0x72, + 0x66, + 0x6F, + 0x6F, + 0x90, # SUBACK + 0x03, + 0x00, + 0x01, + 0x00, + ] + ), + bytearray( + [ + 0x82, # fixed header + 0x0C, # remaining length + 0x00, + 0x01, # message ID + 0x00, + 0x07, # topic length + 0x66, # topic + 0x6F, + 0x6F, + 0x2F, + 0x62, + 0x61, + 0x72, + 0x00, # QoS + ] + ), + ), + # use list of topics for more coverage. If the range was (1, 10000), that would be + # long enough to use 3 bytes for remaining length, however that would make the test + # run for many minutes even on modern systems, so 1001 is used instead. + # This results in 2 bytes for the remaining length. + ( + [(f"foo/bar{x:04}", 0) for x in range(1, 1001)], + bytearray( + [ + 0x90, + 0xEA, # remaining length + 0x07, + 0x00, # message ID + 0x01, + ] + + [0x00] * 1000 # success for all topics + ), + bytearray( + [ + 0x82, # fixed header + 0xB2, # remaining length + 0x6D, + 0x00, # message ID + 0x01, + ] + + sum( + [ + [0x00, 0x0B] + list(f"foo/bar{x:04}".encode("ascii")) + [0x00] + for x in range(1, 1001) + ], + [], + ) + ), + ), +] + + +@pytest.mark.parametrize( + "topic,to_send,exp_recv", + testdata, + ids=[ + "short_topic", + "short_topic_tuple", + "long_topic", + "publish_first", + "topic_list_long", + ], +) +def test_subscribe(topic, to_send, exp_recv) -> None: + """ + Protocol level testing of SUBSCRIBE and SUBACK packet handling. + + Nothing will travel over the wire, it is all fake. + """ + logging.basicConfig() + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + + host = "localhost" + port = 1883 + + subscribed_topics = [] + mqtt_client = MQTT.MQTT( + broker=host, + port=port, + ssl_context=ssl.create_default_context(), + connect_retries=1, + user_data=subscribed_topics, + ) + + mqtt_client.on_subscribe = handle_subscribe + + # patch is_connected() to avoid CONNECT/CONNACK handling. + mqtt_client.is_connected = lambda: True + mocket = Mocket(to_send) + # pylint: disable=protected-access + mqtt_client._sock = mocket + + mqtt_client.logger = logger + + # pylint: disable=logging-fstring-interpolation + logger.info(f"subscribing to {topic}") + mqtt_client.subscribe(topic) + + if isinstance(topic, str): + assert topic in subscribed_topics + elif isinstance(topic, list): + for topic_name, _ in topic: + assert topic_name in subscribed_topics + assert mocket.sent == exp_recv + assert len(mocket._to_send) == 0 diff --git a/tests/test_unsubscribe.py b/tests/test_unsubscribe.py new file mode 100644 index 0000000..d5b67b6 --- /dev/null +++ b/tests/test_unsubscribe.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: 2023 Vladimír Kotal +# +# SPDX-License-Identifier: Unlicense + +"""unsubscribe tests""" + +import logging +import ssl + +import pytest +from mocket import Mocket + +import adafruit_minimqtt.adafruit_minimqtt as MQTT + + +# pylint: disable=unused-argument +def handle_unsubscribe(client, user_data, topic, pid): + """ + Record topics into user data. + """ + assert topic + + user_data.append(topic) + + +# The MQTT packet contents below were captured using Mosquitto client+server. +# These are all verbatim, except: +# - message ID that was changed from 2 to 1 since in the real world +# the UNSUBSCRIBE packet followed the SUBSCRIBE packet. +# - the long list topics is sent as individual UNSUBSCRIBE packets by Mosquitto +testdata = [ + # short topic with remaining length encoded as single byte + ( + "foo/bar", + bytearray([0xB0, 0x02, 0x00, 0x01]), + bytearray( + [ + 0xA2, # fixed header + 0x0B, # remaining length + 0x00, # message ID + 0x01, + 0x00, # topic length + 0x07, + 0x66, # topic + 0x6F, + 0x6F, + 0x2F, + 0x62, + 0x61, + 0x72, + ] + ), + ), + # remaining length is encoded as 2 bytes due to long topic name. + ( + "f" + "o" * 257, + bytearray([0xB0, 0x02, 0x00, 0x01]), + bytearray( + [ + 0xA2, # fixed header + 0x86, # remaining length + 0x02, + 0x00, # message ID + 0x01, + 0x01, # topic length + 0x02, + 0x66, # topic + ] + + [0x6F] * 257 + ), + ), + # use list of topics for more coverage. If the range was (1, 10000), that would be + # long enough to use 3 bytes for remaining length, however that would make the test + # run for many minutes even on modern systems, so 1000 is used instead. + # This results in 2 bytes for the remaining length. + ( + [f"foo/bar{x:04}" for x in range(1, 1000)], + bytearray([0xB0, 0x02, 0x00, 0x01]), + bytearray( + [ + 0xA2, # fixed header + 0xBD, # remaining length + 0x65, + 0x00, # message ID + 0x01, + ] + + sum( + [ + [0x00, 0x0B] + list(f"foo/bar{x:04}".encode("ascii")) + for x in range(1, 1000) + ], + [], + ) + ), + ), +] + + +@pytest.mark.parametrize( + "topic,to_send,exp_recv", + testdata, + ids=["short_topic", "long_topic", "topic_list_long"], +) +def test_unsubscribe(topic, to_send, exp_recv) -> None: + """ + Protocol level testing of UNSUBSCRIBE and UNSUBACK packet handling. + + Nothing will travel over the wire, it is all fake. + Also, the topics are not subscribed into. + """ + logging.basicConfig() + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + + host = "localhost" + port = 1883 + + unsubscribed_topics = [] + mqtt_client = MQTT.MQTT( + broker=host, + port=port, + ssl_context=ssl.create_default_context(), + connect_retries=1, + user_data=unsubscribed_topics, + ) + + mqtt_client.on_unsubscribe = handle_unsubscribe + + # patch is_connected() to avoid CONNECT/CONNACK handling. + mqtt_client.is_connected = lambda: True + mocket = Mocket(to_send) + # pylint: disable=protected-access + mqtt_client._sock = mocket + + mqtt_client.logger = logger + + # pylint: disable=protected-access + if isinstance(topic, str): + mqtt_client._subscribed_topics = [topic] + elif isinstance(topic, list): + mqtt_client._subscribed_topics = topic + + # pylint: disable=logging-fstring-interpolation + logger.info(f"unsubscribing from {topic}") + mqtt_client.unsubscribe(topic) + + if isinstance(topic, str): + assert topic in unsubscribed_topics + elif isinstance(topic, list): + for topic_name in topic: + assert topic_name in unsubscribed_topics + assert mocket.sent == exp_recv + assert len(mocket._to_send) == 0