Skip to content

Commit

Permalink
Handle NetlinkDumpInterrupted, fix worker metrics going stale after e…
Browse files Browse the repository at this point in the history
…xceptions
  • Loading branch information
DasSkelett committed Apr 2, 2024
1 parent d06f8ea commit 4efe903
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 14 deletions.
25 changes: 19 additions & 6 deletions wgkex/worker/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Optional

import paho.mqtt.client as mqtt
import pyroute2.netlink.exceptions

from wgkex.common import logger
from wgkex.common.mqtt import (
Expand All @@ -20,9 +21,7 @@
from wgkex.worker.msg_queue import q
from wgkex.worker.netlink import (
get_device_data,
link_handler,
get_connected_peers_count,
WireGuardClient,
)

_HOSTNAME = socket.gethostname()
Expand Down Expand Up @@ -206,9 +205,15 @@ def publish_metrics_loop(
topic = TOPIC_CONNECTED_PEERS.format(domain=domain, worker=_HOSTNAME)

while not exit_event.is_set():
publish_metrics(client, topic, domain)
# This drifts slightly over time, doesn't matter for us
exit_event.wait(_METRICS_SEND_INTERVAL)
try:
publish_metrics(client, topic, domain)
except Exception as e:
# Don't crash the thread when an exception is encountered
logger.error(f"Exception during publish metrics task for {domain}:")
logger.error(e)
finally:
# This drifts slightly over time, doesn't matter for us
exit_event.wait(_METRICS_SEND_INTERVAL)

# Set peers metric to -1 to mark worker as offline
# Use QoS 1 (at least once) to make sure the broker notices
Expand All @@ -227,7 +232,15 @@ def publish_metrics(client: mqtt.Client, topic: str, domain: str) -> None:
f"Could not get interface name for domain {domain}. Skipping metrics publication"
)
return
peer_count = get_connected_peers_count(iface)

try:
peer_count = get_connected_peers_count(iface)
except pyroute2.netlink.exceptions.NetlinkDumpInterrupted:
# Handle gracefully, don't update metrics
logger.info(
"Caught NetlinkDumpInterrupted exception while collecting metrics for domain {domain}"
)
return

# Publish metrics, retain it at MQTT broker so restarted wgkex broker has metrics right away
client.publish(topic, peer_count, retain=True)
Expand Down
56 changes: 56 additions & 0 deletions wgkex/worker/mqtt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import mock
import paho.mqtt.client
import pyroute2.netlink.exceptions

from wgkex.common.mqtt import TOPIC_CONNECTED_PEERS
from wgkex.worker import mqtt
Expand Down Expand Up @@ -128,6 +129,61 @@ def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock):

self.assertFalse(thread.is_alive())

@mock.patch.object(mqtt, "_METRICS_SEND_INTERVAL", 0.02)
@mock.patch.object(mqtt, "get_config")
@mock.patch.object(mqtt, "get_connected_peers_count")
def test_publish_metrics_loop_no_exception(self, conn_peers_mock, config_mock):
"""Tests that an exception doesn't interrupt the loop"""
config_mock.return_value = _get_config_mock()
conn_peers_mock.side_effect = Exception("Mocked exception")
mqtt_client = mock.MagicMock(spec=paho.mqtt.client.Client)

ee = threading.Event()
thread = threading.Thread(
target=mqtt.publish_metrics_loop,
args=(ee, mqtt_client, "_ffmuc_domain.one"),
)
thread.start()

i = 0
while i < 20 and not len(conn_peers_mock.mock_calls) >= 2:
i += 1
sleep(0.1)

self.assertTrue(
len(conn_peers_mock.mock_calls) >= 2,
"get_connected_peers_count must be called at least twice",
)

mqtt_client.publish.assert_not_called()

ee.set()

i = 0
while i < 20 and thread.is_alive():
i += 1
sleep(0.1)

self.assertFalse(thread.is_alive())

@mock.patch.object(mqtt, "get_config")
@mock.patch.object(mqtt, "get_connected_peers_count")
def test_publish_metrics_NetlinkDumpInterrupted(self, conn_peers_mock, config_mock):
config_mock.return_value = _get_config_mock()
conn_peers_mock.side_effect = (
pyroute2.netlink.exceptions.NetlinkDumpInterrupted()
)
mqtt_client = mock.MagicMock(spec=paho.mqtt.client.Client)

domain = mqtt.get_config().domains[0]
hostname = socket.gethostname()
topic = TOPIC_CONNECTED_PEERS.format(domain=domain, worker=hostname)

# Must not raise NetlinkDumpInterrupted, but handle gracefully by doing nothing
mqtt.publish_metrics(mqtt_client, topic, domain)

mqtt_client.publish.assert_not_called()

@mock.patch.object(mqtt, "get_config")
def test_on_message_wireguard_success(self, config_mock):
# Tests on_message for success.
Expand Down
17 changes: 13 additions & 4 deletions wgkex/worker/netlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from textwrap import wrap
from typing import Any, Dict, List, Tuple

import pyroute2
import pyroute2.netlink
import pyroute2, pyroute2.netlink, pyroute2.netlink.exceptions

from wgkex.common.utils import mac2eui64
from wgkex.common import logger
Expand Down Expand Up @@ -218,12 +217,22 @@ def get_connected_peers_count(wg_interface: str) -> int:
wg_interface: The WireGuard interface to query.
Returns:
# The number of peers which have recently seen a handshake.
The number of peers which have recently seen a handshake.
Raises:
NetlinkDumpInterrupted if the interface data has changed while it was being returned by netlink
"""
three_mins_ago_in_secs = int((datetime.now() - timedelta(minutes=3)).timestamp())
logger.info("Counting connected wireguard peers for interface %s.", wg_interface)
with pyroute2.WireGuard() as wg:
msgs = wg.info(wg_interface)
try:
msgs = wg.info(wg_interface)
except pyroute2.netlink.exceptions.NetlinkDumpInterrupted:
# Normal behaviour, data has changed while it was being returned by netlink.
# Retry once, don't catch the exception this time, and let the caller handle it.
# See https://github.com/svinota/pyroute2/issues/874
msgs = wg.info(wg_interface)

logger.debug("Got infos for connected peers: %s.", msgs)
count = 0
for msg in msgs:
Expand Down
29 changes: 25 additions & 4 deletions wgkex/worker/netlink_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
# any testing platform can execute tests.
import sys

sys.modules["pyroute2"] = mock.MagicMock()
sys.modules["pyroute2.WireGuard"] = mock.MagicMock()
sys.modules["pyroute2.IPRoute"] = mock.MagicMock()
sys.modules["pyroute2.NDB"] = mock.MagicMock()
import pyroute2.netlink.exceptions as pyroute2_netlink_exceptions

pyroute2_module_mock = mock.MagicMock()
pyroute2_module_mock.netlink.exceptions = pyroute2_netlink_exceptions
sys.modules["pyroute2"] = pyroute2_module_mock
sys.modules["pyroute2.netlink"] = mock.MagicMock()
from pyroute2 import WireGuard
from pyroute2 import IPRoute
Expand Down Expand Up @@ -229,6 +230,26 @@ def msg_get_attr(attr: str):
ret = netlink.get_connected_peers_count("wg-welt")
self.assertEqual(ret, 3)

@mock.patch("pyroute2.WireGuard")
def test_get_connected_peers_count_NetlinkDumpInterrupted(self, nl_wg_mock):
"""Tests getting the correct number of connected peers for an interface."""

nl_wg_mock_ctx = mock.MagicMock()
wg_info_mock = mock.MagicMock(
side_effect=(pyroute2_netlink_exceptions.NetlinkDumpInterrupted),
)
nl_wg_mock_ctx.info = wg_info_mock

nl_wg_mock_inst = nl_wg_mock.return_value
nl_wg_mock_inst.__enter__ = mock.MagicMock(return_value=nl_wg_mock_ctx)

self.assertRaises(
pyroute2_netlink_exceptions.NetlinkDumpInterrupted,
netlink.get_connected_peers_count,
"wg-welt",
)
self.assertTrue(len(wg_info_mock.mock_calls) == 2)

def test_get_device_data_success(self):
def msg_get_attr(attr: str):
if attr == "WGDEVICE_A_LISTEN_PORT":
Expand Down

0 comments on commit 4efe903

Please sign in to comment.