Skip to content

Commit

Permalink
Merge pull request #187 from vladak/subscribe_vs_remaining_len
Browse files Browse the repository at this point in the history
encode/decode remaining length properly for {,UN}SUBSCRIBE/SUBACK
  • Loading branch information
FoamyGuy authored Dec 17, 2023
2 parents 4a52082 + 279387e commit 70faa4f
Show file tree
Hide file tree
Showing 4 changed files with 499 additions and 64 deletions.
141 changes: 77 additions & 64 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
MQTT_PINGREQ = b"\xc0\0"
MQTT_PINGRESP = const(0xD0)
MQTT_PUBLISH = const(0x30)
MQTT_SUB = b"\x82"
MQTT_UNSUB = b"\xA2"
MQTT_SUB = const(0x82)
MQTT_UNSUB = const(0xA2)
MQTT_DISCONNECT = b"\xe0\0"

MQTT_PKT_TYPE_MASK = const(0xF0)
Expand Down Expand Up @@ -597,13 +597,12 @@ def _connect(
self.broker, self.port, timeout=self._socket_timeout
)

# Fixed Header
fixed_header = bytearray([0x10])

# Variable CONNECT header [MQTT 3.1.2]
# The byte array is used as a template.
var_header = bytearray(b"\x04MQTT\x04\x02\0\0")
var_header[6] = clean_session << 1
var_header = bytearray(b"\x00\x04MQTT\x04\x02\0\0")
var_header[7] = clean_session << 1

# Set up variable header and remaining_length
remaining_length = 12 + len(self.client_id.encode("utf-8"))
Expand All @@ -614,36 +613,19 @@ def _connect(
+ 2
+ len(self._password.encode("utf-8"))
)
var_header[6] |= 0xC0
var_header[7] |= 0xC0
if self.keep_alive:
assert self.keep_alive < MQTT_TOPIC_LENGTH_LIMIT
var_header[7] |= self.keep_alive >> 8
var_header[8] |= self.keep_alive & 0x00FF
var_header[8] |= self.keep_alive >> 8
var_header[9] |= self.keep_alive & 0x00FF
if self._lw_topic:
remaining_length += (
2 + len(self._lw_topic.encode("utf-8")) + 2 + len(self._lw_msg)
)
var_header[6] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3
var_header[6] |= self._lw_retain << 5

# Remaining length calculation
large_rel_length = False
if remaining_length > 0x7F:
large_rel_length = True
# Calculate Remaining Length [2.2.3]
while remaining_length > 0:
encoded_byte = remaining_length % 0x80
remaining_length = remaining_length // 0x80
# if there is more data to encode, set the top bit of the byte
if remaining_length > 0:
encoded_byte |= 0x80
fixed_header.append(encoded_byte)
if large_rel_length:
fixed_header.append(0x00)
else:
fixed_header.append(remaining_length)
fixed_header.append(0x00)
var_header[7] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3
var_header[7] |= self._lw_retain << 5

self._encode_remaining_length(fixed_header, remaining_length)
self.logger.debug("Sending CONNECT to broker...")
self.logger.debug(f"Fixed Header: {fixed_header}")
self.logger.debug(f"Variable Header: {var_header}")
Expand Down Expand Up @@ -680,6 +662,26 @@ def _connect(
f"No data received from broker for {self._recv_timeout} seconds."
)

# pylint: disable=no-self-use
def _encode_remaining_length(
self, fixed_header: bytearray, remaining_length: int
) -> None:
"""Encode Remaining Length [2.2.3]"""
if remaining_length > 268_435_455:
raise MMQTTException("invalid remaining length")

# Remaining length calculation
if remaining_length > 0x7F:
while remaining_length > 0:
encoded_byte = remaining_length % 0x80
remaining_length = remaining_length // 0x80
# if there is more data to encode, set the top bit of the byte
if remaining_length > 0:
encoded_byte |= 0x80
fixed_header.append(encoded_byte)
else:
fixed_header.append(remaining_length)

def disconnect(self) -> None:
"""Disconnects the MiniMQTT client from the MQTT broker."""
self._connected()
Expand Down Expand Up @@ -766,16 +768,7 @@ def publish(
pub_hdr_var.append(self._pid >> 8)
pub_hdr_var.append(self._pid & 0xFF)

# Calculate remaining length [2.2.3]
if remaining_length > 0x7F:
while remaining_length > 0:
encoded_byte = remaining_length % 0x80
remaining_length = remaining_length // 0x80
if remaining_length > 0:
encoded_byte |= 0x80
pub_hdr_fixed.append(encoded_byte)
else:
pub_hdr_fixed.append(remaining_length)
self._encode_remaining_length(pub_hdr_fixed, remaining_length)

self.logger.debug(
"Sending PUBLISH\nTopic: %s\nMsg: %s\
Expand Down Expand Up @@ -810,9 +803,9 @@ def publish(
f"No data received from broker for {self._recv_timeout} seconds."
)

def subscribe(self, topic: str, qos: int = 0) -> None:
def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> None:
"""Subscribes to a topic on the MQTT Broker.
This method can subscribe to one topics or multiple topics.
This method can subscribe to one topic or multiple topics.
:param str|tuple|list topic: Unique MQTT topic identifier string. If
this is a `tuple`, then the tuple should
Expand Down Expand Up @@ -842,21 +835,28 @@ def subscribe(self, topic: str, qos: int = 0) -> None:
self._valid_topic(t)
topics.append((t, q))
# Assemble packet
self.logger.debug("Sending SUBSCRIBE to broker...")
fixed_header = bytearray([MQTT_SUB])
packet_length = 2 + (2 * len(topics)) + (1 * len(topics))
packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics)
packet_length_byte = packet_length.to_bytes(1, "big")
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
self.logger.debug(f"Fixed Header: {fixed_header}")
self._sock.send(fixed_header)
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
packet_id_bytes = self._pid.to_bytes(2, "big")
# Packet with variable and fixed headers
packet = MQTT_SUB + packet_length_byte + packet_id_bytes
var_header = packet_id_bytes
self.logger.debug(f"Variable Header: {var_header}")
self._sock.send(var_header)
# attaching topic and QOS level to the packet
payload = bytes()
for t, q in topics:
topic_size = len(t.encode("utf-8")).to_bytes(2, "big")
qos_byte = q.to_bytes(1, "big")
packet += topic_size + t.encode() + qos_byte
payload += topic_size + t.encode() + qos_byte
for t, q in topics:
self.logger.debug("SUBSCRIBING to topic %s with QoS %d", t, q)
self._sock.send(packet)
self.logger.debug(f"SUBSCRIBING to topic {t} with QoS {q}")
self.logger.debug(f"payload: {payload}")
self._sock.send(payload)
stamp = self.get_monotonic_time()
while True:
op = self._wait_for_msg()
Expand All @@ -867,13 +867,13 @@ def subscribe(self, topic: str, qos: int = 0) -> None:
)
else:
if op == 0x90:
rc = self._sock_exact_recv(3)
# Check packet identifier.
assert rc[1] == packet[2] and rc[2] == packet[3]
remaining_len = rc[0] - 2
remaining_len = self._decode_remaining_length()
assert remaining_len > 0
rc = self._sock_exact_recv(remaining_len)
for i in range(0, remaining_len):
rc = self._sock_exact_recv(2)
# Check packet identifier.
assert rc[0] == var_header[0] and rc[1] == var_header[1]
rc = self._sock_exact_recv(remaining_len - 2)
for i in range(0, remaining_len - 2):
if rc[i] not in [0, 1, 2]:
raise MMQTTException(
f"SUBACK Failure for topic {topics[i][0]}: {hex(rc[i])}"
Expand All @@ -883,13 +883,17 @@ def subscribe(self, topic: str, qos: int = 0) -> None:
if self.on_subscribe is not None:
self.on_subscribe(self, self.user_data, t, q)
self._subscribed_topics.append(t)

return

raise MMQTTException(
f"invalid message received as response to SUBSCRIBE: {hex(op)}"
)
if op != MQTT_PUBLISH:
# [3.8.4] The Server is permitted to start sending PUBLISH packets
# matching the Subscription before the Server sends the SUBACK Packet.
raise MMQTTException(
f"invalid message received as response to SUBSCRIBE: {hex(op)}"
)

def unsubscribe(self, topic: str) -> None:
def unsubscribe(self, topic: Optional[Union[str, list]]) -> None:
"""Unsubscribes from a MQTT topic.
:param str|list topic: Unique MQTT topic identifier string or list.
Expand All @@ -910,18 +914,25 @@ def unsubscribe(self, topic: str) -> None:
"Topic must be subscribed to before attempting unsubscribe."
)
# Assemble packet
self.logger.debug("Sending UNSUBSCRIBE to broker...")
fixed_header = bytearray([MQTT_UNSUB])
packet_length = 2 + (2 * len(topics))
packet_length += sum(len(topic.encode("utf-8")) for topic in topics)
packet_length_byte = packet_length.to_bytes(1, "big")
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
self.logger.debug(f"Fixed Header: {fixed_header}")
self._sock.send(fixed_header)
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
packet_id_bytes = self._pid.to_bytes(2, "big")
packet = MQTT_UNSUB + packet_length_byte + packet_id_bytes
var_header = packet_id_bytes
self.logger.debug(f"Variable Header: {var_header}")
self._sock.send(var_header)
payload = bytes()
for t in topics:
topic_size = len(t.encode("utf-8")).to_bytes(2, "big")
packet += topic_size + t.encode()
payload += topic_size + t.encode()
for t in topics:
self.logger.debug("UNSUBSCRIBING from topic %s", t)
self._sock.send(packet)
self.logger.debug(f"UNSUBSCRIBING from topic {t}")
self._sock.send(payload)
self.logger.debug("Waiting for UNSUBACK...")
while True:
stamp = self.get_monotonic_time()
Expand Down Expand Up @@ -1082,7 +1093,7 @@ def _wait_for_msg(self) -> Optional[int]:
return pkt_type

# Handle only the PUBLISH packet type from now on.
sz = self._recv_len()
sz = self._decode_remaining_length()
# topic length MSB & LSB
topic_len_buf = self._sock_exact_recv(2)
topic_len = int((topic_len_buf[0] << 8) | topic_len_buf[1])
Expand Down Expand Up @@ -1115,11 +1126,13 @@ def _wait_for_msg(self) -> Optional[int]:

return pkt_type

def _recv_len(self) -> int:
"""Unpack MQTT message length."""
def _decode_remaining_length(self) -> int:
"""Decode Remaining Length [2.2.3]"""
n = 0
sh = 0
while True:
if sh > 28:
raise MMQTTException("invalid remaining length encoding")
b = self._sock_exact_recv(1)[0]
n |= (b & 0x7F) << sh
if not b & 0x80:
Expand Down
41 changes: 41 additions & 0 deletions tests/mocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: 2023 Vladimír Kotal
#
# SPDX-License-Identifier: Unlicense

"""fake socket class for protocol level testing"""

from unittest import mock


class Mocket:
"""
Mock Socket tailored for MiniMQTT testing. Records sent data,
hands out pre-recorded reply.
Inspired by the Mocket class from Adafruit_CircuitPython_Requests
"""

def __init__(self, to_send):
self._to_send = to_send

self.sent = bytearray()

self.timeout = mock.Mock()
self.connect = mock.Mock()
self.close = mock.Mock()

def send(self, bytes_to_send):
"""merely record the bytes. return the length of this bytearray."""
self.sent.extend(bytes_to_send)
return len(bytes_to_send)

# MiniMQTT checks for the presence of "recv_into" and switches behavior based on that.
def recv_into(self, retbuf, bufsize):
"""return data from internal buffer"""
size = min(bufsize, len(self._to_send))
if size == 0:
return size
chop = self._to_send[0:size]
retbuf[0:] = chop
self._to_send = self._to_send[size:]
return size
Loading

0 comments on commit 70faa4f

Please sign in to comment.