Skip to content

Commit

Permalink
Don't retry when MQTT response is unauthorized
Browse files Browse the repository at this point in the history
  • Loading branch information
justmobilize committed May 13, 2024
1 parent ecfd228 commit 906e676
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 8 deletions.
30 changes: 23 additions & 7 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,18 @@
MQTT_PKT_TYPE_MASK = const(0xF0)


CONNACK_ERROR_INCORRECT_PROTOCOL = const(0x01)
CONNACK_ERROR_ID_REJECTED = const(0x02)
CONNACK_ERROR_SERVER_UNAVAILABLE = const(0x03)
CONNACK_ERROR_INCORECT_USERNAME_PASSWORD = const(0x04)
CONNACK_ERROR_UNAUTHORIZED = const(0x05)

CONNACK_ERRORS = {
const(0x01): "Connection Refused - Incorrect Protocol Version",
const(0x02): "Connection Refused - ID Rejected",
const(0x03): "Connection Refused - Server unavailable",
const(0x04): "Connection Refused - Incorrect username/password",
const(0x05): "Connection Refused - Unauthorized",
CONNACK_ERROR_INCORRECT_PROTOCOL: "Connection Refused - Incorrect Protocol Version",
CONNACK_ERROR_ID_REJECTED: "Connection Refused - ID Rejected",
CONNACK_ERROR_SERVER_UNAVAILABLE: "Connection Refused - Server unavailable",
CONNACK_ERROR_INCORECT_USERNAME_PASSWORD: "Connection Refused - Incorrect username/password",
CONNACK_ERROR_UNAUTHORIZED: "Connection Refused - Unauthorized",
}

_default_sock = None # pylint: disable=invalid-name
Expand All @@ -87,6 +93,10 @@
class MMQTTException(Exception):
"""MiniMQTT Exception class."""

def __init__(self, error, code=None):
super().__init__(error, code)
self.code = code


class NullLogger:
"""Fake logger class that does not do anything"""
Expand Down Expand Up @@ -428,8 +438,14 @@ def connect(
self.logger.warning(f"Socket error when connecting: {e}")
backoff = False
except MMQTTException as e:
last_exception = e
self.logger.info(f"MMQT error: {e}")
if e.code in [
CONNACK_ERROR_INCORECT_USERNAME_PASSWORD,
CONNACK_ERROR_UNAUTHORIZED,
]:
# No sense trying these again, re-raise
raise
last_exception = e
backoff = True

if self._reconnect_attempts_max > 1:
Expand Down Expand Up @@ -535,7 +551,7 @@ def _connect(
rc = self._sock_exact_recv(3)
assert rc[0] == 0x02
if rc[2] != 0x00:
raise MMQTTException(CONNACK_ERRORS[rc[2]])
raise MMQTTException(CONNACK_ERRORS[rc[2]], code=rc[2])
self._is_connected = True
result = rc[0] & 1
if self.on_connect is not None:
Expand Down
38 changes: 37 additions & 1 deletion tests/test_backoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,24 @@ class TestExpBackOff:
"""basic exponential back-off test"""

connect_times = []
raise_exception = None

# pylint: disable=unused-argument
def fake_connect(self, arg):
"""connect() replacement that records the call times and always raises OSError"""
self.connect_times.append(time.monotonic())
raise OSError("this connect failed")
raise self.raise_exception

def test_failing_connect(self) -> None:
"""test that exponential back-off is used when connect() always raises OSError"""
# use RFC 1918 address to avoid dealing with IPv6 in the call list below
host = "172.40.0.3"
port = 1883
self.connect_times = []
error_code = MQTT.CONNACK_ERROR_SERVER_UNAVAILABLE
self.raise_exception = MQTT.MMQTTException(
MQTT.CONNACK_ERRORS[error_code], code=error_code
)

with patch.object(socket.socket, "connect") as mock_method:
mock_method.side_effect = self.fake_connect
Expand All @@ -54,3 +60,33 @@ def test_failing_connect(self) -> None:
print(f"connect() call times: {self.connect_times}")
for i in range(1, connect_retries):
assert self.connect_times[i] >= 2**i

def test_unauthorized(self) -> None:
"""test that exponential back-off is used when connect() always raises OSError"""
# use RFC 1918 address to avoid dealing with IPv6 in the call list below
host = "172.40.0.3"
port = 1883
self.connect_times = []
error_code = MQTT.CONNACK_ERROR_UNAUTHORIZED
self.raise_exception = MQTT.MMQTTException(
MQTT.CONNACK_ERRORS[error_code], code=error_code
)

with patch.object(socket.socket, "connect") as mock_method:
mock_method.side_effect = self.fake_connect

connect_retries = 3
mqtt_client = MQTT.MQTT(
broker=host,
port=port,
socket_pool=socket,
ssl_context=ssl.create_default_context(),
connect_retries=connect_retries,
)
print("connecting")
with pytest.raises(MQTT.MMQTTException) as context:
mqtt_client.connect()
assert "Connection Refused - Unauthorized" in str(context)

mock_method.assert_called()
assert len(self.connect_times) == 1

0 comments on commit 906e676

Please sign in to comment.