From 906e6764d871897bb175f5e878b7e59370d3b884 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Mon, 13 May 2024 14:32:32 -0700 Subject: [PATCH] Don't retry when MQTT response is unauthorized --- adafruit_minimqtt/adafruit_minimqtt.py | 30 +++++++++++++++----- tests/test_backoff.py | 38 +++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 622c6e6..82b43ab 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -72,12 +72,18 @@ MQTT_PKT_TYPE_MASK = const(0xF0) +CONNACK_ERROR_INCORRECT_PROTOCOL = const(0x01) +CONNACK_ERROR_ID_REJECTED = const(0x02) +CONNACK_ERROR_SERVER_UNAVAILABLE = const(0x03) +CONNACK_ERROR_INCORECT_USERNAME_PASSWORD = const(0x04) +CONNACK_ERROR_UNAUTHORIZED = const(0x05) + CONNACK_ERRORS = { - const(0x01): "Connection Refused - Incorrect Protocol Version", - const(0x02): "Connection Refused - ID Rejected", - const(0x03): "Connection Refused - Server unavailable", - const(0x04): "Connection Refused - Incorrect username/password", - const(0x05): "Connection Refused - Unauthorized", + CONNACK_ERROR_INCORRECT_PROTOCOL: "Connection Refused - Incorrect Protocol Version", + CONNACK_ERROR_ID_REJECTED: "Connection Refused - ID Rejected", + CONNACK_ERROR_SERVER_UNAVAILABLE: "Connection Refused - Server unavailable", + CONNACK_ERROR_INCORECT_USERNAME_PASSWORD: "Connection Refused - Incorrect username/password", + CONNACK_ERROR_UNAUTHORIZED: "Connection Refused - Unauthorized", } _default_sock = None # pylint: disable=invalid-name @@ -87,6 +93,10 @@ class MMQTTException(Exception): """MiniMQTT Exception class.""" + def __init__(self, error, code=None): + super().__init__(error, code) + self.code = code + class NullLogger: """Fake logger class that does not do anything""" @@ -428,8 +438,14 @@ def connect( self.logger.warning(f"Socket error when connecting: {e}") backoff = False except MMQTTException as e: - last_exception = e self.logger.info(f"MMQT error: {e}") + if e.code in [ + CONNACK_ERROR_INCORECT_USERNAME_PASSWORD, + CONNACK_ERROR_UNAUTHORIZED, + ]: + # No sense trying these again, re-raise + raise + last_exception = e backoff = True if self._reconnect_attempts_max > 1: @@ -535,7 +551,7 @@ def _connect( rc = self._sock_exact_recv(3) assert rc[0] == 0x02 if rc[2] != 0x00: - raise MMQTTException(CONNACK_ERRORS[rc[2]]) + raise MMQTTException(CONNACK_ERRORS[rc[2]], code=rc[2]) self._is_connected = True result = rc[0] & 1 if self.on_connect is not None: diff --git a/tests/test_backoff.py b/tests/test_backoff.py index e26d07a..ce6097f 100644 --- a/tests/test_backoff.py +++ b/tests/test_backoff.py @@ -18,18 +18,24 @@ class TestExpBackOff: """basic exponential back-off test""" connect_times = [] + raise_exception = None # pylint: disable=unused-argument def fake_connect(self, arg): """connect() replacement that records the call times and always raises OSError""" self.connect_times.append(time.monotonic()) - raise OSError("this connect failed") + raise self.raise_exception def test_failing_connect(self) -> None: """test that exponential back-off is used when connect() always raises OSError""" # use RFC 1918 address to avoid dealing with IPv6 in the call list below host = "172.40.0.3" port = 1883 + self.connect_times = [] + error_code = MQTT.CONNACK_ERROR_SERVER_UNAVAILABLE + self.raise_exception = MQTT.MMQTTException( + MQTT.CONNACK_ERRORS[error_code], code=error_code + ) with patch.object(socket.socket, "connect") as mock_method: mock_method.side_effect = self.fake_connect @@ -54,3 +60,33 @@ def test_failing_connect(self) -> None: print(f"connect() call times: {self.connect_times}") for i in range(1, connect_retries): assert self.connect_times[i] >= 2**i + + def test_unauthorized(self) -> None: + """test that exponential back-off is used when connect() always raises OSError""" + # use RFC 1918 address to avoid dealing with IPv6 in the call list below + host = "172.40.0.3" + port = 1883 + self.connect_times = [] + error_code = MQTT.CONNACK_ERROR_UNAUTHORIZED + self.raise_exception = MQTT.MMQTTException( + MQTT.CONNACK_ERRORS[error_code], code=error_code + ) + + with patch.object(socket.socket, "connect") as mock_method: + mock_method.side_effect = self.fake_connect + + connect_retries = 3 + mqtt_client = MQTT.MQTT( + broker=host, + port=port, + socket_pool=socket, + ssl_context=ssl.create_default_context(), + connect_retries=connect_retries, + ) + print("connecting") + with pytest.raises(MQTT.MMQTTException) as context: + mqtt_client.connect() + assert "Connection Refused - Unauthorized" in str(context) + + mock_method.assert_called() + assert len(self.connect_times) == 1