From b6be1aaacfcd420feb2c2a184f18b52b931abce6 Mon Sep 17 00:00:00 2001 From: Eloff Date: Tue, 4 Feb 2025 18:38:13 -0700 Subject: [PATCH 01/11] add on_subscribe and on_unsubscribe to ServerListener --- Cargo.lock | 1 + rust/foxglove/Cargo.toml | 1 + rust/foxglove/src/websocket.rs | 123 ++++++++++++++++++++++----------- 3 files changed, 85 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 927dcdf4..f51909af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -491,6 +491,7 @@ dependencies = [ "arc-swap", "assert_matches", "base64", + "bimap", "bytes", "clap", "env_logger", diff --git a/rust/foxglove/Cargo.toml b/rust/foxglove/Cargo.toml index 04469185..60fc0abc 100644 --- a/rust/foxglove/Cargo.toml +++ b/rust/foxglove/Cargo.toml @@ -10,6 +10,7 @@ license = "MIT" unstable = [] [dependencies] +bimap = "0.6.3" schemars = "0.8.21" arc-swap = "1.7.1" base64 = "0.22.1" diff --git a/rust/foxglove/src/websocket.rs b/rust/foxglove/src/websocket.rs index a10f8cfa..e7955a8c 100644 --- a/rust/foxglove/src/websocket.rs +++ b/rust/foxglove/src/websocket.rs @@ -1,5 +1,6 @@ use crate::channel::ChannelId; use crate::cow_vec::CowVec; +use crate::websocket::protocol::client::Subscription; pub use crate::websocket::protocol::client::{ ClientChannel, ClientChannelId, ClientMessage, SubscriptionId, }; @@ -7,6 +8,7 @@ pub use crate::websocket::protocol::server::Capability; #[cfg(feature = "unstable")] pub use crate::websocket::protocol::server::{Parameter, ParameterType, ParameterValue}; use crate::{Channel, FoxgloveError, LogSink, Metadata}; +use bimap::BiHashMap; use bytes::{BufMut, BytesMut}; use flume::TrySendError; use futures_util::{stream::SplitSink, SinkExt, StreamExt}; @@ -132,6 +134,11 @@ impl Default for Server { /// handling client message events. pub trait ServerListener: Send + Sync { fn on_message_data(&self, channel_id: ClientChannelId, payload: &[u8]); + fn on_subscribe(&self, channel_id: ChannelId); + fn on_unsubscribe(&self, channel_id: ChannelId); + // Whole Channel? + fn on_client_advertise(&self, channel_id: ClientChannelId); + fn on_client_unadvertise(&self, channel_id: ClientChannelId); } /// State for a client maintained by the server @@ -144,7 +151,7 @@ struct ConnectedClient { control_plane_tx: flume::Sender, control_plane_rx: flume::Receiver, /// Subscriptions from this client - subscriptions_by_channel: parking_lot::Mutex>, + subscriptions: parking_lot::Mutex>, /// Channels advertised by this client advertised_channels: parking_lot::Mutex>, /// Optional callback handler for a server implementation @@ -170,45 +177,26 @@ impl ConnectedClient { self.send_error(format!("Invalid message: {message}")); return; }; + let Some(server) = self.server.upgrade() else { + tracing::error!("Server closed"); + return; + }; match serde_json::from_str::(message) { Ok(ClientMessage::Subscribe { subscriptions }) => { - let Some(server) = self.server.upgrade() else { - tracing::error!("Server closed"); - return; - }; - let channels = server.channels.read(); - for subscription in subscriptions { - if !channels.contains_key(&subscription.channel_id) { - tracing::error!( - "Client {} attempted to subscribe to unknown channel: {}", - self.addr, - subscription.channel_id - ); - self.send_error(format!("Unknown channel ID: {}", subscription.channel_id)); - continue; - } - - let mut subscriptions = self.subscriptions_by_channel.lock(); - subscriptions.insert(subscription.channel_id, subscription.id); - tracing::info!( - "Client {} subscribed to channel {} with subscription id {}", - self.addr, - subscription.channel_id, - subscription.id - ); - } + self.on_subscribe(server, subscriptions); } Ok(ClientMessage::Unsubscribe { subscription_ids }) => { - let mut subscriptions = self.subscriptions_by_channel.lock(); - subscriptions - .retain(|_, subscription_id| !subscription_ids.contains(subscription_id)) + let mut subscriptions = self.subscriptions.lock(); + for subscription_id in subscription_ids { + if let Some((channel_id, _)) = subscriptions.remove_by_right(&subscription_id) { + if let Some(handler) = self.server_listener.as_ref() { + handler.on_unsubscribe(channel_id); + } + } + } } Ok(ClientMessage::Advertise { channels }) => { - let Some(server) = self.server.upgrade() else { - tracing::info!("Server closed; ignoring client advertisement"); - return; - }; if !server.capabilities.contains(&Capability::ClientPublish) { self.send_error("Server does not support clientPublish capability".to_string()); return; @@ -232,13 +220,68 @@ impl ConnectedClient { } } + fn on_subscribe(&self, server: Arc, subscriptions: Vec) { + let channels = server.channels.read(); + for subscription in subscriptions { + if !channels.contains_key(&subscription.channel_id) { + tracing::error!( + "Client {} attempted to subscribe to unknown channel: {}", + self.addr, + subscription.channel_id + ); + self.send_error(format!("Unknown channel ID: {}", subscription.channel_id)); + continue; + } + + let mut subscriptions = self.subscriptions.lock(); + if subscriptions + .insert_no_overwrite(subscription.channel_id, subscription.id) + .is_err() + { + if subscriptions.contains_left(&subscription.channel_id) { + self.send_warning(format!( + "Client is already subscribed to channel: {}; ignoring subscription", + subscription.channel_id + )); + } else { + assert!(subscriptions.contains_right(&subscription.id)); + self.send_error(format!( + "Subscription ID was already used: {}; ignoring subscription", + subscription.id + )); + } + continue; + } + + tracing::info!( + "Client {} subscribed to channel {} with subscription id {}", + self.addr, + subscription.channel_id, + subscription.id + ); + if let Some(handler) = self.server_listener.as_ref() { + handler.on_subscribe(subscription.channel_id); + } + } + } + /// Send an ad hoc error status message to the client, with the given message. fn send_error(&self, message: String) { - let status = protocol::server::Status { - level: protocol::server::StatusLevel::Error, - message: message.to_string(), - id: None, - }; + self.send_status(protocol::server::StatusLevel::Error, message, None); + } + + /// Send an ad hoc warning status message to the client, with the given message. + fn send_warning(&self, message: String) { + self.send_status(protocol::server::StatusLevel::Warning, message, None); + } + + fn send_status( + &self, + level: protocol::server::StatusLevel, + message: String, + id: Option, + ) { + let status = protocol::server::Status { level, message, id }; let message = Message::text(serde_json::to_string(&status).unwrap()); // If the message can't be sent, or the outbox is full, log a warning and continue. self.data_plane_tx.try_send(message).unwrap_or_else(|err| { @@ -508,7 +551,7 @@ impl Server { data_plane_rx: data_rx, control_plane_tx: ctrl_tx, control_plane_rx: ctrl_rx, - subscriptions_by_channel: parking_lot::Mutex::new(HashMap::new()), + subscriptions: parking_lot::Mutex::new(BiHashMap::new()), advertised_channels: parking_lot::Mutex::new(HashMap::new()), server_listener: self.listener.clone(), server: self.weak_self.clone(), @@ -673,7 +716,7 @@ impl LogSink for Server { ) -> Result<(), FoxgloveError> { let clients = self.clients.get(); for client in clients.iter() { - let subscriptions = client.subscriptions_by_channel.lock(); + let subscriptions = client.subscriptions.lock(); let Some(subscription_id) = subscriptions.get(&channel.id).cloned() else { continue; }; From bdd83e392f2dac0a9fa260932efac449cbffa163 Mon Sep 17 00:00:00 2001 From: Eloff Date: Wed, 5 Feb 2025 12:40:22 -0700 Subject: [PATCH 02/11] implement on_client_advertise and on_client_unadvertise --- rust/foxglove/src/websocket.rs | 39 +++++++++++++++---- .../foxglove/src/websocket/protocol/client.rs | 2 +- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/rust/foxglove/src/websocket.rs b/rust/foxglove/src/websocket.rs index e7955a8c..dba498fb 100644 --- a/rust/foxglove/src/websocket.rs +++ b/rust/foxglove/src/websocket.rs @@ -12,6 +12,7 @@ use bimap::BiHashMap; use bytes::{BufMut, BytesMut}; use flume::TrySendError; use futures_util::{stream::SplitSink, SinkExt, StreamExt}; +use std::collections::hash_map::Entry; use std::collections::HashSet; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::{AcqRel, Acquire}; @@ -136,8 +137,7 @@ pub trait ServerListener: Send + Sync { fn on_message_data(&self, channel_id: ClientChannelId, payload: &[u8]); fn on_subscribe(&self, channel_id: ChannelId); fn on_unsubscribe(&self, channel_id: ChannelId); - // Whole Channel? - fn on_client_advertise(&self, channel_id: ClientChannelId); + fn on_client_advertise(&self, channel: &ClientChannel); fn on_client_unadvertise(&self, channel_id: ClientChannelId); } @@ -202,15 +202,40 @@ impl ConnectedClient { return; } - let mut advertised_channels = self.advertised_channels.lock(); for channel in channels { - advertised_channels.insert(channel.id, channel); + // Using a limited scope here to avoid holding the lock on advertised_channels while calling on_client_advertise + { + match self.advertised_channels.lock().entry(channel.id) { + Entry::Occupied(_) => { + self.send_warning(format!( + "Client is already advertising channel: {}; ignoring advertisement", + channel.id + )); + continue; + } + Entry::Vacant(entry) => { + entry.insert(channel.clone()); + } + } + } + + if let Some(handler) = self.server_listener.as_ref() { + handler.on_client_advertise(&channel); + } } } Ok(ClientMessage::Unadvertise { channel_ids }) => { - let mut advertised_channels = self.advertised_channels.lock(); - for id in channel_ids { - advertised_channels.remove(&id); + // Using a limited scope and iterating twice to avoid holding the lock on advertised_channels while calling on_client_unadvertise + { + let mut advertised_channels = self.advertised_channels.lock(); + for id in channel_ids.iter() { + advertised_channels.remove(&id); + } + } + if let Some(handler) = self.server_listener.as_ref() { + for id in channel_ids { + handler.on_client_unadvertise(id); + } } } _ => { diff --git a/rust/foxglove/src/websocket/protocol/client.rs b/rust/foxglove/src/websocket/protocol/client.rs index 690f0e28..cfde3412 100644 --- a/rust/foxglove/src/websocket/protocol/client.rs +++ b/rust/foxglove/src/websocket/protocol/client.rs @@ -83,7 +83,7 @@ pub struct Subscription { } #[doc(hidden)] -#[derive(Debug, Deserialize, PartialEq, Serialize)] +#[derive(Debug, Clone, Deserialize, PartialEq, Serialize)] #[serde(rename_all = "camelCase")] pub struct ClientChannel { pub id: ClientChannelId, From 338b6c2ebdebd5a21969f33eb33fc23ecc127389 Mon Sep 17 00:00:00 2001 From: Eloff Date: Wed, 5 Feb 2025 12:49:57 -0700 Subject: [PATCH 03/11] don't call on_unadvertise if channel wasn't advertised --- rust/foxglove/src/websocket.rs | 89 ++++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 37 deletions(-) diff --git a/rust/foxglove/src/websocket.rs b/rust/foxglove/src/websocket.rs index dba498fb..e1a39898 100644 --- a/rust/foxglove/src/websocket.rs +++ b/rust/foxglove/src/websocket.rs @@ -197,50 +197,65 @@ impl ConnectedClient { } } Ok(ClientMessage::Advertise { channels }) => { - if !server.capabilities.contains(&Capability::ClientPublish) { - self.send_error("Server does not support clientPublish capability".to_string()); - return; - } - - for channel in channels { - // Using a limited scope here to avoid holding the lock on advertised_channels while calling on_client_advertise - { - match self.advertised_channels.lock().entry(channel.id) { - Entry::Occupied(_) => { - self.send_warning(format!( - "Client is already advertising channel: {}; ignoring advertisement", - channel.id - )); - continue; - } - Entry::Vacant(entry) => { - entry.insert(channel.clone()); - } - } - } + self.on_advertise(server, channels); + } + Ok(ClientMessage::Unadvertise { channel_ids }) => { + self.on_unadvertise(channel_ids); + } + _ => { + tracing::error!("Unsupported message from {}: {message}", self.addr); + self.send_error(format!("Unsupported message: {message}")); + } + } + } - if let Some(handler) = self.server_listener.as_ref() { - handler.on_client_advertise(&channel); - } + fn on_unadvertise(&self, channel_ids: Vec) { + // Using a limited scope and iterating twice to avoid holding the lock on advertised_channels while calling on_client_unadvertise + let mut channels_not_found = Vec::new(); + { + let mut advertised_channels = self.advertised_channels.lock(); + for id in channel_ids.iter() { + if advertised_channels.remove(&id).is_none() { + channels_not_found.push(id); + self.send_warning(format!( + "Client is not advertising channel: {}; ignoring unadvertisement", + id + )); } } - Ok(ClientMessage::Unadvertise { channel_ids }) => { - // Using a limited scope and iterating twice to avoid holding the lock on advertised_channels while calling on_client_unadvertise - { - let mut advertised_channels = self.advertised_channels.lock(); - for id in channel_ids.iter() { - advertised_channels.remove(&id); + } + if let Some(handler) = self.server_listener.as_ref() { + for id in channel_ids.filter(|id| !channels_not_found.contains(id)) { + handler.on_client_unadvertise(id); + } + } + } + + fn on_advertise(&self, server: Arc, channels: Vec) { + if !server.capabilities.contains(&Capability::ClientPublish) { + self.send_error("Server does not support clientPublish capability".to_string()); + return; + } + + for channel in channels { + // Using a limited scope here to avoid holding the lock on advertised_channels while calling on_client_advertise + { + match self.advertised_channels.lock().entry(channel.id) { + Entry::Occupied(_) => { + self.send_warning(format!( + "Client is already advertising channel: {}; ignoring advertisement", + channel.id + )); + continue; } - } - if let Some(handler) = self.server_listener.as_ref() { - for id in channel_ids { - handler.on_client_unadvertise(id); + Entry::Vacant(entry) => { + entry.insert(channel.clone()); } } } - _ => { - tracing::error!("Unsupported message from {}: {message}", self.addr); - self.send_error(format!("Unsupported message: {message}")); + + if let Some(handler) = self.server_listener.as_ref() { + handler.on_client_advertise(&channel); } } } From 4cc637c93b9f44404b43e53fd4f806567bb9b876 Mon Sep 17 00:00:00 2001 From: Eloff Date: Wed, 5 Feb 2025 14:38:57 -0700 Subject: [PATCH 04/11] update test_client_advertising to check the listener callbacks were invoked --- rust/foxglove/src/tests/websocket_unstable.rs | 72 ++++++++--------- rust/foxglove/src/testutil.rs | 77 +++++++++++++++++++ rust/foxglove/src/websocket.rs | 26 ++++--- rust/foxglove/src/websocket/tests.rs | 3 +- 4 files changed, 125 insertions(+), 53 deletions(-) diff --git a/rust/foxglove/src/tests/websocket_unstable.rs b/rust/foxglove/src/tests/websocket_unstable.rs index 378194a9..b24d7cfc 100644 --- a/rust/foxglove/src/tests/websocket_unstable.rs +++ b/rust/foxglove/src/tests/websocket_unstable.rs @@ -1,52 +1,24 @@ use std::{collections::HashSet, sync::Arc}; +use crate::testutil::RecordingServerListener; +use crate::websocket::{ + create_server, Capability, ClientChannelId, Parameter, ParameterType, ParameterValue, + ServerOptions, SUBPROTOCOL, +}; use bytes::{BufMut, BytesMut}; use futures_util::{SinkExt, StreamExt}; use serde_json::{json, Value}; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio::time::sleep; use tokio_tungstenite::tungstenite::{client::IntoClientRequest, http::HeaderValue, Message}; -use crate::websocket::{ - create_server, Capability, ClientChannelId, Parameter, ParameterType, ParameterValue, - ServerListener, ServerOptions, SUBPROTOCOL, -}; - -struct ClientMessageData { - channel_id: ClientChannelId, - payload: Vec, -} - -struct MPSCServerListener(UnboundedSender); - -impl MPSCServerListener { - fn create() -> ( - Arc, - UnboundedReceiver, - ) { - let (tx, rx) = unbounded_channel(); - (Arc::new(Self(tx)), rx) - } -} - -impl ServerListener for MPSCServerListener { - fn on_message_data(&self, channel_id: ClientChannelId, message: &[u8]) { - self.0 - .send(ClientMessageData { - channel_id, - payload: message.to_vec(), - }) - .expect("MPSC queue closed"); - } -} - #[tokio::test] async fn test_client_advertising() { - let (listener, mut chan_rx) = MPSCServerListener::create(); + let recording_listener = Arc::new(RecordingServerListener::new()); let server = create_server(ServerOptions { capabilities: Some(HashSet::from([Capability::ClientPublish])), supported_encodings: Some(HashSet::from(["json".to_string()])), - listener: Some(listener), + listener: Some(recording_listener.clone()), session_id: None, name: None, message_backlog_size: None, @@ -74,7 +46,7 @@ async fn test_client_advertising() { .await .expect("Failed to send binary message"); // No message sent to listener - assert!(chan_rx.try_recv().is_err()); + assert!(recording_listener.message_data().is_empty()); let advertise = json!({ "op": "advertise", @@ -122,10 +94,30 @@ async fn test_client_advertising() { .await .expect("Failed to send unadvertise"); + // Should be ignored + ws_client + .send(Message::text(unadvertise.to_string())) + .await + .expect("Failed to send unadvertise"); + + // Give the server time to process the messages + sleep(std::time::Duration::from_millis(10)).await; + // Server should have received one message - let received = chan_rx.recv().await.expect("No message received"); - assert_eq!(received.channel_id, ClientChannelId::new(1)); - assert_eq!(received.payload, b"{\"a\":1}"); + let mut received = recording_listener.message_data(); + let (channel_id, payload) = received.pop().expect("No message received"); + assert_eq!(channel_id, ClientChannelId::new(1)); + assert_eq!(payload, b"{\"a\":1}"); + + // Server should have ignored the duplicate advertisement + let advertisements = recording_listener.client_advertise(); + assert_eq!(advertisements.len(), 1); + assert_eq!(advertisements[0].id, channel_id); + + // Server should have received one unadvertise (and ignored the duplicate) + let unadvertises = recording_listener.client_unadvertise(); + assert_eq!(unadvertises.len(), 1); + assert_eq!(unadvertises[0], channel_id); ws_client.close(None).await.unwrap(); server.stop().await; diff --git a/rust/foxglove/src/testutil.rs b/rust/foxglove/src/testutil.rs index 0d21bb56..1d6d2e20 100644 --- a/rust/foxglove/src/testutil.rs +++ b/rust/foxglove/src/testutil.rs @@ -3,5 +3,82 @@ mod log_context; mod log_sink; +use crate::channel::ChannelId; +use crate::websocket::{ClientChannel, ClientChannelId, ServerListener}; pub use log_context::GlobalContextTest; pub use log_sink::{ErrorSink, MockSink, RecordingSink}; +use parking_lot::Mutex; + +#[allow(dead_code)] +pub(crate) struct RecordingServerListener { + message_data: Mutex)>>, + subscribe: Mutex>, + unsubscribe: Mutex>, + client_advertise: Mutex>, + client_unadvertise: Mutex>, +} + +impl RecordingServerListener { + #[allow(dead_code)] + pub fn new() -> Self { + Self { + message_data: Mutex::new(Vec::new()), + subscribe: Mutex::new(Vec::new()), + unsubscribe: Mutex::new(Vec::new()), + client_advertise: Mutex::new(Vec::new()), + client_unadvertise: Mutex::new(Vec::new()), + } + } + + #[allow(dead_code)] + pub fn message_data(&self) -> Vec<(ClientChannelId, Vec)> { + std::mem::take(&mut self.message_data.lock()) + } + + #[allow(dead_code)] + pub fn subscribe(&self) -> Vec { + std::mem::take(&mut self.subscribe.lock()) + } + + #[allow(dead_code)] + pub fn unsubscribe(&self) -> Vec { + std::mem::take(&mut self.unsubscribe.lock()) + } + + #[allow(dead_code)] + pub fn client_advertise(&self) -> Vec { + std::mem::take(&mut self.client_advertise.lock()) + } + + #[allow(dead_code)] + pub fn client_unadvertise(&self) -> Vec { + std::mem::take(&mut self.client_unadvertise.lock()) + } +} + +impl ServerListener for RecordingServerListener { + fn on_message_data(&self, channel_id: ClientChannelId, payload: &[u8]) { + let mut data = self.message_data.lock(); + data.push((channel_id, payload.to_vec())); + } + + fn on_subscribe(&self, channel_id: ChannelId) { + let mut subs = self.subscribe.lock(); + subs.push(channel_id); + } + + fn on_unsubscribe(&self, channel_id: ChannelId) { + let mut unsubs = self.unsubscribe.lock(); + unsubs.push(channel_id); + } + + fn on_client_advertise(&self, channel: &ClientChannel) { + let mut adverts = self.client_advertise.lock(); + adverts.push(channel.clone()); + } + + fn on_client_unadvertise(&self, channel_id: ClientChannelId) { + let mut unadverts = self.client_unadvertise.lock(); + unadverts.push(channel_id); + } +} diff --git a/rust/foxglove/src/websocket.rs b/rust/foxglove/src/websocket.rs index e1a39898..934015f9 100644 --- a/rust/foxglove/src/websocket.rs +++ b/rust/foxglove/src/websocket.rs @@ -96,7 +96,7 @@ pub(crate) struct Server { weak_self: Weak, started: AtomicBool, message_backlog_size: u32, - pub runtime_handle: Handle, + pub(crate) runtime_handle: Handle, /// May be provided by the caller session_id: String, name: String, @@ -131,14 +131,14 @@ impl Default for Server { } } -/// Provides a mechanism for registering callbacks for -/// handling client message events. +/// Provides a mechanism for registering callbacks for handling client message events. +/// All methods are optional. pub trait ServerListener: Send + Sync { - fn on_message_data(&self, channel_id: ClientChannelId, payload: &[u8]); - fn on_subscribe(&self, channel_id: ChannelId); - fn on_unsubscribe(&self, channel_id: ChannelId); - fn on_client_advertise(&self, channel: &ClientChannel); - fn on_client_unadvertise(&self, channel_id: ClientChannelId); + fn on_message_data(&self, _channel_id: ClientChannelId, _payload: &[u8]) {} + fn on_subscribe(&self, _channel_id: ChannelId) {} + fn on_unsubscribe(&self, _channel_id: ChannelId) {} + fn on_client_advertise(&self, _channel: &ClientChannel) {} + fn on_client_unadvertise(&self, _channel_id: ClientChannelId) {} } /// State for a client maintained by the server @@ -214,7 +214,7 @@ impl ConnectedClient { let mut channels_not_found = Vec::new(); { let mut advertised_channels = self.advertised_channels.lock(); - for id in channel_ids.iter() { + for &id in channel_ids.iter() { if advertised_channels.remove(&id).is_none() { channels_not_found.push(id); self.send_warning(format!( @@ -225,7 +225,11 @@ impl ConnectedClient { } } if let Some(handler) = self.server_listener.as_ref() { - for id in channel_ids.filter(|id| !channels_not_found.contains(id)) { + for id in channel_ids + .iter() + .cloned() + .filter(|id| !channels_not_found.contains(id)) + { handler.on_client_unadvertise(id); } } @@ -757,7 +761,7 @@ impl LogSink for Server { let clients = self.clients.get(); for client in clients.iter() { let subscriptions = client.subscriptions.lock(); - let Some(subscription_id) = subscriptions.get(&channel.id).cloned() else { + let Some(subscription_id) = subscriptions.get_by_left(&channel.id).cloned() else { continue; }; diff --git a/rust/foxglove/src/websocket/tests.rs b/rust/foxglove/src/websocket/tests.rs index 6b2c4f82..a7ad5528 100644 --- a/rust/foxglove/src/websocket/tests.rs +++ b/rust/foxglove/src/websocket/tests.rs @@ -1,8 +1,7 @@ +use super::{send_lossy, SendLossyResult}; use assert_matches::assert_matches; use tokio_tungstenite::tungstenite::Message; -use super::{send_lossy, SendLossyResult}; - fn make_message(id: usize) -> Message { Message::Text(format!("{id}").into()) } From 40f16c37baf2238ac2fffd1c300696f2e6f3afef Mon Sep 17 00:00:00 2001 From: Eloff Date: Wed, 5 Feb 2025 15:38:41 -0700 Subject: [PATCH 05/11] modify tests to check for subscribe and unsubscribe callbacks --- rust/foxglove/src/tests/websocket.rs | 57 +++++++++++++++++-- rust/foxglove/src/tests/websocket_unstable.rs | 12 ++-- rust/foxglove/src/testutil.rs | 10 ++-- rust/foxglove/src/websocket.rs | 12 ++++ 4 files changed, 74 insertions(+), 17 deletions(-) diff --git a/rust/foxglove/src/tests/websocket.rs b/rust/foxglove/src/tests/websocket.rs index e5429b10..37641721 100644 --- a/rust/foxglove/src/tests/websocket.rs +++ b/rust/foxglove/src/tests/websocket.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use tokio_tungstenite::tungstenite::{self, http::HeaderValue, Message}; use tungstenite::client::IntoClientRequest; +use crate::testutil::RecordingServerListener; use crate::websocket::{create_server, ClientMessage, ServerOptions, SubscriptionId, SUBPROTOCOL}; use crate::{collection, Channel, ChannelBuilder, LogContext, LogSink, Metadata, Schema}; @@ -121,7 +122,12 @@ async fn test_handshake_with_multiple_subprotocols() { #[tokio::test] async fn test_advertise_to_client() { - let server = create_server(ServerOptions::default()); + let recording_listener = Arc::new(RecordingServerListener::new()); + + let server = create_server(ServerOptions { + listener: Some(recording_listener.clone()), + ..Default::default() + }); let ctx = LogContext::new(); ctx.add_sink(server.clone()); @@ -162,6 +168,11 @@ async fn test_advertise_to_client() { } ] }); + client_sender + .send(Message::text(subscribe.to_string())) + .await + .expect("Failed to send"); + // Send a duplicate subscribe message (ignored) client_sender .send(Message::text(subscribe.to_string())) .await @@ -169,12 +180,27 @@ async fn test_advertise_to_client() { // Allow the server to process the subscription // FG-9723: replace this with an on_subscribe callback + // (whoops, that won't work either, unless we do something like polling the recording_listener) tokio::time::sleep(std::time::Duration::from_millis(100)).await; - server.log(&ch, b"{\"a\":1}", &metadata).unwrap(); + server.log(&ch, b"foo bar baz", &metadata).unwrap(); - let result = client_receiver.next().await.unwrap(); - let msg = result.expect("Failed to parse message"); + // Ignore the warning message from the duplicate subscribe + let msg = client_receiver + .next() + .await + .unwrap() + .expect("Failed to parse message"); + let data = msg.into_data(); + let data_str = std::str::from_utf8(&data).unwrap(); + println!("data: {:?}", data_str); + assert!(data_str.contains("Client is already subscribed to channel")); + + let msg = client_receiver + .next() + .await + .unwrap() + .expect("Failed to parse message"); let data = msg.into_data(); assert_eq!(data[0], 0x01); // message data opcode @@ -183,12 +209,21 @@ async fn test_advertise_to_client() { subscription_id ); + let subscriptions = recording_listener.take_subscribe(); + assert_eq!(subscriptions.len(), 1); + assert_eq!(subscriptions[0], ch.id()); + server.stop().await; } #[tokio::test] async fn test_log_only_to_subscribers() { - let server = create_server(ServerOptions::default()); + let recording_listener = Arc::new(RecordingServerListener::new()); + + let server = create_server(ServerOptions { + listener: Some(recording_listener.clone()), + ..Default::default() + }); let ctx = LogContext::new(); @@ -280,6 +315,18 @@ async fn test_log_only_to_subscribers() { // FG-9723: replace this with an on_subscribe callback tokio::time::sleep(std::time::Duration::from_millis(100)).await; + let subscriptions = recording_listener.take_subscribe(); + assert_eq!(subscriptions.len(), 4); + assert_eq!(subscriptions[0], ch1.id()); + assert_eq!(subscriptions[1], ch2.id()); + assert_eq!(subscriptions[2], ch1.id()); + assert_eq!(subscriptions[3], ch2.id()); + + let unsubscriptions = recording_listener.take_unsubscribe(); + assert_eq!(unsubscriptions.len(), 2); + assert_eq!(unsubscriptions[0], ch1.id()); + assert_eq!(unsubscriptions[1], ch2.id()); + let metadata = Metadata { log_time: 123456, ..Metadata::default() diff --git a/rust/foxglove/src/tests/websocket_unstable.rs b/rust/foxglove/src/tests/websocket_unstable.rs index b24d7cfc..3abd9352 100644 --- a/rust/foxglove/src/tests/websocket_unstable.rs +++ b/rust/foxglove/src/tests/websocket_unstable.rs @@ -19,9 +19,7 @@ async fn test_client_advertising() { capabilities: Some(HashSet::from([Capability::ClientPublish])), supported_encodings: Some(HashSet::from(["json".to_string()])), listener: Some(recording_listener.clone()), - session_id: None, - name: None, - message_backlog_size: None, + ..Default::default() }); let addr = server @@ -46,7 +44,7 @@ async fn test_client_advertising() { .await .expect("Failed to send binary message"); // No message sent to listener - assert!(recording_listener.message_data().is_empty()); + assert!(recording_listener.take_message_data().is_empty()); let advertise = json!({ "op": "advertise", @@ -104,18 +102,18 @@ async fn test_client_advertising() { sleep(std::time::Duration::from_millis(10)).await; // Server should have received one message - let mut received = recording_listener.message_data(); + let mut received = recording_listener.take_message_data(); let (channel_id, payload) = received.pop().expect("No message received"); assert_eq!(channel_id, ClientChannelId::new(1)); assert_eq!(payload, b"{\"a\":1}"); // Server should have ignored the duplicate advertisement - let advertisements = recording_listener.client_advertise(); + let advertisements = recording_listener.take_client_advertise(); assert_eq!(advertisements.len(), 1); assert_eq!(advertisements[0].id, channel_id); // Server should have received one unadvertise (and ignored the duplicate) - let unadvertises = recording_listener.client_unadvertise(); + let unadvertises = recording_listener.take_client_unadvertise(); assert_eq!(unadvertises.len(), 1); assert_eq!(unadvertises[0], channel_id); diff --git a/rust/foxglove/src/testutil.rs b/rust/foxglove/src/testutil.rs index 1d6d2e20..de09b4b7 100644 --- a/rust/foxglove/src/testutil.rs +++ b/rust/foxglove/src/testutil.rs @@ -31,27 +31,27 @@ impl RecordingServerListener { } #[allow(dead_code)] - pub fn message_data(&self) -> Vec<(ClientChannelId, Vec)> { + pub fn take_message_data(&self) -> Vec<(ClientChannelId, Vec)> { std::mem::take(&mut self.message_data.lock()) } #[allow(dead_code)] - pub fn subscribe(&self) -> Vec { + pub fn take_subscribe(&self) -> Vec { std::mem::take(&mut self.subscribe.lock()) } #[allow(dead_code)] - pub fn unsubscribe(&self) -> Vec { + pub fn take_unsubscribe(&self) -> Vec { std::mem::take(&mut self.unsubscribe.lock()) } #[allow(dead_code)] - pub fn client_advertise(&self) -> Vec { + pub fn take_client_advertise(&self) -> Vec { std::mem::take(&mut self.client_advertise.lock()) } #[allow(dead_code)] - pub fn client_unadvertise(&self) -> Vec { + pub fn take_client_unadvertise(&self) -> Vec { std::mem::take(&mut self.client_unadvertise.lock()) } } diff --git a/rust/foxglove/src/websocket.rs b/rust/foxglove/src/websocket.rs index 934015f9..33f648aa 100644 --- a/rust/foxglove/src/websocket.rs +++ b/rust/foxglove/src/websocket.rs @@ -277,6 +277,10 @@ impl ConnectedClient { continue; } + println!( + "Client {} subscribed to channel {} with subscription id {}", + self.addr, subscription.channel_id, subscription.id + ); let mut subscriptions = self.subscriptions.lock(); if subscriptions .insert_no_overwrite(subscription.channel_id, subscription.id) @@ -297,6 +301,10 @@ impl ConnectedClient { continue; } + println!( + "Client {} subscribed to channel {} with subscription id {}", + self.addr, subscription.channel_id, subscription.id + ); tracing::info!( "Client {} subscribed to channel {} with subscription id {}", self.addr, @@ -761,10 +769,14 @@ impl LogSink for Server { let clients = self.clients.get(); for client in clients.iter() { let subscriptions = client.subscriptions.lock(); + println!("Subscriptions active: {}", subscriptions.len()); let Some(subscription_id) = subscriptions.get_by_left(&channel.id).cloned() else { + println!("No subscription found for channel {}", channel.id); continue; }; + println!("Sending message to client {}", client.addr); + // https://github.com/foxglove/ws-protocol/blob/main/docs/spec.md#message-data let header_size: usize = 1 + 4 + 8; let mut buf = BytesMut::with_capacity(header_size + msg.len()); From 0bae9c4d132edbc5fcea20993ab6051c4c7ec187 Mon Sep 17 00:00:00 2001 From: Eloff Date: Wed, 5 Feb 2025 15:39:34 -0700 Subject: [PATCH 06/11] remove debug printlns --- rust/foxglove/src/tests/websocket.rs | 1 - rust/foxglove/src/websocket.rs | 12 ------------ 2 files changed, 13 deletions(-) diff --git a/rust/foxglove/src/tests/websocket.rs b/rust/foxglove/src/tests/websocket.rs index 37641721..96371f55 100644 --- a/rust/foxglove/src/tests/websocket.rs +++ b/rust/foxglove/src/tests/websocket.rs @@ -193,7 +193,6 @@ async fn test_advertise_to_client() { .expect("Failed to parse message"); let data = msg.into_data(); let data_str = std::str::from_utf8(&data).unwrap(); - println!("data: {:?}", data_str); assert!(data_str.contains("Client is already subscribed to channel")); let msg = client_receiver diff --git a/rust/foxglove/src/websocket.rs b/rust/foxglove/src/websocket.rs index 33f648aa..934015f9 100644 --- a/rust/foxglove/src/websocket.rs +++ b/rust/foxglove/src/websocket.rs @@ -277,10 +277,6 @@ impl ConnectedClient { continue; } - println!( - "Client {} subscribed to channel {} with subscription id {}", - self.addr, subscription.channel_id, subscription.id - ); let mut subscriptions = self.subscriptions.lock(); if subscriptions .insert_no_overwrite(subscription.channel_id, subscription.id) @@ -301,10 +297,6 @@ impl ConnectedClient { continue; } - println!( - "Client {} subscribed to channel {} with subscription id {}", - self.addr, subscription.channel_id, subscription.id - ); tracing::info!( "Client {} subscribed to channel {} with subscription id {}", self.addr, @@ -769,14 +761,10 @@ impl LogSink for Server { let clients = self.clients.get(); for client in clients.iter() { let subscriptions = client.subscriptions.lock(); - println!("Subscriptions active: {}", subscriptions.len()); let Some(subscription_id) = subscriptions.get_by_left(&channel.id).cloned() else { - println!("No subscription found for channel {}", channel.id); continue; }; - println!("Sending message to client {}", client.addr); - // https://github.com/foxglove/ws-protocol/blob/main/docs/spec.md#message-data let header_size: usize = 1 + 4 + 8; let mut buf = BytesMut::with_capacity(header_size + msg.len()); From 0f4812e6e6e46334533cee54c21c0802fc20b067 Mon Sep 17 00:00:00 2001 From: Eloff Date: Wed, 5 Feb 2025 15:51:38 -0700 Subject: [PATCH 07/11] fix flakey concurrent test --- rust/foxglove/src/cow_vec.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rust/foxglove/src/cow_vec.rs b/rust/foxglove/src/cow_vec.rs index 1bef8ce9..92e9ac87 100644 --- a/rust/foxglove/src/cow_vec.rs +++ b/rust/foxglove/src/cow_vec.rs @@ -93,18 +93,16 @@ mod tests { let handle = thread::spawn(move || { vec2.push(4); - vec2.get().to_vec() }); vec.push(5); - let thread_state = handle.join().unwrap(); + handle.join().unwrap(); let final_state = vec.get().to_vec(); // Old snapshot should still be valid and have original length assert_eq!(old_snapshot.len(), 3); // Both threads should see 5 items in their final state assert_eq!(final_state.len(), 5); - assert_eq!(thread_state.len(), 5); } #[test] From 71f4c47d4f7d66f9f09df7464b8efe90a9c6dd93 Mon Sep 17 00:00:00 2001 From: Eloff Date: Thu, 6 Feb 2025 14:25:39 -0700 Subject: [PATCH 08/11] update comments around sleeps to point at FG-10395 --- rust/foxglove/src/cow_vec.rs | 2 +- rust/foxglove/src/tests/logging.rs | 5 +++-- rust/foxglove/src/tests/websocket.rs | 5 ++--- rust/foxglove/src/tests/websocket_unstable.rs | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/rust/foxglove/src/cow_vec.rs b/rust/foxglove/src/cow_vec.rs index 92e9ac87..d8ac27d4 100644 --- a/rust/foxglove/src/cow_vec.rs +++ b/rust/foxglove/src/cow_vec.rs @@ -101,7 +101,7 @@ mod tests { // Old snapshot should still be valid and have original length assert_eq!(old_snapshot.len(), 3); - // Both threads should see 5 items in their final state + // There should now be 5 items in the final state assert_eq!(final_state.len(), 5); } diff --git a/rust/foxglove/src/tests/logging.rs b/rust/foxglove/src/tests/logging.rs index 9087fbeb..a516f5ae 100644 --- a/rust/foxglove/src/tests/logging.rs +++ b/rust/foxglove/src/tests/logging.rs @@ -41,7 +41,7 @@ async fn test_logging_to_file_and_live_sinks() { ws_stream }; - // FG-9877: allow the server to handle the connection before creating the channel + // FG-10395 replace this with something more precise tokio::time::sleep(Duration::from_millis(100)).await; let channel = ChannelBuilder::new("/test-topic") @@ -103,6 +103,7 @@ async fn test_logging_to_file_and_live_sinks() { .expect("Failed to subscribe"); // Let subscription register before publishing + // FG-10395 replace this with something more precise tokio::time::sleep(Duration::from_millis(100)).await; } @@ -122,7 +123,7 @@ async fn test_logging_to_file_and_live_sinks() { channel.log(&msg); - // Ensure message has arrived + // FG-10395 replace this with something more precise tokio::time::sleep(Duration::from_millis(100)).await; let writer = handle.close().expect("Failed to flush log"); diff --git a/rust/foxglove/src/tests/websocket.rs b/rust/foxglove/src/tests/websocket.rs index 96371f55..c2acc04f 100644 --- a/rust/foxglove/src/tests/websocket.rs +++ b/rust/foxglove/src/tests/websocket.rs @@ -179,8 +179,7 @@ async fn test_advertise_to_client() { .expect("Failed to send"); // Allow the server to process the subscription - // FG-9723: replace this with an on_subscribe callback - // (whoops, that won't work either, unless we do something like polling the recording_listener) + // FG-10395 replace this with something more precise tokio::time::sleep(std::time::Duration::from_millis(100)).await; server.log(&ch, b"foo bar baz", &metadata).unwrap(); @@ -311,7 +310,7 @@ async fn test_log_only_to_subscribers() { .expect("Failed to send"); // Allow the server to process the subscription - // FG-9723: replace this with an on_subscribe callback + // FG-10395 replace this with something more precise tokio::time::sleep(std::time::Duration::from_millis(100)).await; let subscriptions = recording_listener.take_subscribe(); diff --git a/rust/foxglove/src/tests/websocket_unstable.rs b/rust/foxglove/src/tests/websocket_unstable.rs index 3abd9352..dea738bb 100644 --- a/rust/foxglove/src/tests/websocket_unstable.rs +++ b/rust/foxglove/src/tests/websocket_unstable.rs @@ -98,7 +98,7 @@ async fn test_client_advertising() { .await .expect("Failed to send unadvertise"); - // Give the server time to process the messages + // FG-10395 replace this with something more precise sleep(std::time::Duration::from_millis(10)).await; // Server should have received one message @@ -147,7 +147,7 @@ async fn test_parameter_values() { }; server.publish_parameter_values(vec![parameter], None).await; - // Allow the server to process the parameter values + // FG-10395 replace this with something more precise std::thread::sleep(std::time::Duration::from_millis(100)); let msg = ws_client.next().await.expect("No message received"); From b2e70f5293f1163269a0f664a73f192222a69aa5 Mon Sep 17 00:00:00 2001 From: Eloff Date: Thu, 6 Feb 2025 15:52:59 -0700 Subject: [PATCH 09/11] rework handler callbacks to accept client id (and add a client id), rework invoking code to not hold locks while invoking callbacks, pass Channel refs instead of channel ids --- rust/foxglove/src/tests/websocket.rs | 14 +- rust/foxglove/src/testutil.rs | 49 +++--- rust/foxglove/src/websocket.rs | 251 ++++++++++++++++----------- 3 files changed, 183 insertions(+), 131 deletions(-) diff --git a/rust/foxglove/src/tests/websocket.rs b/rust/foxglove/src/tests/websocket.rs index c2acc04f..03ac6703 100644 --- a/rust/foxglove/src/tests/websocket.rs +++ b/rust/foxglove/src/tests/websocket.rs @@ -209,7 +209,7 @@ async fn test_advertise_to_client() { let subscriptions = recording_listener.take_subscribe(); assert_eq!(subscriptions.len(), 1); - assert_eq!(subscriptions[0], ch.id()); + assert!(Arc::ptr_eq(&subscriptions[0].1, &ch)); server.stop().await; } @@ -315,15 +315,15 @@ async fn test_log_only_to_subscribers() { let subscriptions = recording_listener.take_subscribe(); assert_eq!(subscriptions.len(), 4); - assert_eq!(subscriptions[0], ch1.id()); - assert_eq!(subscriptions[1], ch2.id()); - assert_eq!(subscriptions[2], ch1.id()); - assert_eq!(subscriptions[3], ch2.id()); + assert!(Arc::ptr_eq(&subscriptions[0].1, &ch1)); + assert!(Arc::ptr_eq(&subscriptions[1].1, &ch2)); + assert!(Arc::ptr_eq(&subscriptions[2].1, &ch1)); + assert!(Arc::ptr_eq(&subscriptions[3].1, &ch2)); let unsubscriptions = recording_listener.take_unsubscribe(); assert_eq!(unsubscriptions.len(), 2); - assert_eq!(unsubscriptions[0], ch1.id()); - assert_eq!(unsubscriptions[1], ch2.id()); + assert!(Arc::ptr_eq(&unsubscriptions[0].1, &ch1)); + assert!(Arc::ptr_eq(&unsubscriptions[1].1, &ch2)); let metadata = Metadata { log_time: 123456, diff --git a/rust/foxglove/src/testutil.rs b/rust/foxglove/src/testutil.rs index de09b4b7..017c91bf 100644 --- a/rust/foxglove/src/testutil.rs +++ b/rust/foxglove/src/testutil.rs @@ -3,23 +3,22 @@ mod log_context; mod log_sink; -use crate::channel::ChannelId; -use crate::websocket::{ClientChannel, ClientChannelId, ServerListener}; +use crate::websocket::{ClientChannel, ClientChannelId, ClientId, ServerListener}; +use crate::Channel; pub use log_context::GlobalContextTest; pub use log_sink::{ErrorSink, MockSink, RecordingSink}; use parking_lot::Mutex; +use std::sync::Arc; -#[allow(dead_code)] pub(crate) struct RecordingServerListener { - message_data: Mutex)>>, - subscribe: Mutex>, - unsubscribe: Mutex>, - client_advertise: Mutex>, - client_unadvertise: Mutex>, + message_data: Mutex)>>, + subscribe: Mutex)>>, + unsubscribe: Mutex)>>, + client_advertise: Mutex>, + client_unadvertise: Mutex>, } impl RecordingServerListener { - #[allow(dead_code)] pub fn new() -> Self { Self { message_data: Mutex::new(Vec::new()), @@ -31,54 +30,52 @@ impl RecordingServerListener { } #[allow(dead_code)] - pub fn take_message_data(&self) -> Vec<(ClientChannelId, Vec)> { + pub fn take_message_data(&self) -> Vec<(ClientId, ClientChannelId, Vec)> { std::mem::take(&mut self.message_data.lock()) } - #[allow(dead_code)] - pub fn take_subscribe(&self) -> Vec { + pub fn take_subscribe(&self) -> Vec<(ClientId, Arc)> { std::mem::take(&mut self.subscribe.lock()) } - #[allow(dead_code)] - pub fn take_unsubscribe(&self) -> Vec { + pub fn take_unsubscribe(&self) -> Vec<(ClientId, Arc)> { std::mem::take(&mut self.unsubscribe.lock()) } #[allow(dead_code)] - pub fn take_client_advertise(&self) -> Vec { + pub fn take_client_advertise(&self) -> Vec<(ClientId, ClientChannel)> { std::mem::take(&mut self.client_advertise.lock()) } #[allow(dead_code)] - pub fn take_client_unadvertise(&self) -> Vec { + pub fn take_client_unadvertise(&self) -> Vec<(ClientId, ClientChannelId)> { std::mem::take(&mut self.client_unadvertise.lock()) } } impl ServerListener for RecordingServerListener { - fn on_message_data(&self, channel_id: ClientChannelId, payload: &[u8]) { + fn on_message_data(&self, client_id: ClientId, channel_id: ClientChannelId, payload: &[u8]) { let mut data = self.message_data.lock(); - data.push((channel_id, payload.to_vec())); + data.push((client_id, channel_id, payload.to_vec())); } - fn on_subscribe(&self, channel_id: ChannelId) { + fn on_subscribe(&self, client_id: ClientId, channel: Arc) { let mut subs = self.subscribe.lock(); - subs.push(channel_id); + subs.push((client_id, channel)); } - fn on_unsubscribe(&self, channel_id: ChannelId) { + fn on_unsubscribe(&self, client_id: ClientId, channel: Arc) { let mut unsubs = self.unsubscribe.lock(); - unsubs.push(channel_id); + unsubs.push((client_id, channel)); } - fn on_client_advertise(&self, channel: &ClientChannel) { + fn on_client_advertise(&self, client_id: ClientId, channel: &ClientChannel) { let mut adverts = self.client_advertise.lock(); - adverts.push(channel.clone()); + adverts.push((client_id, channel.clone())); } - fn on_client_unadvertise(&self, channel_id: ClientChannelId) { + fn on_client_unadvertise(&self, client_id: ClientId, channel_id: ClientChannelId) { let mut unadverts = self.client_unadvertise.lock(); - unadverts.push(channel_id); + unadverts.push((client_id, channel_id)); } } diff --git a/rust/foxglove/src/websocket.rs b/rust/foxglove/src/websocket.rs index efe73ef6..aff7c6b1 100644 --- a/rust/foxglove/src/websocket.rs +++ b/rust/foxglove/src/websocket.rs @@ -14,8 +14,8 @@ use flume::TrySendError; use futures_util::{stream::SplitSink, SinkExt, StreamExt}; use std::collections::hash_map::Entry; use std::collections::HashSet; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering::{AcqRel, Acquire}; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed}; +use std::sync::atomic::{AtomicBool, AtomicU32}; use std::sync::{OnceLock, Weak}; use std::{collections::HashMap, net::SocketAddr, sync::Arc}; use thiserror::Error; @@ -34,6 +34,9 @@ mod protocol; #[cfg(test)] mod tests; +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct ClientId(u32); + pub(crate) const SUBPROTOCOL: &str = "foxglove.sdk.v1"; type WebsocketSender = SplitSink, Message>; @@ -135,19 +138,21 @@ impl Default for Server { /// All methods are optional. pub trait ServerListener: Send + Sync { /// Callback invoked when a client message is received. - fn on_message_data(&self, _channel_id: ClientChannelId, _payload: &[u8]) {} + fn on_message_data(&self, _client_id: ClientId, _channel_id: ClientChannelId, _payload: &[u8]) { + } /// Callback invoked when a client subscribes to a channel. - fn on_subscribe(&self, _channel_id: ChannelId) {} + fn on_subscribe(&self, _client_id: ClientId, _channel_id: Arc) {} /// Callback invoked when a client unsubscribes from a channel. - fn on_unsubscribe(&self, _channel_id: ChannelId) {} + fn on_unsubscribe(&self, _client_id: ClientId, _channel_id: Arc) {} /// Callback invoked when a client advertises a client channel. Requires the "clientPublish" capability. - fn on_client_advertise(&self, _channel: &ClientChannel) {} + fn on_client_advertise(&self, _client_id: ClientId, _channel: &ClientChannel) {} /// Callback invoked when a client unadvertises a client channel. Requires the "clientPublish" capability. - fn on_client_unadvertise(&self, _channel_id: ClientChannelId) {} + fn on_client_unadvertise(&self, _client_id: ClientId, _channel_id: ClientChannelId) {} } /// State for a client maintained by the server struct ConnectedClient { + id: ClientId, addr: SocketAddr, /// Write side of a WS stream sender: Mutex, @@ -192,14 +197,7 @@ impl ConnectedClient { self.on_subscribe(server, subscriptions); } Ok(ClientMessage::Unsubscribe { subscription_ids }) => { - let mut subscriptions = self.subscriptions.lock(); - for subscription_id in subscription_ids { - if let Some((channel_id, _)) = subscriptions.remove_by_right(&subscription_id) { - if let Some(handler) = self.server_listener.as_ref() { - handler.on_unsubscribe(channel_id); - } - } - } + self.on_unsubscribe(server, subscription_ids); } Ok(ClientMessage::Advertise { channels }) => { self.on_advertise(server, channels); @@ -214,28 +212,74 @@ impl ConnectedClient { } } - fn on_unadvertise(&self, channel_ids: Vec) { + fn handle_binary_message(&self, message: Message) { + if message.is_empty() { + tracing::debug!("Received empty binary message from {}", self.addr); + return; + } + + let msg_bytes = message.into_data(); + let opcode = protocol::client::BinaryOpcode::from_repr(msg_bytes[0]); + match opcode { + Some(protocol::client::BinaryOpcode::MessageData) => { + match protocol::client::parse_binary_message(&msg_bytes) { + Ok((channel_id, payload)) => { + { + let advertised_channels = self.advertised_channels.lock(); + if !advertised_channels.contains_key(&channel_id) { + tracing::error!( + "Received message for unknown channel: {}", + channel_id + ); + self.send_error(format!("Unknown channel ID: {}", channel_id)); + // Do not forward to server listener + return; + } + } + // Call the handler after releasing the advertised_channels lock + if let Some(handler) = self.server_listener.as_ref() { + handler.on_message_data(self.id, channel_id, payload); + } + } + Err(err) => { + tracing::error!("Failed to parse binary message: {err}"); + self.send_error(format!("Failed to parse binary message: {err}")); + } + } + } + Some(_) => { + tracing::error!("Opcode not yet implemented: {}", msg_bytes[0]); + } + None => { + tracing::error!("Invalid binary opcode: {}", msg_bytes[0]); + self.send_error(format!("Invalid binary opcode: {}", msg_bytes[0])); + } + } + } + + fn on_unadvertise(&self, mut channel_ids: Vec) { // Using a limited scope and iterating twice to avoid holding the lock on advertised_channels while calling on_client_unadvertise - let mut channels_not_found = Vec::new(); { let mut advertised_channels = self.advertised_channels.lock(); - for &id in channel_ids.iter() { + let mut i = 0; + while i < channel_ids.len() { + let id = channel_ids[i]; if advertised_channels.remove(&id).is_none() { - channels_not_found.push(id); + // Remove the channel ID from the list so we don't invoke the on_client_unadvertise callback + channel_ids.swap_remove(i); self.send_warning(format!( "Client is not advertising channel: {}; ignoring unadvertisement", id )); + continue; } + i += 1; } } + // Call the handler after releasing the advertised_channels lock if let Some(handler) = self.server_listener.as_ref() { - for id in channel_ids - .iter() - .cloned() - .filter(|id| !channels_not_found.contains(id)) - { - handler.on_client_unadvertise(id); + for id in channel_ids { + handler.on_client_unadvertise(self.id, id); } } } @@ -263,43 +307,94 @@ impl ConnectedClient { } } + // Call the handler after releasing the advertised_channels lock if let Some(handler) = self.server_listener.as_ref() { - handler.on_client_advertise(&channel); + handler.on_client_advertise(self.id, &channel); } } } - fn on_subscribe(&self, server: Arc, subscriptions: Vec) { - let channels = server.channels.read(); - for subscription in subscriptions { - if !channels.contains_key(&subscription.channel_id) { - tracing::error!( - "Client {} attempted to subscribe to unknown channel: {}", - self.addr, - subscription.channel_id - ); - self.send_error(format!("Unknown channel ID: {}", subscription.channel_id)); - continue; + fn on_unsubscribe(&self, server: Arc, subscription_ids: Vec) { + let mut unsubscribed_channel_ids = Vec::with_capacity(subscription_ids.len()); + // First gather the unsubscribed channel ids while holding the subscriptions lock + { + let mut subscriptions = self.subscriptions.lock(); + for subscription_id in subscription_ids { + if let Some((channel_id, _)) = subscriptions.remove_by_right(&subscription_id) { + unsubscribed_channel_ids.push(channel_id); + } } + } - let mut subscriptions = self.subscriptions.lock(); - if subscriptions - .insert_no_overwrite(subscription.channel_id, subscription.id) - .is_err() - { - if subscriptions.contains_left(&subscription.channel_id) { - self.send_warning(format!( - "Client is already subscribed to channel: {}; ignoring subscription", + // If we don't have a ServerListener, we're done. + let Some(handler) = self.server_listener.as_ref() else { + return; + }; + + // Then gather the actual channel references while holding the channels lock + let mut unsubscribed_channels = Vec::with_capacity(unsubscribed_channel_ids.len()); + { + let channels = server.channels.read(); + for channel_id in unsubscribed_channel_ids { + if let Some(channel) = channels.get(&channel_id) { + unsubscribed_channels.push(channel.clone()); + } + } + } + + // Finally call the handler for each channel + for channel in unsubscribed_channels { + handler.on_unsubscribe(self.id, channel); + } + } + + fn on_subscribe(&self, server: Arc, mut subscriptions: Vec) { + // First prune out any subscriptions for channels not in the channel map, + // limiting how long we need to hold the lock. + let mut subscribed_channels = Vec::with_capacity(subscriptions.len()); + { + let channels = server.channels.read(); + let mut i = 0; + while i < subscriptions.len() { + let subscription = &subscriptions[i]; + let Some(channel) = channels.get(&subscription.channel_id) else { + tracing::error!( + "Client {} attempted to subscribe to unknown channel: {}", + self.addr, subscription.channel_id - )); - } else { - assert!(subscriptions.contains_right(&subscription.id)); - self.send_error(format!( - "Subscription ID was already used: {}; ignoring subscription", - subscription.id - )); + ); + self.send_error(format!("Unknown channel ID: {}", subscription.channel_id)); + // Remove the subscription from the list so we don't invoke the on_subscribe callback for it + subscriptions.swap_remove(i); + continue; + }; + subscribed_channels.push(channel.clone()); + i += 1 + } + } + + for (subscription, channel) in subscriptions.into_iter().zip(subscribed_channels) { + // Using a limited scope here to avoid holding the lock on subscriptions while calling on_subscribe + { + let mut subscriptions = self.subscriptions.lock(); + if subscriptions + .insert_no_overwrite(subscription.channel_id, subscription.id) + .is_err() + { + if subscriptions.contains_left(&subscription.channel_id) { + self.send_warning(format!( + "Client is already subscribed to channel: {}; ignoring subscription", + subscription.channel_id + )); + } else { + assert!(subscriptions.contains_right(&subscription.id)); + self.send_error(format!( + "Subscription ID was already used: {}; ignoring subscription", + subscription.id + )); + } + continue; } - continue; } tracing::info!( @@ -309,7 +404,7 @@ impl ConnectedClient { subscription.id ); if let Some(handler) = self.server_listener.as_ref() { - handler.on_subscribe(subscription.channel_id); + handler.on_subscribe(self.id, channel); } } } @@ -337,50 +432,6 @@ impl ConnectedClient { tracing::warn!("Failed to send status to client {}: {err}", self.addr) }); } - - fn handle_binary_message(&self, message: Message) { - if message.is_empty() { - tracing::debug!("Received empty binary message from {}", self.addr); - return; - } - - let msg_bytes = message.into_data(); - let opcode = protocol::client::BinaryOpcode::from_repr(msg_bytes[0]); - match opcode { - Some(protocol::client::BinaryOpcode::MessageData) => { - match protocol::client::parse_binary_message(&msg_bytes) { - Ok((channel_id, payload)) => { - { - let advertised_channels = self.advertised_channels.lock(); - if !advertised_channels.contains_key(&channel_id) { - tracing::error!( - "Received message for unknown channel: {}", - channel_id - ); - self.send_error(format!("Unknown channel ID: {}", channel_id)); - // Do not forward to server listener - return; - } - } - if let Some(handler) = self.server_listener.as_ref() { - handler.on_message_data(channel_id, payload); - } - } - Err(err) => { - tracing::error!("Failed to parse binary message: {err}"); - self.send_error(format!("Failed to parse binary message: {err}")); - } - } - } - Some(_) => { - tracing::error!("Opcode not yet implemented: {}", msg_bytes[0]); - } - None => { - tracing::error!("Invalid binary opcode: {}", msg_bytes[0]); - self.send_error(format!("Invalid binary opcode: {}", msg_bytes[0])); - } - } - } } // A websocket server that implements the Foxglove WebSocket Protocol @@ -588,10 +639,14 @@ impl Server { return; } + static CLIENT_ID: AtomicU32 = AtomicU32::new(1); + let id = ClientId(CLIENT_ID.fetch_add(1, Relaxed)); + let (data_tx, data_rx) = flume::bounded(self.message_backlog_size as usize); let (ctrl_tx, ctrl_rx) = flume::bounded(DEFAULT_CONTROL_PLANE_BACKLOG_SIZE); let new_client = Arc::new(ConnectedClient { + id, addr, sender: Mutex::new(ws_sender), data_plane_tx: data_tx, From 5732c56e4654d514df242618d1a9773a9165f18b Mon Sep 17 00:00:00 2001 From: Eloff Date: Thu, 6 Feb 2025 16:27:57 -0700 Subject: [PATCH 10/11] tweak docs a little and fix unstable tests --- rust/foxglove/examples/unstable/client-publish.rs | 5 +++-- rust/foxglove/src/lib.rs | 2 +- rust/foxglove/src/tests/websocket_unstable.rs | 6 +++--- rust/foxglove/src/websocket.rs | 3 +++ 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/rust/foxglove/examples/unstable/client-publish.rs b/rust/foxglove/examples/unstable/client-publish.rs index 3020e58c..c3ebdcc8 100644 --- a/rust/foxglove/examples/unstable/client-publish.rs +++ b/rust/foxglove/examples/unstable/client-publish.rs @@ -11,7 +11,8 @@ use clap::Parser; use foxglove::schemas::log::Level; use foxglove::schemas::Log; use foxglove::{ - Capability, ClientChannelId, PartialMetadata, ServerListener, TypedChannel, WebSocketServer, + Capability, ClientChannelId, ClientId, PartialMetadata, ServerListener, TypedChannel, + WebSocketServer, }; use std::sync::Arc; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; @@ -29,7 +30,7 @@ struct Cli { struct ExampleCallbackHandler; impl ServerListener for ExampleCallbackHandler { - fn on_message_data(&self, channel_id: ClientChannelId, message: &[u8]) { + fn on_message_data(&self, _client_id: ClientId, channel_id: ClientChannelId, message: &[u8]) { let json: serde_json::Value = serde_json::from_slice(message).expect("Failed to parse message"); println!("Received message on channel {channel_id}: {json}"); diff --git a/rust/foxglove/src/lib.rs b/rust/foxglove/src/lib.rs index 1cf58bc5..be3e57c1 100644 --- a/rust/foxglove/src/lib.rs +++ b/rust/foxglove/src/lib.rs @@ -194,7 +194,7 @@ pub(crate) use time::nanoseconds_since_epoch; #[doc(hidden)] #[cfg(feature = "unstable")] pub use websocket::{ - Capability, ClientChannelId, Parameter, ParameterType, ParameterValue, ServerListener, + Capability, ClientChannelId, ClientId, Parameter, ParameterType, ParameterValue, ServerListener, }; pub use websocket_server::{WebSocketServer, WebSocketServerHandle}; diff --git a/rust/foxglove/src/tests/websocket_unstable.rs b/rust/foxglove/src/tests/websocket_unstable.rs index dea738bb..653de39d 100644 --- a/rust/foxglove/src/tests/websocket_unstable.rs +++ b/rust/foxglove/src/tests/websocket_unstable.rs @@ -103,19 +103,19 @@ async fn test_client_advertising() { // Server should have received one message let mut received = recording_listener.take_message_data(); - let (channel_id, payload) = received.pop().expect("No message received"); + let (_, channel_id, payload) = received.pop().expect("No message received"); assert_eq!(channel_id, ClientChannelId::new(1)); assert_eq!(payload, b"{\"a\":1}"); // Server should have ignored the duplicate advertisement let advertisements = recording_listener.take_client_advertise(); assert_eq!(advertisements.len(), 1); - assert_eq!(advertisements[0].id, channel_id); + assert_eq!(advertisements[0].1.id, channel_id); // Server should have received one unadvertise (and ignored the duplicate) let unadvertises = recording_listener.take_client_unadvertise(); assert_eq!(unadvertises.len(), 1); - assert_eq!(unadvertises[0], channel_id); + assert_eq!(unadvertises[0].1, channel_id); ws_client.close(None).await.unwrap(); server.stop().await; diff --git a/rust/foxglove/src/websocket.rs b/rust/foxglove/src/websocket.rs index 3a026615..3e60abbf 100644 --- a/rust/foxglove/src/websocket.rs +++ b/rust/foxglove/src/websocket.rs @@ -34,6 +34,7 @@ mod protocol; #[cfg(test)] mod tests; +/// An arbitrary integer unique identifier for a client connection. #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct ClientId(u32); @@ -104,8 +105,10 @@ pub trait ServerListener: Send + Sync { fn on_message_data(&self, _client_id: ClientId, _channel_id: ClientChannelId, _payload: &[u8]) { } /// Callback invoked when a client subscribes to a channel. + /// Only invoked if the channel is associated with the server and isn't already subscribed to by the client. fn on_subscribe(&self, _client_id: ClientId, _channel_id: Arc) {} /// Callback invoked when a client unsubscribes from a channel. + /// Only invoked for channels that had an active subscription from the client. fn on_unsubscribe(&self, _client_id: ClientId, _channel_id: Arc) {} /// Callback invoked when a client advertises a client channel. Requires the "clientPublish" capability. fn on_client_advertise(&self, _client_id: ClientId, _channel: &ClientChannel) {} From aa999f8fe322a05ff9c3c7774d4bb1bb005be0de Mon Sep 17 00:00:00 2001 From: Eloff Date: Fri, 7 Feb 2025 10:48:07 -0700 Subject: [PATCH 11/11] use wrapped view types in the ServerListener callbacks --- .../examples/unstable/client-publish.rs | 6 +- rust/foxglove/src/lib.rs | 3 +- rust/foxglove/src/tests/websocket.rs | 21 ++- rust/foxglove/src/tests/websocket_unstable.rs | 8 +- rust/foxglove/src/testutil.rs | 76 ++++++--- rust/foxglove/src/websocket.rs | 145 +++++++++++++++--- 6 files changed, 198 insertions(+), 61 deletions(-) diff --git a/rust/foxglove/examples/unstable/client-publish.rs b/rust/foxglove/examples/unstable/client-publish.rs index c3ebdcc8..f07d1fe8 100644 --- a/rust/foxglove/examples/unstable/client-publish.rs +++ b/rust/foxglove/examples/unstable/client-publish.rs @@ -11,7 +11,7 @@ use clap::Parser; use foxglove::schemas::log::Level; use foxglove::schemas::Log; use foxglove::{ - Capability, ClientChannelId, ClientId, PartialMetadata, ServerListener, TypedChannel, + Capability, Client, ClientChannelView, PartialMetadata, ServerListener, TypedChannel, WebSocketServer, }; use std::sync::Arc; @@ -30,10 +30,10 @@ struct Cli { struct ExampleCallbackHandler; impl ServerListener for ExampleCallbackHandler { - fn on_message_data(&self, _client_id: ClientId, channel_id: ClientChannelId, message: &[u8]) { + fn on_message_data(&self, _client: Client, channel: ClientChannelView, message: &[u8]) { let json: serde_json::Value = serde_json::from_slice(message).expect("Failed to parse message"); - println!("Received message on channel {channel_id}: {json}"); + println!("Received message on channel {}: {json}", channel.id()); } } diff --git a/rust/foxglove/src/lib.rs b/rust/foxglove/src/lib.rs index d1e2712a..ef27ae08 100644 --- a/rust/foxglove/src/lib.rs +++ b/rust/foxglove/src/lib.rs @@ -193,8 +193,9 @@ pub use runtime::shutdown_runtime; pub(crate) use time::nanoseconds_since_epoch; #[doc(hidden)] #[cfg(feature = "unstable")] +pub use websocket::{Capability, Parameter, ParameterType, ParameterValue}; pub use websocket::{ - Capability, ClientChannelId, ClientId, Parameter, ParameterType, ParameterValue, ServerListener, + ChannelView, Client, ClientChannelId, ClientChannelView, ClientId, ServerListener, }; pub use websocket_server::{WebSocketServer, WebSocketServerBlockingHandle, WebSocketServerHandle}; diff --git a/rust/foxglove/src/tests/websocket.rs b/rust/foxglove/src/tests/websocket.rs index 03ac6703..baec7e1e 100644 --- a/rust/foxglove/src/tests/websocket.rs +++ b/rust/foxglove/src/tests/websocket.rs @@ -209,7 +209,8 @@ async fn test_advertise_to_client() { let subscriptions = recording_listener.take_subscribe(); assert_eq!(subscriptions.len(), 1); - assert!(Arc::ptr_eq(&subscriptions[0].1, &ch)); + assert_eq!(subscriptions[0].1.id, ch.id); + assert_eq!(subscriptions[0].1.topic, ch.topic); server.stop().await; } @@ -315,15 +316,21 @@ async fn test_log_only_to_subscribers() { let subscriptions = recording_listener.take_subscribe(); assert_eq!(subscriptions.len(), 4); - assert!(Arc::ptr_eq(&subscriptions[0].1, &ch1)); - assert!(Arc::ptr_eq(&subscriptions[1].1, &ch2)); - assert!(Arc::ptr_eq(&subscriptions[2].1, &ch1)); - assert!(Arc::ptr_eq(&subscriptions[3].1, &ch2)); + assert_eq!(subscriptions[0].1.id, ch1.id); + assert_eq!(subscriptions[1].1.id, ch2.id); + assert_eq!(subscriptions[2].1.id, ch1.id); + assert_eq!(subscriptions[3].1.id, ch2.id); + assert_eq!(subscriptions[0].1.topic, ch1.topic); + assert_eq!(subscriptions[1].1.topic, ch2.topic); + assert_eq!(subscriptions[2].1.topic, ch1.topic); + assert_eq!(subscriptions[3].1.topic, ch2.topic); let unsubscriptions = recording_listener.take_unsubscribe(); assert_eq!(unsubscriptions.len(), 2); - assert!(Arc::ptr_eq(&unsubscriptions[0].1, &ch1)); - assert!(Arc::ptr_eq(&unsubscriptions[1].1, &ch2)); + assert_eq!(unsubscriptions[0].1.id, ch1.id); + assert_eq!(unsubscriptions[1].1.id, ch2.id); + assert_eq!(unsubscriptions[0].1.topic, ch1.topic); + assert_eq!(unsubscriptions[1].1.topic, ch2.topic); let metadata = Metadata { log_time: 123456, diff --git a/rust/foxglove/src/tests/websocket_unstable.rs b/rust/foxglove/src/tests/websocket_unstable.rs index 104413fc..b4d3758f 100644 --- a/rust/foxglove/src/tests/websocket_unstable.rs +++ b/rust/foxglove/src/tests/websocket_unstable.rs @@ -103,19 +103,19 @@ async fn test_client_advertising() { // Server should have received one message let mut received = recording_listener.take_message_data(); - let (_, channel_id, payload) = received.pop().expect("No message received"); - assert_eq!(channel_id, ClientChannelId::new(1)); + let (_, channel_info, payload) = received.pop().expect("No message received"); + assert_eq!(channel_info.id, ClientChannelId::new(1)); assert_eq!(payload, b"{\"a\":1}"); // Server should have ignored the duplicate advertisement let advertisements = recording_listener.take_client_advertise(); assert_eq!(advertisements.len(), 1); - assert_eq!(advertisements[0].1.id, channel_id); + assert_eq!(advertisements[0].1.id, channel_info.id); // Server should have received one unadvertise (and ignored the duplicate) let unadvertises = recording_listener.take_client_unadvertise(); assert_eq!(unadvertises.len(), 1); - assert_eq!(unadvertises[0].1, channel_id); + assert_eq!(unadvertises[0].1.id, channel_info.id); ws_client.close(None).await.unwrap(); server.stop().await; diff --git a/rust/foxglove/src/testutil.rs b/rust/foxglove/src/testutil.rs index 017c91bf..212008da 100644 --- a/rust/foxglove/src/testutil.rs +++ b/rust/foxglove/src/testutil.rs @@ -3,19 +3,49 @@ mod log_context; mod log_sink; -use crate::websocket::{ClientChannel, ClientChannelId, ClientId, ServerListener}; -use crate::Channel; +use crate::channel::ChannelId; +use crate::websocket::{ + ChannelView, Client, ClientChannelId, ClientChannelView, ClientId, ServerListener, +}; pub use log_context::GlobalContextTest; pub use log_sink::{ErrorSink, MockSink, RecordingSink}; use parking_lot::Mutex; -use std::sync::Arc; + +#[allow(dead_code)] +pub(crate) struct ClientChannelInfo { + pub(crate) id: ClientChannelId, + pub(crate) topic: String, +} + +impl From> for ClientChannelInfo { + fn from(channel: ClientChannelView) -> Self { + Self { + id: channel.id(), + topic: channel.topic().to_string(), + } + } +} + +pub(crate) struct ChannelInfo { + pub(crate) id: ChannelId, + pub(crate) topic: String, +} + +impl From> for ChannelInfo { + fn from(channel: ChannelView) -> Self { + Self { + id: channel.id(), + topic: channel.topic().to_string(), + } + } +} pub(crate) struct RecordingServerListener { - message_data: Mutex)>>, - subscribe: Mutex)>>, - unsubscribe: Mutex)>>, - client_advertise: Mutex>, - client_unadvertise: Mutex>, + message_data: Mutex)>>, + subscribe: Mutex>, + unsubscribe: Mutex>, + client_advertise: Mutex>, + client_unadvertise: Mutex>, } impl RecordingServerListener { @@ -30,52 +60,52 @@ impl RecordingServerListener { } #[allow(dead_code)] - pub fn take_message_data(&self) -> Vec<(ClientId, ClientChannelId, Vec)> { + pub fn take_message_data(&self) -> Vec<(ClientId, ClientChannelInfo, Vec)> { std::mem::take(&mut self.message_data.lock()) } - pub fn take_subscribe(&self) -> Vec<(ClientId, Arc)> { + pub fn take_subscribe(&self) -> Vec<(ClientId, ChannelInfo)> { std::mem::take(&mut self.subscribe.lock()) } - pub fn take_unsubscribe(&self) -> Vec<(ClientId, Arc)> { + pub fn take_unsubscribe(&self) -> Vec<(ClientId, ChannelInfo)> { std::mem::take(&mut self.unsubscribe.lock()) } #[allow(dead_code)] - pub fn take_client_advertise(&self) -> Vec<(ClientId, ClientChannel)> { + pub fn take_client_advertise(&self) -> Vec<(ClientId, ClientChannelInfo)> { std::mem::take(&mut self.client_advertise.lock()) } #[allow(dead_code)] - pub fn take_client_unadvertise(&self) -> Vec<(ClientId, ClientChannelId)> { + pub fn take_client_unadvertise(&self) -> Vec<(ClientId, ClientChannelInfo)> { std::mem::take(&mut self.client_unadvertise.lock()) } } impl ServerListener for RecordingServerListener { - fn on_message_data(&self, client_id: ClientId, channel_id: ClientChannelId, payload: &[u8]) { + fn on_message_data(&self, client: Client, channel: ClientChannelView, payload: &[u8]) { let mut data = self.message_data.lock(); - data.push((client_id, channel_id, payload.to_vec())); + data.push((client.id(), channel.into(), payload.to_vec())); } - fn on_subscribe(&self, client_id: ClientId, channel: Arc) { + fn on_subscribe(&self, client: Client, channel: ChannelView) { let mut subs = self.subscribe.lock(); - subs.push((client_id, channel)); + subs.push((client.id(), channel.into())); } - fn on_unsubscribe(&self, client_id: ClientId, channel: Arc) { + fn on_unsubscribe(&self, client: Client, channel: ChannelView) { let mut unsubs = self.unsubscribe.lock(); - unsubs.push((client_id, channel)); + unsubs.push((client.id(), channel.into())); } - fn on_client_advertise(&self, client_id: ClientId, channel: &ClientChannel) { + fn on_client_advertise(&self, client: Client, channel: ClientChannelView) { let mut adverts = self.client_advertise.lock(); - adverts.push((client_id, channel.clone())); + adverts.push((client.id(), channel.into())); } - fn on_client_unadvertise(&self, client_id: ClientId, channel_id: ClientChannelId) { + fn on_client_unadvertise(&self, client: Client, channel: ClientChannelView) { let mut unadverts = self.client_unadvertise.lock(); - unadverts.push((client_id, channel_id)); + unadverts.push((client.id(), channel.into())); } } diff --git a/rust/foxglove/src/websocket.rs b/rust/foxglove/src/websocket.rs index c322410b..814ebac4 100644 --- a/rust/foxglove/src/websocket.rs +++ b/rust/foxglove/src/websocket.rs @@ -34,10 +34,59 @@ mod protocol; #[cfg(test)] mod tests; -/// An arbitrary integer unique identifier for a client connection. +/// Identifies a client connection. Unique for the duration of the server's lifetime. #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct ClientId(u32); +/// A connected client session with the websocket server. +#[derive(Debug)] +pub struct Client<'a>(&'a ConnectedClient); + +impl Client<'_> { + /// Returns the client ID. + pub fn id(&self) -> ClientId { + self.0.id + } +} + +/// Information about a client channel. +#[derive(Debug)] +pub struct ClientChannelView<'a> { + id: ClientChannelId, + topic: &'a str, +} + +impl ClientChannelView<'_> { + /// Returns the client channel ID. + pub fn id(&self) -> ClientChannelId { + self.id + } + + /// Returns the topic of the client channel. + pub fn topic(&self) -> &str { + self.topic + } +} + +/// Information about a channel. +#[derive(Debug)] +pub struct ChannelView<'a> { + id: ChannelId, + topic: &'a str, +} + +impl ChannelView<'_> { + /// Returns the channel ID. + pub fn id(&self) -> ChannelId { + self.id + } + + /// Returns the topic of the channel. + pub fn topic(&self) -> &str { + self.topic + } +} + pub(crate) const SUBPROTOCOL: &str = "foxglove.sdk.v1"; type WebsocketSender = SplitSink, Message>; @@ -103,21 +152,26 @@ pub(crate) struct Server { /// handling client message events. pub trait ServerListener: Send + Sync { /// Callback invoked when a client message is received. - fn on_message_data(&self, _client_id: ClientId, _channel_id: ClientChannelId, _payload: &[u8]) { + fn on_message_data( + &self, + _client: Client, + _client_channel: ClientChannelView, + _payload: &[u8], + ) { } /// Callback invoked when a client subscribes to a channel. /// Only invoked if the channel is associated with the server and isn't already subscribed to by the client. - fn on_subscribe(&self, _client_id: ClientId, _channel_id: Arc) {} + fn on_subscribe(&self, _client: Client, _channel: ChannelView) {} /// Callback invoked when a client unsubscribes from a channel. /// Only invoked for channels that had an active subscription from the client. - fn on_unsubscribe(&self, _client_id: ClientId, _channel_id: Arc) {} + fn on_unsubscribe(&self, _client: Client, _channel: ChannelView) {} /// Callback invoked when a client advertises a client channel. Requires the "clientPublish" capability. - fn on_client_advertise(&self, _client_id: ClientId, _channel: &ClientChannel) {} + fn on_client_advertise(&self, _client: Client, _channel: ClientChannelView) {} /// Callback invoked when a client unadvertises a client channel. Requires the "clientPublish" capability. - fn on_client_unadvertise(&self, _client_id: ClientId, _channel_id: ClientChannelId) {} + fn on_client_unadvertise(&self, _client: Client, _channel: ClientChannelView) {} } -/// State for a client maintained by the server +/// A connected client session with the websocket server. struct ConnectedClient { id: ClientId, addr: SocketAddr, @@ -130,7 +184,7 @@ struct ConnectedClient { /// Subscriptions from this client subscriptions: parking_lot::Mutex>, /// Channels advertised by this client - advertised_channels: parking_lot::Mutex>, + advertised_channels: parking_lot::Mutex>>, /// Optional callback handler for a server implementation server_listener: Option>, server: Weak, @@ -191,9 +245,9 @@ impl ConnectedClient { Some(protocol::client::BinaryOpcode::MessageData) => { match protocol::client::parse_binary_message(&msg_bytes) { Ok((channel_id, payload)) => { - { + let client_channel = { let advertised_channels = self.advertised_channels.lock(); - if !advertised_channels.contains_key(&channel_id) { + let Some(channel) = advertised_channels.get(&channel_id) else { tracing::error!( "Received message for unknown channel: {}", channel_id @@ -201,11 +255,19 @@ impl ConnectedClient { self.send_error(format!("Unknown channel ID: {}", channel_id)); // Do not forward to server listener return; - } - } + }; + channel.clone() + }; // Call the handler after releasing the advertised_channels lock if let Some(handler) = self.server_listener.as_ref() { - handler.on_message_data(self.id, channel_id, payload); + handler.on_message_data( + Client(self), + ClientChannelView { + id: client_channel.id, + topic: &client_channel.topic, + }, + payload, + ); } } Err(err) => { @@ -225,13 +287,14 @@ impl ConnectedClient { } fn on_unadvertise(&self, mut channel_ids: Vec) { + let mut client_channels = Vec::with_capacity(channel_ids.len()); // Using a limited scope and iterating twice to avoid holding the lock on advertised_channels while calling on_client_unadvertise { let mut advertised_channels = self.advertised_channels.lock(); let mut i = 0; while i < channel_ids.len() { let id = channel_ids[i]; - if advertised_channels.remove(&id).is_none() { + let Some(channel) = advertised_channels.remove(&id) else { // Remove the channel ID from the list so we don't invoke the on_client_unadvertise callback channel_ids.swap_remove(i); self.send_warning(format!( @@ -239,14 +302,21 @@ impl ConnectedClient { id )); continue; - } + }; + client_channels.push(channel.clone()); i += 1; } } // Call the handler after releasing the advertised_channels lock if let Some(handler) = self.server_listener.as_ref() { - for id in channel_ids { - handler.on_client_unadvertise(self.id, id); + for (id, client_channel) in channel_ids.iter().cloned().zip(client_channels) { + handler.on_client_unadvertise( + Client(self), + ClientChannelView { + id, + topic: &client_channel.topic, + }, + ); } } } @@ -259,7 +329,7 @@ impl ConnectedClient { for channel in channels { // Using a limited scope here to avoid holding the lock on advertised_channels while calling on_client_advertise - { + let client_channel = { match self.advertised_channels.lock().entry(channel.id) { Entry::Occupied(_) => { self.send_warning(format!( @@ -269,14 +339,22 @@ impl ConnectedClient { continue; } Entry::Vacant(entry) => { - entry.insert(channel.clone()); + let client_channel = Arc::new(channel); + entry.insert(client_channel.clone()); + client_channel } } - } + }; // Call the handler after releasing the advertised_channels lock if let Some(handler) = self.server_listener.as_ref() { - handler.on_client_advertise(self.id, &channel); + handler.on_client_advertise( + Client(self), + ClientChannelView { + id: client_channel.id, + topic: &client_channel.topic, + }, + ); } } } @@ -311,7 +389,13 @@ impl ConnectedClient { // Finally call the handler for each channel for channel in unsubscribed_channels { - handler.on_unsubscribe(self.id, channel); + handler.on_unsubscribe( + Client(self), + ChannelView { + id: channel.id, + topic: &channel.topic, + }, + ); } } @@ -371,7 +455,13 @@ impl ConnectedClient { subscription.id ); if let Some(handler) = self.server_listener.as_ref() { - handler.on_subscribe(self.id, channel); + handler.on_subscribe( + Client(self), + ChannelView { + id: channel.id, + topic: &channel.topic, + }, + ); } } } @@ -401,6 +491,15 @@ impl ConnectedClient { } } +impl std::fmt::Debug for ConnectedClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Client") + .field("id", &self.id) + .field("address", &self.addr) + .finish() + } +} + // A websocket server that implements the Foxglove WebSocket Protocol impl Server { pub fn new(weak_self: Weak, opts: ServerOptions) -> Self {