Skip to content

Commit

Permalink
Merge pull request #213 from justmobilize/no-retry-on-unauthorized
Browse files Browse the repository at this point in the history
Don't retry when MQTT response is unauthorized
  • Loading branch information
dhalbert authored May 20, 2024
2 parents ecfd228 + 16b6c6d commit d412e9a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 11 deletions.
42 changes: 32 additions & 10 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,18 @@
MQTT_PKT_TYPE_MASK = const(0xF0)


CONNACK_ERROR_INCORRECT_PROTOCOL = const(0x01)
CONNACK_ERROR_ID_REJECTED = const(0x02)
CONNACK_ERROR_SERVER_UNAVAILABLE = const(0x03)
CONNACK_ERROR_INCORECT_USERNAME_PASSWORD = const(0x04)
CONNACK_ERROR_UNAUTHORIZED = const(0x05)

CONNACK_ERRORS = {
const(0x01): "Connection Refused - Incorrect Protocol Version",
const(0x02): "Connection Refused - ID Rejected",
const(0x03): "Connection Refused - Server unavailable",
const(0x04): "Connection Refused - Incorrect username/password",
const(0x05): "Connection Refused - Unauthorized",
CONNACK_ERROR_INCORRECT_PROTOCOL: "Connection Refused - Incorrect Protocol Version",
CONNACK_ERROR_ID_REJECTED: "Connection Refused - ID Rejected",
CONNACK_ERROR_SERVER_UNAVAILABLE: "Connection Refused - Server unavailable",
CONNACK_ERROR_INCORECT_USERNAME_PASSWORD: "Connection Refused - Incorrect username/password",
CONNACK_ERROR_UNAUTHORIZED: "Connection Refused - Unauthorized",
}

_default_sock = None # pylint: disable=invalid-name
Expand All @@ -87,6 +93,10 @@
class MMQTTException(Exception):
"""MiniMQTT Exception class."""

def __init__(self, error, code=None):
super().__init__(error, code)
self.code = code


class NullLogger:
"""Fake logger class that does not do anything"""
Expand Down Expand Up @@ -428,17 +438,24 @@ def connect(
self.logger.warning(f"Socket error when connecting: {e}")
backoff = False
except MMQTTException as e:
last_exception = e
self._close_socket()
self.logger.info(f"MMQT error: {e}")
if e.code in [
CONNACK_ERROR_INCORECT_USERNAME_PASSWORD,
CONNACK_ERROR_UNAUTHORIZED,
]:
# No sense trying these again, re-raise
raise
last_exception = e
backoff = True

if self._reconnect_attempts_max > 1:
exc_msg = "Repeated connect failures"
else:
exc_msg = "Connect failure"

if last_exception:
raise MMQTTException(exc_msg) from last_exception

raise MMQTTException(exc_msg)

# pylint: disable=too-many-branches, too-many-statements, too-many-locals
Expand Down Expand Up @@ -535,7 +552,7 @@ def _connect(
rc = self._sock_exact_recv(3)
assert rc[0] == 0x02
if rc[2] != 0x00:
raise MMQTTException(CONNACK_ERRORS[rc[2]])
raise MMQTTException(CONNACK_ERRORS[rc[2]], code=rc[2])
self._is_connected = True
result = rc[0] & 1
if self.on_connect is not None:
Expand All @@ -549,6 +566,12 @@ def _connect(
f"No data received from broker for {self._recv_timeout} seconds."
)

def _close_socket(self):
if self._sock:
self.logger.debug("Closing socket")
self._connection_manager.close_socket(self._sock)
self._sock = None

# pylint: disable=no-self-use
def _encode_remaining_length(
self, fixed_header: bytearray, remaining_length: int
Expand Down Expand Up @@ -577,8 +600,7 @@ def disconnect(self) -> None:
self._sock.send(MQTT_DISCONNECT)
except RuntimeError as e:
self.logger.warning(f"Unable to send DISCONNECT packet: {e}")
self.logger.debug("Closing socket")
self._connection_manager.close_socket(self._sock)
self._close_socket()
self._is_connected = False
self._subscribed_topics = []
self._last_msg_sent_timestamp = 0
Expand Down
40 changes: 39 additions & 1 deletion tests/test_backoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,24 @@ class TestExpBackOff:
"""basic exponential back-off test"""

connect_times = []
raise_exception = None

# pylint: disable=unused-argument
def fake_connect(self, arg):
"""connect() replacement that records the call times and always raises OSError"""
self.connect_times.append(time.monotonic())
raise OSError("this connect failed")
raise self.raise_exception

def test_failing_connect(self) -> None:
"""test that exponential back-off is used when connect() always raises OSError"""
# use RFC 1918 address to avoid dealing with IPv6 in the call list below
host = "172.40.0.3"
port = 1883
self.connect_times = []
error_code = MQTT.CONNACK_ERROR_SERVER_UNAVAILABLE
self.raise_exception = MQTT.MMQTTException(
MQTT.CONNACK_ERRORS[error_code], code=error_code
)

with patch.object(socket.socket, "connect") as mock_method:
mock_method.side_effect = self.fake_connect
Expand All @@ -45,6 +51,7 @@ def test_failing_connect(self) -> None:
print("connecting")
with pytest.raises(MQTT.MMQTTException) as context:
mqtt_client.connect()
assert mqtt_client._sock is None
assert "Repeated connect failures" in str(context)

mock_method.assert_called()
Expand All @@ -54,3 +61,34 @@ def test_failing_connect(self) -> None:
print(f"connect() call times: {self.connect_times}")
for i in range(1, connect_retries):
assert self.connect_times[i] >= 2**i

def test_unauthorized(self) -> None:
"""test that exponential back-off is used when connect() always raises OSError"""
# use RFC 1918 address to avoid dealing with IPv6 in the call list below
host = "172.40.0.3"
port = 1883
self.connect_times = []
error_code = MQTT.CONNACK_ERROR_UNAUTHORIZED
self.raise_exception = MQTT.MMQTTException(
MQTT.CONNACK_ERRORS[error_code], code=error_code
)

with patch.object(socket.socket, "connect") as mock_method:
mock_method.side_effect = self.fake_connect

connect_retries = 3
mqtt_client = MQTT.MQTT(
broker=host,
port=port,
socket_pool=socket,
ssl_context=ssl.create_default_context(),
connect_retries=connect_retries,
)
print("connecting")
with pytest.raises(MQTT.MMQTTException) as context:
mqtt_client.connect()
assert mqtt_client._sock is None
assert "Connection Refused - Unauthorized" in str(context)

mock_method.assert_called()
assert len(self.connect_times) == 1

0 comments on commit d412e9a

Please sign in to comment.