From 15bfa6337bc5f5492c67336f6afcbdfed10fa380 Mon Sep 17 00:00:00 2001 From: banocean <47253870+banocean@users.noreply.github.com> Date: Fri, 26 Jul 2024 14:58:01 +0200 Subject: [PATCH] store unsaved changes --- Cargo.toml | 4 +- src/server/guild/editing.rs | 65 ++++++++++++ src/server/guild/ws.rs | 179 ++++++++++++++++++++++++++++++++ src/server/mod.rs | 6 ++ src/server/routes/guilds/_id.rs | 104 ++++--------------- src/server/routes/guilds/mod.rs | 2 +- 6 files changed, 277 insertions(+), 83 deletions(-) create mode 100644 src/server/guild/editing.rs create mode 100644 src/server/guild/ws.rs diff --git a/Cargo.toml b/Cargo.toml index 8361f19..409ac1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ serde = "1.0" async-trait = "0.1.57" futures-util = "0.3.19" tokio = "1.16.1" +tokio-stream = "0.1" mongodb = "2.1.0" redis = { version = "0.24", features = ["aio", "tokio-comp"] } @@ -37,6 +38,7 @@ hex = { version = "0.4", optional = true } warp = { version = "0.3", optional = true } rusty_paseto = { version = "0.6", features = ["core", "v4_local"], optional = true } anyhow = { version = "1.0", optional = true } +json-patch = { version = "2.0", optional = true } [features] all = ["custom-clients", "tasks", "http-interactions", "api", "gateway"] @@ -44,4 +46,4 @@ custom-clients = [] tasks = [] http-interactions = ["dep:warp", "dep:hex", "dep:anyhow", "dep:ed25519-dalek"] gateway = ["dep:regex", "dep:twilight-gateway"] -api = ["dep:warp", "dep:rusty_paseto", "dep:serde_urlencoded", "dep:anyhow", "reqwest/json"] +api = ["dep:warp", "dep:rusty_paseto", "dep:serde_urlencoded", "dep:anyhow", "dep:json-patch", "reqwest/json"] diff --git a/src/server/guild/editing.rs b/src/server/guild/editing.rs new file mode 100644 index 0000000..4b3e725 --- /dev/null +++ b/src/server/guild/editing.rs @@ -0,0 +1,65 @@ +use std::collections::{BTreeSet, HashMap}; +use std::sync::Arc; +use mongodb::bson::oid::ObjectId; +use serde_json::Value; +use tokio::sync::{Mutex, RwLock}; +use twilight_model::id::Id; +use twilight_model::id::marker::{GuildMarker, UserMarker}; +use crate::context::Context; +use crate::server::guild::ws::{Connection, OutboundAction, OutboundMessage}; + +struct GuildEditingState { + pub connections: Vec>, + pub changes: Value, + pub edited_by: BTreeSet> +} + +#[derive(Default)] +pub struct GuildsEditing(RwLock, Arc>>>); + +impl GuildsEditing { + pub async fn add_connection(&self, guild_id: Id, connection_data: Connection) { + todo!() + } + + pub async fn remove_connection(&self, guild_id: Id, session_id: ObjectId) { + todo!() + } + + pub async fn marge_changes( + &self, + author: Id, + guild_id: Id, + changes: Value + ) -> Option<()> { + let guild = self.get_guild(guild_id).await?; + let mut guild_lock = guild.lock().await; + json_patch::merge(&mut guild_lock.changes, &changes); + guild_lock.edited_by.insert(author); + Some(()) + } + + async fn get_guild(&self, guild_id: Id) -> Option>> { + let list_lock = self.0.read().await; + list_lock.get(&guild_id).cloned() + } + + pub async fn broadcast_changes(&self, context: &Arc, guild_id: Id) -> Option<()> { + let config = context.mongodb.get_config(guild_id).await.ok()?; + let guild = self.get_guild(guild_id).await?; + let guild_lock = guild.lock().await; + let users = guild_lock.connections + .iter().map(|connection| connection.user_id) + .collect::>>(); + + for connection in &guild_lock.connections { + let _ = connection.tx.send(OutboundAction::Message(OutboundMessage::UpdateConfigurationData { + saved_config: config.to_owned(), + changes: guild_lock.changes.to_owned(), + users: users.to_owned(), + })); + } + + Some(()) + } +} diff --git a/src/server/guild/ws.rs b/src/server/guild/ws.rs new file mode 100644 index 0000000..9e00522 --- /dev/null +++ b/src/server/guild/ws.rs @@ -0,0 +1,179 @@ +use std::borrow::Cow; +use std::sync::Arc; +use futures_util::{SinkExt, StreamExt}; +use mongodb::bson::oid::ObjectId; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::sync::mpsc::UnboundedSender; +use tokio_stream::wrappers::UnboundedReceiverStream; +use twilight_model::id::Id; +use twilight_model::id::marker::UserMarker; +use twilight_model::user::CurrentUserGuild; +use warp::ws::{Message, WebSocket}; +use crate::context::Context; +use crate::database::redis::PartialGuild; +use crate::models::config::GuildConfig; +use crate::ok_or_return; +use crate::server::guild::editing::GuildsEditing; +use crate::server::session::AuthorizationInformation; + +macro_rules! close { + ($tx: expr, $reason: expr) => { + let _ = $tx.send(OutboundAction::Close($reason)); + }; +} + +macro_rules! unwrap_or_close_and_return { + ($target: expr, $tx: expr, $reason: expr) => { + match $target { + Ok(value) => value, + Err(_) => { + close!($tx, $reason); + return + } + } + }; +} + +pub enum CloseReason { + MessageIsNotString, + CannotParseJSON +} + +impl CloseReason { + pub fn code(&self) -> u16 { + match self { + CloseReason::MessageIsNotString => 4001, + CloseReason::CannotParseJSON => 4002, + } + } + + pub fn text(&self) -> impl Into> { + match self { + CloseReason::MessageIsNotString => "Message is not UTF-8 string", + CloseReason::CannotParseJSON => "Cannot parse JSON message" + } + } +} + +pub struct Connection { + pub user_id: Id, + pub session_id: ObjectId, + pub tx: UnboundedSender +} + +pub async fn handle_connection( + context: Arc, + ws: WebSocket, + info: Arc, + guild: CurrentUserGuild, + guilds_editing: Arc +) { + let (mut ws_tx, mut ws_rx) = ws.split(); + + let (tx, rx) = + tokio::sync::mpsc::unbounded_channel(); + + let mut rx = UnboundedReceiverStream::new(rx); + + tokio::spawn(async move { + while let Some(message) = rx.next().await { + match message { + OutboundAction::Message(msg) => { + if let Ok(data) = serde_json::to_string(&msg) { + let _ = ws_tx.send(Message::text(data)).await; + } + } + OutboundAction::Close(reason) => { + let _ = ws_tx.send( + Message::close_with(reason.code(), reason.text()) + ).await; + } + } + } + let _ = ws_tx.close().await; + }); + + let session_id = ObjectId::new(); + let guild_id = guild.id; + + guilds_editing.add_connection(guild_id, Connection { + user_id: info.user.id, + session_id, + tx: tx.to_owned(), + }).await; + + let _ = tx.send(OutboundAction::Message(OutboundMessage::Initialization { + cached: ok_or_return!(context.redis.get_guild(guild.id).await, Ok), + oauth2: guild.to_owned(), + session_id + })); + + while let Some(result) = ws_rx.next().await { + let message = match result { + Ok(message) => message, + Err(_) => { + break + } + }; + + if !message.is_text() { + break + } + + on_message(message, &info, &guild, &tx, &guilds_editing, &context).await; + } + + guilds_editing.remove_connection(guild_id, session_id).await; +} +#[derive(Debug, Deserialize)] +#[serde(tag = "action", content = "data")] +enum InboundMessage { + GuildConfigUpdate(Value), + ApplyChanges +} + +#[derive(Debug, Serialize)] +#[serde(tag = "action", content = "data")] +pub enum OutboundMessage { + Initialization { + oauth2: CurrentUserGuild, + cached: PartialGuild, + session_id: ObjectId + }, + UpdateConfigurationData { + saved_config: GuildConfig, + changes: Value, + users: Vec> + } +} + +pub enum OutboundAction { + Message(OutboundMessage), + Close(CloseReason) +} + +async fn on_message( + message: Message, + info: &Arc, + guild: &CurrentUserGuild, + tx: &UnboundedSender, + guilds_editing: &Arc, + context: &Arc +) { + let message = unwrap_or_close_and_return!( + message.to_str(), tx, CloseReason::MessageIsNotString + ); + + let message: InboundMessage = unwrap_or_close_and_return!( + serde_json::from_str(message), tx, CloseReason::CannotParseJSON + ); + + match message { + InboundMessage::GuildConfigUpdate(changes) => { + let _ = guilds_editing.marge_changes(info.user.id, guild.id, changes).await; + guilds_editing.broadcast_changes(context, guild.id).await; + } + InboundMessage::ApplyChanges => {} + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 78fa73f..e7e43a6 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -23,6 +23,12 @@ mod session; #[cfg(feature = "http-interactions")] pub mod authorize; +#[cfg(feature = "api")] +pub mod guild { + pub mod editing; + pub mod ws; +} + mod http_server { use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; diff --git a/src/server/routes/guilds/_id.rs b/src/server/routes/guilds/_id.rs index 26fb632..fe27412 100644 --- a/src/server/routes/guilds/_id.rs +++ b/src/server/routes/guilds/_id.rs @@ -1,19 +1,18 @@ -use std::collections::HashMap; use std::sync::Arc; -use futures_util::{SinkExt, StreamExt, TryFutureExt}; + +use futures_util::StreamExt; use twilight_model::guild::Permissions; use twilight_model::id::Id; use twilight_model::id::marker::GuildMarker; -use warp::Filter; -use futures_util::FutureExt; -use futures_util::stream::SplitSink; -use serde::{Deserialize, Deserializer}; -use serde_json::Value; use twilight_model::user::CurrentUserGuild; -use warp::ws::{Message, WebSocket, Ws}; -use crate::context::Context; +use warp::Filter; +use warp::ws::Ws; + use crate::{response_type, with_value}; +use crate::context::Context; use crate::server::error::{MapErrorIntoInternalRejection, Rejection}; +use crate::server::guild::editing::GuildsEditing; +use crate::server::guild::ws::handle_connection; use crate::server::session::{Authenticator, AuthorizationInformation, authorize_user, Sessions}; type GuildId = Id; @@ -25,14 +24,26 @@ pub fn run( ) -> response_type!() { let with_context = with_value!(context); + let guilds_editing = Arc::new(GuildsEditing::default()); + let with_guilds_editing = with_value!(guilds_editing); + warp::path!("guilds" / GuildId) .and(authorize_user(authenticator, sessions)) - .and(with_context) + .and(with_context.clone()) .and_then(check_guild) .and(warp::ws()) - .map(|(info, guild): (Arc, CurrentUserGuild), ws: Ws| { + .and(with_context) + .and(with_guilds_editing) + .map(| + (info, guild): (Arc, CurrentUserGuild), + ws: Ws, + context: Arc, + guilds_editing: Arc + | { + let context = context.clone(); + let guilds_editing = guilds_editing.clone(); ws.on_upgrade(move |ws| { - handle_connection(ws, info, guild) + handle_connection(context, ws, info, guild, guilds_editing) }) }) } @@ -55,72 +66,3 @@ async fn check_guild( Ok((info, guild)) } - -macro_rules! close { - ($tx: expr) => { - let _ = $tx.close().await; - }; -} - -macro_rules! unwrap_or_close_and_return { - ($target: expr, $tx: expr) => { - match $target { - Ok(value) => value, - Err(_) => { - close!($tx); - return - } - } - }; -} - -async fn handle_connection( - ws: WebSocket, - info: Arc, - guild: CurrentUserGuild -) { - let (mut tx, mut rx) = ws.split(); - - while let Some(result) = rx.next().await { - let message = match result { - Ok(message) => message, - Err(_) => { - close!(tx); - break - } - }; - - if !message.is_text() { - close!(tx); - break - } - - on_message(message, &info, &guild, &mut tx).await; - } -} - -#[derive(Debug, Deserialize)] -#[serde(tag = "action", content = "data")] -enum InboundMessage { - GuildConfigUpdate(HashMap), - ApplyChanges -} - -async fn on_message( - message: Message, - _info: &Arc, - _guild: &CurrentUserGuild, - tx: &mut SplitSink -) { - let message = unwrap_or_close_and_return!(message.to_str(), tx); - let message: InboundMessage = unwrap_or_close_and_return!( - serde_json::from_str(message), tx - ); - - match message { - InboundMessage::GuildConfigUpdate(_) => {} - InboundMessage::ApplyChanges => {} - } - - () -} diff --git a/src/server/routes/guilds/mod.rs b/src/server/routes/guilds/mod.rs index c7fb6d6..3fe4d88 100644 --- a/src/server/routes/guilds/mod.rs +++ b/src/server/routes/guilds/mod.rs @@ -37,7 +37,7 @@ async fn run( let ids: Vec> = guilds .iter().map(|guild| guild.id).collect(); - let mut pipe = redis::pipe().atomic().to_owned(); + let mut pipe = redis::pipe().to_owned(); ids.iter().for_each(|id| { pipe.exists(format!("guilds.{id}")); }); let result: Vec = pipe.query_async(