-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
277 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
use std::collections::{BTreeSet, HashMap}; | ||
use std::sync::Arc; | ||
use mongodb::bson::oid::ObjectId; | ||
use serde_json::Value; | ||
use tokio::sync::{Mutex, RwLock}; | ||
use twilight_model::id::Id; | ||
use twilight_model::id::marker::{GuildMarker, UserMarker}; | ||
use crate::context::Context; | ||
use crate::server::guild::ws::{Connection, OutboundAction, OutboundMessage}; | ||
|
||
struct GuildEditingState { | ||
pub connections: Vec<Arc<Connection>>, | ||
pub changes: Value, | ||
pub edited_by: BTreeSet<Id<UserMarker>> | ||
} | ||
|
||
#[derive(Default)] | ||
pub struct GuildsEditing(RwLock<HashMap<Id<GuildMarker>, Arc<Mutex<GuildEditingState>>>>); | ||
|
||
impl GuildsEditing { | ||
pub async fn add_connection(&self, guild_id: Id<GuildMarker>, connection_data: Connection) { | ||
todo!() | ||
} | ||
|
||
pub async fn remove_connection(&self, guild_id: Id<GuildMarker>, session_id: ObjectId) { | ||
todo!() | ||
} | ||
|
||
pub async fn marge_changes( | ||
&self, | ||
author: Id<UserMarker>, | ||
guild_id: Id<GuildMarker>, | ||
changes: Value | ||
) -> Option<()> { | ||
let guild = self.get_guild(guild_id).await?; | ||
let mut guild_lock = guild.lock().await; | ||
json_patch::merge(&mut guild_lock.changes, &changes); | ||
guild_lock.edited_by.insert(author); | ||
Some(()) | ||
} | ||
|
||
async fn get_guild(&self, guild_id: Id<GuildMarker>) -> Option<Arc<Mutex<GuildEditingState>>> { | ||
let list_lock = self.0.read().await; | ||
list_lock.get(&guild_id).cloned() | ||
} | ||
|
||
pub async fn broadcast_changes(&self, context: &Arc<Context>, guild_id: Id<GuildMarker>) -> Option<()> { | ||
let config = context.mongodb.get_config(guild_id).await.ok()?; | ||
let guild = self.get_guild(guild_id).await?; | ||
let guild_lock = guild.lock().await; | ||
let users = guild_lock.connections | ||
.iter().map(|connection| connection.user_id) | ||
.collect::<Vec<Id<UserMarker>>>(); | ||
|
||
for connection in &guild_lock.connections { | ||
let _ = connection.tx.send(OutboundAction::Message(OutboundMessage::UpdateConfigurationData { | ||
saved_config: config.to_owned(), | ||
changes: guild_lock.changes.to_owned(), | ||
users: users.to_owned(), | ||
})); | ||
} | ||
|
||
Some(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
use std::borrow::Cow; | ||
use std::sync::Arc; | ||
use futures_util::{SinkExt, StreamExt}; | ||
use mongodb::bson::oid::ObjectId; | ||
use serde::{Deserialize, Serialize}; | ||
use serde_json::Value; | ||
use tokio::sync::mpsc::UnboundedSender; | ||
use tokio_stream::wrappers::UnboundedReceiverStream; | ||
use twilight_model::id::Id; | ||
use twilight_model::id::marker::UserMarker; | ||
use twilight_model::user::CurrentUserGuild; | ||
use warp::ws::{Message, WebSocket}; | ||
use crate::context::Context; | ||
use crate::database::redis::PartialGuild; | ||
use crate::models::config::GuildConfig; | ||
use crate::ok_or_return; | ||
use crate::server::guild::editing::GuildsEditing; | ||
use crate::server::session::AuthorizationInformation; | ||
|
||
macro_rules! close { | ||
($tx: expr, $reason: expr) => { | ||
let _ = $tx.send(OutboundAction::Close($reason)); | ||
}; | ||
} | ||
|
||
macro_rules! unwrap_or_close_and_return { | ||
($target: expr, $tx: expr, $reason: expr) => { | ||
match $target { | ||
Ok(value) => value, | ||
Err(_) => { | ||
close!($tx, $reason); | ||
return | ||
} | ||
} | ||
}; | ||
} | ||
|
||
pub enum CloseReason { | ||
MessageIsNotString, | ||
CannotParseJSON | ||
} | ||
|
||
impl CloseReason { | ||
pub fn code(&self) -> u16 { | ||
match self { | ||
CloseReason::MessageIsNotString => 4001, | ||
CloseReason::CannotParseJSON => 4002, | ||
} | ||
} | ||
|
||
pub fn text(&self) -> impl Into<Cow<'static, str>> { | ||
match self { | ||
CloseReason::MessageIsNotString => "Message is not UTF-8 string", | ||
CloseReason::CannotParseJSON => "Cannot parse JSON message" | ||
} | ||
} | ||
} | ||
|
||
pub struct Connection { | ||
pub user_id: Id<UserMarker>, | ||
pub session_id: ObjectId, | ||
pub tx: UnboundedSender<OutboundAction> | ||
} | ||
|
||
pub async fn handle_connection( | ||
context: Arc<Context>, | ||
ws: WebSocket, | ||
info: Arc<AuthorizationInformation>, | ||
guild: CurrentUserGuild, | ||
guilds_editing: Arc<GuildsEditing> | ||
) { | ||
let (mut ws_tx, mut ws_rx) = ws.split(); | ||
|
||
let (tx, rx) = | ||
tokio::sync::mpsc::unbounded_channel(); | ||
|
||
let mut rx = UnboundedReceiverStream::new(rx); | ||
|
||
tokio::spawn(async move { | ||
while let Some(message) = rx.next().await { | ||
match message { | ||
OutboundAction::Message(msg) => { | ||
if let Ok(data) = serde_json::to_string(&msg) { | ||
let _ = ws_tx.send(Message::text(data)).await; | ||
} | ||
} | ||
OutboundAction::Close(reason) => { | ||
let _ = ws_tx.send( | ||
Message::close_with(reason.code(), reason.text()) | ||
).await; | ||
} | ||
} | ||
} | ||
let _ = ws_tx.close().await; | ||
}); | ||
|
||
let session_id = ObjectId::new(); | ||
let guild_id = guild.id; | ||
|
||
guilds_editing.add_connection(guild_id, Connection { | ||
user_id: info.user.id, | ||
session_id, | ||
tx: tx.to_owned(), | ||
}).await; | ||
|
||
let _ = tx.send(OutboundAction::Message(OutboundMessage::Initialization { | ||
cached: ok_or_return!(context.redis.get_guild(guild.id).await, Ok), | ||
oauth2: guild.to_owned(), | ||
session_id | ||
})); | ||
|
||
while let Some(result) = ws_rx.next().await { | ||
let message = match result { | ||
Ok(message) => message, | ||
Err(_) => { | ||
break | ||
} | ||
}; | ||
|
||
if !message.is_text() { | ||
break | ||
} | ||
|
||
on_message(message, &info, &guild, &tx, &guilds_editing, &context).await; | ||
} | ||
|
||
guilds_editing.remove_connection(guild_id, session_id).await; | ||
} | ||
#[derive(Debug, Deserialize)] | ||
#[serde(tag = "action", content = "data")] | ||
enum InboundMessage { | ||
GuildConfigUpdate(Value), | ||
ApplyChanges | ||
} | ||
|
||
#[derive(Debug, Serialize)] | ||
#[serde(tag = "action", content = "data")] | ||
pub enum OutboundMessage { | ||
Initialization { | ||
oauth2: CurrentUserGuild, | ||
cached: PartialGuild, | ||
session_id: ObjectId | ||
}, | ||
UpdateConfigurationData { | ||
saved_config: GuildConfig, | ||
changes: Value, | ||
users: Vec<Id<UserMarker>> | ||
} | ||
} | ||
|
||
pub enum OutboundAction { | ||
Message(OutboundMessage), | ||
Close(CloseReason) | ||
} | ||
|
||
async fn on_message( | ||
message: Message, | ||
info: &Arc<AuthorizationInformation>, | ||
guild: &CurrentUserGuild, | ||
tx: &UnboundedSender<OutboundAction>, | ||
guilds_editing: &Arc<GuildsEditing>, | ||
context: &Arc<Context> | ||
) { | ||
let message = unwrap_or_close_and_return!( | ||
message.to_str(), tx, CloseReason::MessageIsNotString | ||
); | ||
|
||
let message: InboundMessage = unwrap_or_close_and_return!( | ||
serde_json::from_str(message), tx, CloseReason::CannotParseJSON | ||
); | ||
|
||
match message { | ||
InboundMessage::GuildConfigUpdate(changes) => { | ||
let _ = guilds_editing.marge_changes(info.user.id, guild.id, changes).await; | ||
guilds_editing.broadcast_changes(context, guild.id).await; | ||
} | ||
InboundMessage::ApplyChanges => {} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.