From 62680feeb17e4ff8fccfe2a66754fff1d2739861 Mon Sep 17 00:00:00 2001 From: Eloff Date: Fri, 7 Feb 2025 11:55:16 -0700 Subject: [PATCH] Core implement ws server listener interface for subs and adverts (#189) Add new callbacks to ServerListener: ``` fn on_subscribe(&self, _channel_id: ChannelId) {} fn on_unsubscribe(&self, _channel_id: ChannelId) {} fn on_client_advertise(&self, _channel: &ClientChannel) {} fn on_client_unadvertise(&self, _channel_id: ClientChannelId) {} ``` This mirrors how the callbacks are defined in the [Python implementation](https://github.com/foxglove/ws-protocol/blob/main/python/src/foxglove_websocket/server/__init__.py#L508), and they function in a similar way (e.g. not firing for duplicate or erroneous requests) I added RecordingServerListener to testutil to allow more easily testing ServerListener. I modified existing tests to verify these callbacks are called at the appropriate times with the expected arguments, and not called for duplicate requests. --- Cargo.lock | 1 + rust/foxglove/Cargo.toml | 1 + .../examples/unstable/client-publish.rs | 7 +- rust/foxglove/src/cow_vec.rs | 6 +- rust/foxglove/src/lib.rs | 3 +- rust/foxglove/src/tests/logging.rs | 5 +- rust/foxglove/src/testutil.rs | 104 +++++ rust/foxglove/src/websocket.rs | 394 ++++++++++++++---- .../foxglove/src/websocket/protocol/client.rs | 2 +- rust/foxglove/src/websocket/tests.rs | 57 ++- rust/foxglove/src/websocket/unstable_tests.rs | 75 ++-- 11 files changed, 525 insertions(+), 130 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8fa9a8bd..1900357e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -507,6 +507,7 @@ dependencies = [ "arc-swap", "assert_matches", "base64", + "bimap", "bytes", "clap", "ctrlc", diff --git a/rust/foxglove/Cargo.toml b/rust/foxglove/Cargo.toml index 2c8948a0..482cf23d 100644 --- a/rust/foxglove/Cargo.toml +++ b/rust/foxglove/Cargo.toml @@ -10,6 +10,7 @@ license = "MIT" unstable = [] [dependencies] +bimap = "0.6.3" schemars = "0.8.21" arc-swap = "1.7.1" base64 = "0.22.1" diff --git a/rust/foxglove/examples/unstable/client-publish.rs b/rust/foxglove/examples/unstable/client-publish.rs index 3020e58c..f07d1fe8 100644 --- a/rust/foxglove/examples/unstable/client-publish.rs +++ b/rust/foxglove/examples/unstable/client-publish.rs @@ -11,7 +11,8 @@ use clap::Parser; use foxglove::schemas::log::Level; use foxglove::schemas::Log; use foxglove::{ - Capability, ClientChannelId, PartialMetadata, ServerListener, TypedChannel, WebSocketServer, + Capability, Client, ClientChannelView, PartialMetadata, ServerListener, TypedChannel, + WebSocketServer, }; use std::sync::Arc; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; @@ -29,10 +30,10 @@ struct Cli { struct ExampleCallbackHandler; impl ServerListener for ExampleCallbackHandler { - fn on_message_data(&self, channel_id: ClientChannelId, message: &[u8]) { + fn on_message_data(&self, _client: Client, channel: ClientChannelView, message: &[u8]) { let json: serde_json::Value = serde_json::from_slice(message).expect("Failed to parse message"); - println!("Received message on channel {channel_id}: {json}"); + println!("Received message on channel {}: {json}", channel.id()); } } diff --git a/rust/foxglove/src/cow_vec.rs b/rust/foxglove/src/cow_vec.rs index 1bef8ce9..d8ac27d4 100644 --- a/rust/foxglove/src/cow_vec.rs +++ b/rust/foxglove/src/cow_vec.rs @@ -93,18 +93,16 @@ mod tests { let handle = thread::spawn(move || { vec2.push(4); - vec2.get().to_vec() }); vec.push(5); - let thread_state = handle.join().unwrap(); + handle.join().unwrap(); let final_state = vec.get().to_vec(); // Old snapshot should still be valid and have original length assert_eq!(old_snapshot.len(), 3); - // Both threads should see 5 items in their final state + // There should now be 5 items in the final state assert_eq!(final_state.len(), 5); - assert_eq!(thread_state.len(), 5); } #[test] diff --git a/rust/foxglove/src/lib.rs b/rust/foxglove/src/lib.rs index 86de0856..ef27ae08 100644 --- a/rust/foxglove/src/lib.rs +++ b/rust/foxglove/src/lib.rs @@ -193,8 +193,9 @@ pub use runtime::shutdown_runtime; pub(crate) use time::nanoseconds_since_epoch; #[doc(hidden)] #[cfg(feature = "unstable")] +pub use websocket::{Capability, Parameter, ParameterType, ParameterValue}; pub use websocket::{ - Capability, ClientChannelId, Parameter, ParameterType, ParameterValue, ServerListener, + ChannelView, Client, ClientChannelId, ClientChannelView, ClientId, ServerListener, }; pub use websocket_server::{WebSocketServer, WebSocketServerBlockingHandle, WebSocketServerHandle}; diff --git a/rust/foxglove/src/tests/logging.rs b/rust/foxglove/src/tests/logging.rs index 9087fbeb..a516f5ae 100644 --- a/rust/foxglove/src/tests/logging.rs +++ b/rust/foxglove/src/tests/logging.rs @@ -41,7 +41,7 @@ async fn test_logging_to_file_and_live_sinks() { ws_stream }; - // FG-9877: allow the server to handle the connection before creating the channel + // FG-10395 replace this with something more precise tokio::time::sleep(Duration::from_millis(100)).await; let channel = ChannelBuilder::new("/test-topic") @@ -103,6 +103,7 @@ async fn test_logging_to_file_and_live_sinks() { .expect("Failed to subscribe"); // Let subscription register before publishing + // FG-10395 replace this with something more precise tokio::time::sleep(Duration::from_millis(100)).await; } @@ -122,7 +123,7 @@ async fn test_logging_to_file_and_live_sinks() { channel.log(&msg); - // Ensure message has arrived + // FG-10395 replace this with something more precise tokio::time::sleep(Duration::from_millis(100)).await; let writer = handle.close().expect("Failed to flush log"); diff --git a/rust/foxglove/src/testutil.rs b/rust/foxglove/src/testutil.rs index 0d21bb56..212008da 100644 --- a/rust/foxglove/src/testutil.rs +++ b/rust/foxglove/src/testutil.rs @@ -3,5 +3,109 @@ mod log_context; mod log_sink; +use crate::channel::ChannelId; +use crate::websocket::{ + ChannelView, Client, ClientChannelId, ClientChannelView, ClientId, ServerListener, +}; pub use log_context::GlobalContextTest; pub use log_sink::{ErrorSink, MockSink, RecordingSink}; +use parking_lot::Mutex; + +#[allow(dead_code)] +pub(crate) struct ClientChannelInfo { + pub(crate) id: ClientChannelId, + pub(crate) topic: String, +} + +impl From> for ClientChannelInfo { + fn from(channel: ClientChannelView) -> Self { + Self { + id: channel.id(), + topic: channel.topic().to_string(), + } + } +} + +pub(crate) struct ChannelInfo { + pub(crate) id: ChannelId, + pub(crate) topic: String, +} + +impl From> for ChannelInfo { + fn from(channel: ChannelView) -> Self { + Self { + id: channel.id(), + topic: channel.topic().to_string(), + } + } +} + +pub(crate) struct RecordingServerListener { + message_data: Mutex)>>, + subscribe: Mutex>, + unsubscribe: Mutex>, + client_advertise: Mutex>, + client_unadvertise: Mutex>, +} + +impl RecordingServerListener { + pub fn new() -> Self { + Self { + message_data: Mutex::new(Vec::new()), + subscribe: Mutex::new(Vec::new()), + unsubscribe: Mutex::new(Vec::new()), + client_advertise: Mutex::new(Vec::new()), + client_unadvertise: Mutex::new(Vec::new()), + } + } + + #[allow(dead_code)] + pub fn take_message_data(&self) -> Vec<(ClientId, ClientChannelInfo, Vec)> { + std::mem::take(&mut self.message_data.lock()) + } + + pub fn take_subscribe(&self) -> Vec<(ClientId, ChannelInfo)> { + std::mem::take(&mut self.subscribe.lock()) + } + + pub fn take_unsubscribe(&self) -> Vec<(ClientId, ChannelInfo)> { + std::mem::take(&mut self.unsubscribe.lock()) + } + + #[allow(dead_code)] + pub fn take_client_advertise(&self) -> Vec<(ClientId, ClientChannelInfo)> { + std::mem::take(&mut self.client_advertise.lock()) + } + + #[allow(dead_code)] + pub fn take_client_unadvertise(&self) -> Vec<(ClientId, ClientChannelInfo)> { + std::mem::take(&mut self.client_unadvertise.lock()) + } +} + +impl ServerListener for RecordingServerListener { + fn on_message_data(&self, client: Client, channel: ClientChannelView, payload: &[u8]) { + let mut data = self.message_data.lock(); + data.push((client.id(), channel.into(), payload.to_vec())); + } + + fn on_subscribe(&self, client: Client, channel: ChannelView) { + let mut subs = self.subscribe.lock(); + subs.push((client.id(), channel.into())); + } + + fn on_unsubscribe(&self, client: Client, channel: ChannelView) { + let mut unsubs = self.unsubscribe.lock(); + unsubs.push((client.id(), channel.into())); + } + + fn on_client_advertise(&self, client: Client, channel: ClientChannelView) { + let mut adverts = self.client_advertise.lock(); + adverts.push((client.id(), channel.into())); + } + + fn on_client_unadvertise(&self, client: Client, channel: ClientChannelView) { + let mut unadverts = self.client_unadvertise.lock(); + unadverts.push((client.id(), channel.into())); + } +} diff --git a/rust/foxglove/src/websocket.rs b/rust/foxglove/src/websocket.rs index b853052c..b10c8101 100644 --- a/rust/foxglove/src/websocket.rs +++ b/rust/foxglove/src/websocket.rs @@ -1,5 +1,6 @@ use crate::channel::ChannelId; use crate::cow_vec::CowVec; +use crate::websocket::protocol::client::Subscription; pub use crate::websocket::protocol::client::{ ClientChannel, ClientChannelId, ClientMessage, SubscriptionId, }; @@ -7,12 +8,14 @@ pub use crate::websocket::protocol::server::Capability; #[cfg(feature = "unstable")] pub use crate::websocket::protocol::server::{Parameter, ParameterType, ParameterValue}; use crate::{get_runtime_handle, Channel, FoxgloveError, LogSink, Metadata}; +use bimap::BiHashMap; use bytes::{BufMut, BytesMut}; use flume::TrySendError; use futures_util::{stream::SplitSink, SinkExt, StreamExt}; +use std::collections::hash_map::Entry; use std::collections::HashSet; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering::{AcqRel, Acquire}; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed}; +use std::sync::atomic::{AtomicBool, AtomicU32}; use std::sync::Weak; use std::{collections::HashMap, net::SocketAddr, sync::Arc}; use thiserror::Error; @@ -33,6 +36,59 @@ mod tests; #[cfg(all(test, feature = "unstable"))] mod unstable_tests; +/// Identifies a client connection. Unique for the duration of the server's lifetime. +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct ClientId(u32); + +/// A connected client session with the websocket server. +#[derive(Debug)] +pub struct Client<'a>(&'a ConnectedClient); + +impl Client<'_> { + /// Returns the client ID. + pub fn id(&self) -> ClientId { + self.0.id + } +} + +/// Information about a client channel. +#[derive(Debug)] +pub struct ClientChannelView<'a> { + id: ClientChannelId, + topic: &'a str, +} + +impl ClientChannelView<'_> { + /// Returns the client channel ID. + pub fn id(&self) -> ClientChannelId { + self.id + } + + /// Returns the topic of the client channel. + pub fn topic(&self) -> &str { + self.topic + } +} + +/// Information about a channel. +#[derive(Debug)] +pub struct ChannelView<'a> { + id: ChannelId, + topic: &'a str, +} + +impl ChannelView<'_> { + /// Returns the channel ID. + pub fn id(&self) -> ChannelId { + self.id + } + + /// Returns the topic of the channel. + pub fn topic(&self) -> &str { + self.topic + } +} + pub(crate) const SUBPROTOCOL: &str = "foxglove.sdk.v1"; type WebsocketSender = SplitSink, Message>; @@ -98,11 +154,28 @@ pub(crate) struct Server { /// handling client message events. pub trait ServerListener: Send + Sync { /// Callback invoked when a client message is received. - fn on_message_data(&self, channel_id: ClientChannelId, payload: &[u8]); + fn on_message_data( + &self, + _client: Client, + _client_channel: ClientChannelView, + _payload: &[u8], + ) { + } + /// Callback invoked when a client subscribes to a channel. + /// Only invoked if the channel is associated with the server and isn't already subscribed to by the client. + fn on_subscribe(&self, _client: Client, _channel: ChannelView) {} + /// Callback invoked when a client unsubscribes from a channel. + /// Only invoked for channels that had an active subscription from the client. + fn on_unsubscribe(&self, _client: Client, _channel: ChannelView) {} + /// Callback invoked when a client advertises a client channel. Requires the "clientPublish" capability. + fn on_client_advertise(&self, _client: Client, _channel: ClientChannelView) {} + /// Callback invoked when a client unadvertises a client channel. Requires the "clientPublish" capability. + fn on_client_unadvertise(&self, _client: Client, _channel: ClientChannelView) {} } -/// State for a client maintained by the server +/// A connected client session with the websocket server. struct ConnectedClient { + id: ClientId, addr: SocketAddr, /// Write side of a WS stream sender: Mutex, @@ -111,9 +184,9 @@ struct ConnectedClient { control_plane_tx: flume::Sender, control_plane_rx: flume::Receiver, /// Subscriptions from this client - subscriptions_by_channel: parking_lot::Mutex>, + subscriptions: parking_lot::Mutex>, /// Channels advertised by this client - advertised_channels: parking_lot::Mutex>, + advertised_channels: parking_lot::Mutex>>, /// Optional callback handler for a server implementation server_listener: Option>, server: Weak, @@ -137,60 +210,23 @@ impl ConnectedClient { self.send_error(format!("Invalid message: {message}")); return; }; + let Some(server) = self.server.upgrade() else { + tracing::error!("Server closed"); + return; + }; match serde_json::from_str::(message) { Ok(ClientMessage::Subscribe { subscriptions }) => { - let Some(server) = self.server.upgrade() else { - tracing::error!("Server closed"); - return; - }; - let channels = server.channels.read(); - for subscription in subscriptions { - if !channels.contains_key(&subscription.channel_id) { - tracing::error!( - "Client {} attempted to subscribe to unknown channel: {}", - self.addr, - subscription.channel_id - ); - self.send_error(format!("Unknown channel ID: {}", subscription.channel_id)); - continue; - } - - let mut subscriptions = self.subscriptions_by_channel.lock(); - subscriptions.insert(subscription.channel_id, subscription.id); - tracing::info!( - "Client {} subscribed to channel {} with subscription id {}", - self.addr, - subscription.channel_id, - subscription.id - ); - } + self.on_subscribe(server, subscriptions); } Ok(ClientMessage::Unsubscribe { subscription_ids }) => { - let mut subscriptions = self.subscriptions_by_channel.lock(); - subscriptions - .retain(|_, subscription_id| !subscription_ids.contains(subscription_id)) + self.on_unsubscribe(server, subscription_ids); } Ok(ClientMessage::Advertise { channels }) => { - let Some(server) = self.server.upgrade() else { - tracing::info!("Server closed; ignoring client advertisement"); - return; - }; - if !server.capabilities.contains(&Capability::ClientPublish) { - self.send_error("Server does not support clientPublish capability".to_string()); - return; - } - - let mut advertised_channels = self.advertised_channels.lock(); - for channel in channels { - advertised_channels.insert(channel.id, channel); - } + self.on_advertise(server, channels); } Ok(ClientMessage::Unadvertise { channel_ids }) => { - let mut advertised_channels = self.advertised_channels.lock(); - for id in channel_ids { - advertised_channels.remove(&id); - } + self.on_unadvertise(channel_ids); } _ => { tracing::error!("Unsupported message from {}: {message}", self.addr); @@ -199,20 +235,6 @@ impl ConnectedClient { } } - /// Send an ad hoc error status message to the client, with the given message. - fn send_error(&self, message: String) { - let status = protocol::server::Status { - level: protocol::server::StatusLevel::Error, - message: message.to_string(), - id: None, - }; - let message = Message::text(serde_json::to_string(&status).unwrap()); - // If the message can't be sent, or the outbox is full, log a warning and continue. - self.data_plane_tx.try_send(message).unwrap_or_else(|err| { - tracing::warn!("Failed to send status to client {}: {err}", self.addr) - }); - } - fn handle_binary_message(&self, message: Message) { if message.is_empty() { tracing::debug!("Received empty binary message from {}", self.addr); @@ -225,9 +247,9 @@ impl ConnectedClient { Some(protocol::client::BinaryOpcode::MessageData) => { match protocol::client::parse_binary_message(&msg_bytes) { Ok((channel_id, payload)) => { - { + let client_channel = { let advertised_channels = self.advertised_channels.lock(); - if !advertised_channels.contains_key(&channel_id) { + let Some(channel) = advertised_channels.get(&channel_id) else { tracing::error!( "Received message for unknown channel: {}", channel_id @@ -235,10 +257,19 @@ impl ConnectedClient { self.send_error(format!("Unknown channel ID: {}", channel_id)); // Do not forward to server listener return; - } - } + }; + channel.clone() + }; + // Call the handler after releasing the advertised_channels lock if let Some(handler) = self.server_listener.as_ref() { - handler.on_message_data(channel_id, payload); + handler.on_message_data( + Client(self), + ClientChannelView { + id: client_channel.id, + topic: &client_channel.topic, + }, + payload, + ); } } Err(err) => { @@ -256,6 +287,219 @@ impl ConnectedClient { } } } + + fn on_unadvertise(&self, mut channel_ids: Vec) { + let mut client_channels = Vec::with_capacity(channel_ids.len()); + // Using a limited scope and iterating twice to avoid holding the lock on advertised_channels while calling on_client_unadvertise + { + let mut advertised_channels = self.advertised_channels.lock(); + let mut i = 0; + while i < channel_ids.len() { + let id = channel_ids[i]; + let Some(channel) = advertised_channels.remove(&id) else { + // Remove the channel ID from the list so we don't invoke the on_client_unadvertise callback + channel_ids.swap_remove(i); + self.send_warning(format!( + "Client is not advertising channel: {}; ignoring unadvertisement", + id + )); + continue; + }; + client_channels.push(channel.clone()); + i += 1; + } + } + // Call the handler after releasing the advertised_channels lock + if let Some(handler) = self.server_listener.as_ref() { + for (id, client_channel) in channel_ids.iter().cloned().zip(client_channels) { + handler.on_client_unadvertise( + Client(self), + ClientChannelView { + id, + topic: &client_channel.topic, + }, + ); + } + } + } + + fn on_advertise(&self, server: Arc, channels: Vec) { + if !server.capabilities.contains(&Capability::ClientPublish) { + self.send_error("Server does not support clientPublish capability".to_string()); + return; + } + + for channel in channels { + // Using a limited scope here to avoid holding the lock on advertised_channels while calling on_client_advertise + let client_channel = { + match self.advertised_channels.lock().entry(channel.id) { + Entry::Occupied(_) => { + self.send_warning(format!( + "Client is already advertising channel: {}; ignoring advertisement", + channel.id + )); + continue; + } + Entry::Vacant(entry) => { + let client_channel = Arc::new(channel); + entry.insert(client_channel.clone()); + client_channel + } + } + }; + + // Call the handler after releasing the advertised_channels lock + if let Some(handler) = self.server_listener.as_ref() { + handler.on_client_advertise( + Client(self), + ClientChannelView { + id: client_channel.id, + topic: &client_channel.topic, + }, + ); + } + } + } + + fn on_unsubscribe(&self, server: Arc, subscription_ids: Vec) { + let mut unsubscribed_channel_ids = Vec::with_capacity(subscription_ids.len()); + // First gather the unsubscribed channel ids while holding the subscriptions lock + { + let mut subscriptions = self.subscriptions.lock(); + for subscription_id in subscription_ids { + if let Some((channel_id, _)) = subscriptions.remove_by_right(&subscription_id) { + unsubscribed_channel_ids.push(channel_id); + } + } + } + + // If we don't have a ServerListener, we're done. + let Some(handler) = self.server_listener.as_ref() else { + return; + }; + + // Then gather the actual channel references while holding the channels lock + let mut unsubscribed_channels = Vec::with_capacity(unsubscribed_channel_ids.len()); + { + let channels = server.channels.read(); + for channel_id in unsubscribed_channel_ids { + if let Some(channel) = channels.get(&channel_id) { + unsubscribed_channels.push(channel.clone()); + } + } + } + + // Finally call the handler for each channel + for channel in unsubscribed_channels { + handler.on_unsubscribe( + Client(self), + ChannelView { + id: channel.id, + topic: &channel.topic, + }, + ); + } + } + + fn on_subscribe(&self, server: Arc, mut subscriptions: Vec) { + // First prune out any subscriptions for channels not in the channel map, + // limiting how long we need to hold the lock. + let mut subscribed_channels = Vec::with_capacity(subscriptions.len()); + { + let channels = server.channels.read(); + let mut i = 0; + while i < subscriptions.len() { + let subscription = &subscriptions[i]; + let Some(channel) = channels.get(&subscription.channel_id) else { + tracing::error!( + "Client {} attempted to subscribe to unknown channel: {}", + self.addr, + subscription.channel_id + ); + self.send_error(format!("Unknown channel ID: {}", subscription.channel_id)); + // Remove the subscription from the list so we don't invoke the on_subscribe callback for it + subscriptions.swap_remove(i); + continue; + }; + subscribed_channels.push(channel.clone()); + i += 1 + } + } + + for (subscription, channel) in subscriptions.into_iter().zip(subscribed_channels) { + // Using a limited scope here to avoid holding the lock on subscriptions while calling on_subscribe + { + let mut subscriptions = self.subscriptions.lock(); + if subscriptions + .insert_no_overwrite(subscription.channel_id, subscription.id) + .is_err() + { + if subscriptions.contains_left(&subscription.channel_id) { + self.send_warning(format!( + "Client is already subscribed to channel: {}; ignoring subscription", + subscription.channel_id + )); + } else { + assert!(subscriptions.contains_right(&subscription.id)); + self.send_error(format!( + "Subscription ID was already used: {}; ignoring subscription", + subscription.id + )); + } + continue; + } + } + + tracing::info!( + "Client {} subscribed to channel {} with subscription id {}", + self.addr, + subscription.channel_id, + subscription.id + ); + if let Some(handler) = self.server_listener.as_ref() { + handler.on_subscribe( + Client(self), + ChannelView { + id: channel.id, + topic: &channel.topic, + }, + ); + } + } + } + + /// Send an ad hoc error status message to the client, with the given message. + fn send_error(&self, message: String) { + self.send_status(protocol::server::StatusLevel::Error, message, None); + } + + /// Send an ad hoc warning status message to the client, with the given message. + fn send_warning(&self, message: String) { + self.send_status(protocol::server::StatusLevel::Warning, message, None); + } + + fn send_status( + &self, + level: protocol::server::StatusLevel, + message: String, + id: Option, + ) { + let status = protocol::server::Status { level, message, id }; + let message = Message::text(serde_json::to_string(&status).unwrap()); + // If the message can't be sent, or the outbox is full, log a warning and continue. + self.data_plane_tx.try_send(message).unwrap_or_else(|err| { + tracing::warn!("Failed to send status to client {}: {err}", self.addr) + }); + } +} + +impl std::fmt::Debug for ConnectedClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Client") + .field("id", &self.id) + .field("address", &self.addr) + .finish() + } } // A websocket server that implements the Foxglove WebSocket Protocol @@ -490,17 +734,21 @@ impl Server { return; } + static CLIENT_ID: AtomicU32 = AtomicU32::new(1); + let id = ClientId(CLIENT_ID.fetch_add(1, Relaxed)); + let (data_tx, data_rx) = flume::bounded(self.message_backlog_size as usize); let (ctrl_tx, ctrl_rx) = flume::bounded(DEFAULT_CONTROL_PLANE_BACKLOG_SIZE); let new_client = Arc::new(ConnectedClient { + id, addr, sender: Mutex::new(ws_sender), data_plane_tx: data_tx, data_plane_rx: data_rx, control_plane_tx: ctrl_tx, control_plane_rx: ctrl_rx, - subscriptions_by_channel: parking_lot::Mutex::new(HashMap::new()), + subscriptions: parking_lot::Mutex::new(BiHashMap::new()), advertised_channels: parking_lot::Mutex::new(HashMap::new()), server_listener: self.listener.clone(), server: self.weak_self.clone(), @@ -665,8 +913,8 @@ impl LogSink for Server { ) -> Result<(), FoxgloveError> { let clients = self.clients.get(); for client in clients.iter() { - let subscriptions = client.subscriptions_by_channel.lock(); - let Some(subscription_id) = subscriptions.get(&channel.id).cloned() else { + let subscriptions = client.subscriptions.lock(); + let Some(subscription_id) = subscriptions.get_by_left(&channel.id).cloned() else { continue; }; diff --git a/rust/foxglove/src/websocket/protocol/client.rs b/rust/foxglove/src/websocket/protocol/client.rs index 20bc8c22..b898b7f5 100644 --- a/rust/foxglove/src/websocket/protocol/client.rs +++ b/rust/foxglove/src/websocket/protocol/client.rs @@ -85,7 +85,7 @@ pub struct Subscription { } #[doc(hidden)] -#[derive(Debug, Deserialize, PartialEq, Serialize)] +#[derive(Debug, Clone, Deserialize, PartialEq, Serialize)] #[serde(rename_all = "camelCase")] pub struct ClientChannel { pub id: ClientChannelId, diff --git a/rust/foxglove/src/websocket/tests.rs b/rust/foxglove/src/websocket/tests.rs index 408d49e7..7eb8e948 100644 --- a/rust/foxglove/src/websocket/tests.rs +++ b/rust/foxglove/src/websocket/tests.rs @@ -9,6 +9,7 @@ use super::{ create_server, send_lossy, ClientMessage, SendLossyResult, ServerOptions, SubscriptionId, SUBPROTOCOL, }; +use crate::testutil::RecordingServerListener; use crate::{collection, Channel, ChannelBuilder, LogContext, LogSink, Metadata, Schema}; fn make_message(id: usize) -> Message { @@ -169,7 +170,12 @@ async fn test_handshake_with_multiple_subprotocols() { #[tokio::test] async fn test_advertise_to_client() { - let server = create_server(ServerOptions::default()); + let recording_listener = Arc::new(RecordingServerListener::new()); + + let server = create_server(ServerOptions { + listener: Some(recording_listener.clone()), + ..Default::default() + }); let ctx = LogContext::new(); ctx.add_sink(server.clone()); @@ -210,13 +216,18 @@ async fn test_advertise_to_client() { } ] }); + client_sender + .send(Message::text(subscribe.to_string())) + .await + .expect("Failed to send"); + // Send a duplicate subscribe message (ignored) client_sender .send(Message::text(subscribe.to_string())) .await .expect("Failed to send"); // Allow the server to process the subscription - // FG-9723: replace this with an on_subscribe callback + // FG-10395 replace this with something more precise tokio::time::sleep(std::time::Duration::from_millis(100)).await; server.log(&ch, b"{\"a\":1}", &metadata).unwrap(); @@ -224,6 +235,16 @@ async fn test_advertise_to_client() { let result = client_receiver.next().await.unwrap(); let msg = result.expect("Failed to parse message"); let data = msg.into_data(); + let data_str = std::str::from_utf8(&data).unwrap(); + println!("data_str: {data_str}"); + assert!(data_str.contains("Client is already subscribed to channel")); + + let msg = client_receiver + .next() + .await + .unwrap() + .expect("Failed to parse message"); + let data = msg.into_data(); assert_eq!(data[0], 0x01); // message data opcode assert_eq!( @@ -231,12 +252,22 @@ async fn test_advertise_to_client() { subscription_id ); + let subscriptions = recording_listener.take_subscribe(); + assert_eq!(subscriptions.len(), 1); + assert_eq!(subscriptions[0].1.id, ch.id); + assert_eq!(subscriptions[0].1.topic, ch.topic); + server.stop().await; } #[tokio::test] async fn test_log_only_to_subscribers() { - let server = create_server(ServerOptions::default()); + let recording_listener = Arc::new(RecordingServerListener::new()); + + let server = create_server(ServerOptions { + listener: Some(recording_listener.clone()), + ..Default::default() + }); let ctx = LogContext::new(); @@ -325,9 +356,27 @@ async fn test_log_only_to_subscribers() { .expect("Failed to send"); // Allow the server to process the subscription - // FG-9723: replace this with an on_subscribe callback + // FG-10395 replace this with something more precise tokio::time::sleep(std::time::Duration::from_millis(100)).await; + let subscriptions = recording_listener.take_subscribe(); + assert_eq!(subscriptions.len(), 4); + assert_eq!(subscriptions[0].1.id, ch1.id); + assert_eq!(subscriptions[1].1.id, ch2.id); + assert_eq!(subscriptions[2].1.id, ch1.id); + assert_eq!(subscriptions[3].1.id, ch2.id); + assert_eq!(subscriptions[0].1.topic, ch1.topic); + assert_eq!(subscriptions[1].1.topic, ch2.topic); + assert_eq!(subscriptions[2].1.topic, ch1.topic); + assert_eq!(subscriptions[3].1.topic, ch2.topic); + + let unsubscriptions = recording_listener.take_unsubscribe(); + assert_eq!(unsubscriptions.len(), 2); + assert_eq!(unsubscriptions[0].1.id, ch1.id); + assert_eq!(unsubscriptions[1].1.id, ch2.id); + assert_eq!(unsubscriptions[0].1.topic, ch1.topic); + assert_eq!(unsubscriptions[1].1.topic, ch2.topic); + let metadata = Metadata { log_time: 123456, ..Metadata::default() diff --git a/rust/foxglove/src/websocket/unstable_tests.rs b/rust/foxglove/src/websocket/unstable_tests.rs index befd2ff1..e7ee3c53 100644 --- a/rust/foxglove/src/websocket/unstable_tests.rs +++ b/rust/foxglove/src/websocket/unstable_tests.rs @@ -1,53 +1,24 @@ use std::{collections::HashSet, sync::Arc}; -use bytes::{Buf, BufMut, BytesMut}; -use futures_util::{SinkExt, StreamExt}; -use serde_json::{json, Value}; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; -use tokio_tungstenite::tungstenite::Message; - use super::tests::connect_client; use super::{ create_server, protocol, Capability, ClientChannelId, Parameter, ParameterType, ParameterValue, - ServerListener, ServerOptions, + ServerOptions, }; - -struct ClientMessageData { - channel_id: ClientChannelId, - payload: Vec, -} - -struct MPSCServerListener(UnboundedSender); - -impl MPSCServerListener { - fn create() -> ( - Arc, - UnboundedReceiver, - ) { - let (tx, rx) = unbounded_channel(); - (Arc::new(Self(tx)), rx) - } -} - -impl ServerListener for MPSCServerListener { - fn on_message_data(&self, channel_id: ClientChannelId, message: &[u8]) { - self.0 - .send(ClientMessageData { - channel_id, - payload: message.to_vec(), - }) - .expect("MPSC queue closed"); - } -} +use crate::testutil::RecordingServerListener; +use bytes::{Buf, BufMut, BytesMut}; +use futures_util::{SinkExt, StreamExt}; +use serde_json::{json, Value}; +use tokio_tungstenite::tungstenite::Message; #[tokio::test] async fn test_client_advertising() { - let (listener, mut chan_rx) = MPSCServerListener::create(); + let recording_listener = Arc::new(RecordingServerListener::new()); let server = create_server(ServerOptions { capabilities: Some(HashSet::from([Capability::ClientPublish])), supported_encodings: Some(HashSet::from(["json".to_string()])), - listener: Some(listener), + listener: Some(recording_listener.clone()), ..Default::default() }); @@ -73,7 +44,7 @@ async fn test_client_advertising() { .await .expect("Failed to send binary message"); // No message sent to listener - assert!(chan_rx.try_recv().is_err()); + assert!(recording_listener.take_message_data().is_empty()); let advertise = json!({ "op": "advertise", @@ -121,10 +92,30 @@ async fn test_client_advertising() { .await .expect("Failed to send unadvertise"); + // Should be ignored + ws_client + .send(Message::text(unadvertise.to_string())) + .await + .expect("Failed to send unadvertise"); + + // FG-10395 replace this with something more precise + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + // Server should have received one message - let received = chan_rx.recv().await.expect("No message received"); - assert_eq!(received.channel_id, ClientChannelId::new(1)); - assert_eq!(received.payload, b"{\"a\":1}"); + let mut received = recording_listener.take_message_data(); + let (_, channel_info, payload) = received.pop().expect("No message received"); + assert_eq!(channel_info.id, ClientChannelId::new(1)); + assert_eq!(payload, b"{\"a\":1}"); + + // Server should have ignored the duplicate advertisement + let advertisements = recording_listener.take_client_advertise(); + assert_eq!(advertisements.len(), 1); + assert_eq!(advertisements[0].1.id, channel_info.id); + + // Server should have received one unadvertise (and ignored the duplicate) + let unadvertises = recording_listener.take_client_unadvertise(); + assert_eq!(unadvertises.len(), 1); + assert_eq!(unadvertises[0].1.id, channel_info.id); ws_client.close(None).await.unwrap(); server.stop().await; @@ -179,7 +170,7 @@ async fn test_parameter_values() { }; server.publish_parameter_values(vec![parameter], None).await; - // Allow the server to process the parameter values + // FG-10395 replace this with something more precise std::thread::sleep(std::time::Duration::from_millis(100)); let msg = ws_client.next().await.expect("No message received");