From 4efe903a3f62e2627ff0a4dde03e75c75420c815 Mon Sep 17 00:00:00 2001 From: DasSkelett Date: Thu, 28 Mar 2024 22:18:48 +0000 Subject: [PATCH] Handle NetlinkDumpInterrupted, fix worker metrics going stale after exceptions --- wgkex/worker/mqtt.py | 25 ++++++++++++---- wgkex/worker/mqtt_test.py | 56 ++++++++++++++++++++++++++++++++++++ wgkex/worker/netlink.py | 17 ++++++++--- wgkex/worker/netlink_test.py | 29 ++++++++++++++++--- 4 files changed, 113 insertions(+), 14 deletions(-) diff --git a/wgkex/worker/mqtt.py b/wgkex/worker/mqtt.py index d5941cd..6c30c73 100644 --- a/wgkex/worker/mqtt.py +++ b/wgkex/worker/mqtt.py @@ -9,6 +9,7 @@ from typing import Any, Optional import paho.mqtt.client as mqtt +import pyroute2.netlink.exceptions from wgkex.common import logger from wgkex.common.mqtt import ( @@ -20,9 +21,7 @@ from wgkex.worker.msg_queue import q from wgkex.worker.netlink import ( get_device_data, - link_handler, get_connected_peers_count, - WireGuardClient, ) _HOSTNAME = socket.gethostname() @@ -206,9 +205,15 @@ def publish_metrics_loop( topic = TOPIC_CONNECTED_PEERS.format(domain=domain, worker=_HOSTNAME) while not exit_event.is_set(): - publish_metrics(client, topic, domain) - # This drifts slightly over time, doesn't matter for us - exit_event.wait(_METRICS_SEND_INTERVAL) + try: + publish_metrics(client, topic, domain) + except Exception as e: + # Don't crash the thread when an exception is encountered + logger.error(f"Exception during publish metrics task for {domain}:") + logger.error(e) + finally: + # This drifts slightly over time, doesn't matter for us + exit_event.wait(_METRICS_SEND_INTERVAL) # Set peers metric to -1 to mark worker as offline # Use QoS 1 (at least once) to make sure the broker notices @@ -227,7 +232,15 @@ def publish_metrics(client: mqtt.Client, topic: str, domain: str) -> None: f"Could not get interface name for domain {domain}. Skipping metrics publication" ) return - peer_count = get_connected_peers_count(iface) + + try: + peer_count = get_connected_peers_count(iface) + except pyroute2.netlink.exceptions.NetlinkDumpInterrupted: + # Handle gracefully, don't update metrics + logger.info( + "Caught NetlinkDumpInterrupted exception while collecting metrics for domain {domain}" + ) + return # Publish metrics, retain it at MQTT broker so restarted wgkex broker has metrics right away client.publish(topic, peer_count, retain=True) diff --git a/wgkex/worker/mqtt_test.py b/wgkex/worker/mqtt_test.py index efdd4eb..127ec48 100644 --- a/wgkex/worker/mqtt_test.py +++ b/wgkex/worker/mqtt_test.py @@ -7,6 +7,7 @@ import mock import paho.mqtt.client +import pyroute2.netlink.exceptions from wgkex.common.mqtt import TOPIC_CONNECTED_PEERS from wgkex.worker import mqtt @@ -128,6 +129,61 @@ def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock): self.assertFalse(thread.is_alive()) + @mock.patch.object(mqtt, "_METRICS_SEND_INTERVAL", 0.02) + @mock.patch.object(mqtt, "get_config") + @mock.patch.object(mqtt, "get_connected_peers_count") + def test_publish_metrics_loop_no_exception(self, conn_peers_mock, config_mock): + """Tests that an exception doesn't interrupt the loop""" + config_mock.return_value = _get_config_mock() + conn_peers_mock.side_effect = Exception("Mocked exception") + mqtt_client = mock.MagicMock(spec=paho.mqtt.client.Client) + + ee = threading.Event() + thread = threading.Thread( + target=mqtt.publish_metrics_loop, + args=(ee, mqtt_client, "_ffmuc_domain.one"), + ) + thread.start() + + i = 0 + while i < 20 and not len(conn_peers_mock.mock_calls) >= 2: + i += 1 + sleep(0.1) + + self.assertTrue( + len(conn_peers_mock.mock_calls) >= 2, + "get_connected_peers_count must be called at least twice", + ) + + mqtt_client.publish.assert_not_called() + + ee.set() + + i = 0 + while i < 20 and thread.is_alive(): + i += 1 + sleep(0.1) + + self.assertFalse(thread.is_alive()) + + @mock.patch.object(mqtt, "get_config") + @mock.patch.object(mqtt, "get_connected_peers_count") + def test_publish_metrics_NetlinkDumpInterrupted(self, conn_peers_mock, config_mock): + config_mock.return_value = _get_config_mock() + conn_peers_mock.side_effect = ( + pyroute2.netlink.exceptions.NetlinkDumpInterrupted() + ) + mqtt_client = mock.MagicMock(spec=paho.mqtt.client.Client) + + domain = mqtt.get_config().domains[0] + hostname = socket.gethostname() + topic = TOPIC_CONNECTED_PEERS.format(domain=domain, worker=hostname) + + # Must not raise NetlinkDumpInterrupted, but handle gracefully by doing nothing + mqtt.publish_metrics(mqtt_client, topic, domain) + + mqtt_client.publish.assert_not_called() + @mock.patch.object(mqtt, "get_config") def test_on_message_wireguard_success(self, config_mock): # Tests on_message for success. diff --git a/wgkex/worker/netlink.py b/wgkex/worker/netlink.py index 057e110..1e681aa 100644 --- a/wgkex/worker/netlink.py +++ b/wgkex/worker/netlink.py @@ -9,8 +9,7 @@ from textwrap import wrap from typing import Any, Dict, List, Tuple -import pyroute2 -import pyroute2.netlink +import pyroute2, pyroute2.netlink, pyroute2.netlink.exceptions from wgkex.common.utils import mac2eui64 from wgkex.common import logger @@ -218,12 +217,22 @@ def get_connected_peers_count(wg_interface: str) -> int: wg_interface: The WireGuard interface to query. Returns: - # The number of peers which have recently seen a handshake. + The number of peers which have recently seen a handshake. + + Raises: + NetlinkDumpInterrupted if the interface data has changed while it was being returned by netlink """ three_mins_ago_in_secs = int((datetime.now() - timedelta(minutes=3)).timestamp()) logger.info("Counting connected wireguard peers for interface %s.", wg_interface) with pyroute2.WireGuard() as wg: - msgs = wg.info(wg_interface) + try: + msgs = wg.info(wg_interface) + except pyroute2.netlink.exceptions.NetlinkDumpInterrupted: + # Normal behaviour, data has changed while it was being returned by netlink. + # Retry once, don't catch the exception this time, and let the caller handle it. + # See https://github.com/svinota/pyroute2/issues/874 + msgs = wg.info(wg_interface) + logger.debug("Got infos for connected peers: %s.", msgs) count = 0 for msg in msgs: diff --git a/wgkex/worker/netlink_test.py b/wgkex/worker/netlink_test.py index 86673a7..177fb01 100644 --- a/wgkex/worker/netlink_test.py +++ b/wgkex/worker/netlink_test.py @@ -9,10 +9,11 @@ # any testing platform can execute tests. import sys -sys.modules["pyroute2"] = mock.MagicMock() -sys.modules["pyroute2.WireGuard"] = mock.MagicMock() -sys.modules["pyroute2.IPRoute"] = mock.MagicMock() -sys.modules["pyroute2.NDB"] = mock.MagicMock() +import pyroute2.netlink.exceptions as pyroute2_netlink_exceptions + +pyroute2_module_mock = mock.MagicMock() +pyroute2_module_mock.netlink.exceptions = pyroute2_netlink_exceptions +sys.modules["pyroute2"] = pyroute2_module_mock sys.modules["pyroute2.netlink"] = mock.MagicMock() from pyroute2 import WireGuard from pyroute2 import IPRoute @@ -229,6 +230,26 @@ def msg_get_attr(attr: str): ret = netlink.get_connected_peers_count("wg-welt") self.assertEqual(ret, 3) + @mock.patch("pyroute2.WireGuard") + def test_get_connected_peers_count_NetlinkDumpInterrupted(self, nl_wg_mock): + """Tests getting the correct number of connected peers for an interface.""" + + nl_wg_mock_ctx = mock.MagicMock() + wg_info_mock = mock.MagicMock( + side_effect=(pyroute2_netlink_exceptions.NetlinkDumpInterrupted), + ) + nl_wg_mock_ctx.info = wg_info_mock + + nl_wg_mock_inst = nl_wg_mock.return_value + nl_wg_mock_inst.__enter__ = mock.MagicMock(return_value=nl_wg_mock_ctx) + + self.assertRaises( + pyroute2_netlink_exceptions.NetlinkDumpInterrupted, + netlink.get_connected_peers_count, + "wg-welt", + ) + self.assertTrue(len(wg_info_mock.mock_calls) == 2) + def test_get_device_data_success(self): def msg_get_attr(attr: str): if attr == "WGDEVICE_A_LISTEN_PORT":