diff --git a/wgkex/broker/app.py b/wgkex/broker/app.py index 1d753ff..e8122cc 100644 --- a/wgkex/broker/app.py +++ b/wgkex/broker/app.py @@ -131,8 +131,15 @@ def wg_api_v2_key_exchange() -> Tuple[Response | Dict, int]: } }, 400 + # Update number of peers locally to interpolate data between MQTT updates from the worker + # TODO fix data race + current_peers_domain = ( + worker_metrics.get(best_worker) + .get_domain_metrics(domain) + .get(CONNECTED_PEERS_METRIC, 0) + ) worker_metrics.update( - best_worker, domain, CONNECTED_PEERS_METRIC, current_peers + 1 + best_worker, domain, CONNECTED_PEERS_METRIC, current_peers_domain + 1 ) logger.debug( f"Chose worker {best_worker} with {current_peers} connected clients ({diff})" @@ -200,10 +207,10 @@ def handle_mqtt_message_status( _, worker, _ = message.topic.split("/", 2) status = int(message.payload) - if status < 1: + if status < 1 and worker_metrics.get(worker).is_online(): logger.warning(f"Marking worker as offline: {worker}") worker_metrics.set_offline(worker) - else: + elif status >= 1 and not worker_metrics.get(worker).is_online(): logger.warning(f"Marking worker as online: {worker}") worker_metrics.set_online(worker) diff --git a/wgkex/broker/metrics.py b/wgkex/broker/metrics.py index a2e2893..7ca52d1 100644 --- a/wgkex/broker/metrics.py +++ b/wgkex/broker/metrics.py @@ -34,10 +34,25 @@ def set_metric(self, domain: str, metric: str, value: Any) -> None: else: self.domain_data[domain] = {metric: value} + def get_peer_count(self) -> int: + """ Returns the sum of connected peers on this worker over all domains + """ + total = 0 + for data in self.domain_data.values(): + total += max( + data.get( + CONNECTED_PEERS_METRIC, 0 + ), + 0, + ) + + return total @dataclasses.dataclass class WorkerMetricsCollection: - """A container for all worker metrics""" + """A container for all worker metrics + # TODO make threadsafe / fix data races + """ # worker -> WorkerMetrics data: Dict[str, WorkerMetrics] = dataclasses.field(default_factory=dict) @@ -68,7 +83,9 @@ def set_offline(self, worker: str) -> None: if worker in self.data: self.data[worker].online = False - def get_total_peers(self) -> int: + def get_total_peer_count(self) -> int: + """ Returns the sum of connected peers over all workers and domains + """ total = 0 for worker in self.data: worker_data = self.data.get(worker) @@ -96,22 +113,23 @@ def get_best_worker(self, domain: str) -> Tuple[Optional[str], int, int]: # Map metrics to a list of (target diff, peer count, worker) tuples for online workers peers_worker_tuples = [] - total_peers = self.get_total_peers() + total_peers = self.get_total_peer_count() worker_cfg = config.get_config().workers for wm in self.data.values(): if not wm.is_online(domain): continue - peers = wm.get_domain_metrics(domain).get(CONNECTED_PEERS_METRIC) + peers = wm.get_peer_count() rel_weight = worker_cfg.relative_worker_weight(wm.worker) target = rel_weight * total_peers diff = peers - target logger.debug( - f"Worker {wm.worker}: rel weight {rel_weight}, target {target} (total {total_peers}), diff {diff}" + f"Best worker candidate {wm.worker}: current {peers}, target {target} (total {total_peers}, rel weight {rel_weight}), diff {diff}" ) peers_worker_tuples.append((diff, peers, wm.worker)) + # Sort by diff (ascending), workers with most peers missing to target are sorted first peers_worker_tuples = sorted(peers_worker_tuples, key=itemgetter(0)) if len(peers_worker_tuples) > 0: diff --git a/wgkex/worker/mqtt.py b/wgkex/worker/mqtt.py index caf7011..d5941cd 100644 --- a/wgkex/worker/mqtt.py +++ b/wgkex/worker/mqtt.py @@ -127,6 +127,7 @@ def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None: logger.info(f"Subscribing to topic {topic}") client.subscribe(topic) + for domain in domains: # Publish worker data (WG pubkeys, ports, local addresses) iface = wg_interface_name(domain) if iface: diff --git a/wgkex/worker/netlink.py b/wgkex/worker/netlink.py index 366d430..bb413f1 100644 --- a/wgkex/worker/netlink.py +++ b/wgkex/worker/netlink.py @@ -231,11 +231,10 @@ def get_connected_peers_count(wg_interface: str) -> int: if peers: for peer in peers: if ( - peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get( - "tv_sec", int() - ) - > three_mins_ago_in_secs - ): + hshk_time := peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME") + ) is not None and hshk_time.get( + "tv_sec", int() + ) > three_mins_ago_in_secs: count += 1 return count @@ -251,7 +250,7 @@ def get_device_data(wg_interface: str) -> Tuple[int, str, str]: # The listening port, public key, and local IP address of the WireGuard interface. """ logger.info("Reading data from interface %s.", wg_interface) - with pyroute2.WireGuard() as wg, pyroute2.NDB() as ndb: + with pyroute2.WireGuard() as wg, pyroute2.IPRoute() as ipr: msgs = wg.info(wg_interface) logger.debug("Got infos for interface data: %s.", msgs) if len(msgs) > 1: @@ -262,7 +261,11 @@ def get_device_data(wg_interface: str) -> Tuple[int, str, str]: port = int(info.get_attr("WGDEVICE_A_LISTEN_PORT")) public_key = info.get_attr("WGDEVICE_A_PUBLIC_KEY").decode("ascii") - link_address = ndb.interfaces[wg_interface].ipaddr[0].get("address") + + # Get link address using IPRoute + ipr_link = ipr.link_lookup(ifname=wg_interface)[0] + msgs = ipr.get_addr(index=ipr_link) + link_address = msgs[0].get_attr("IFA_ADDRESS") logger.debug( "Interface data: port '%s', public key '%s', link address '%s",