Skip to content

Commit

Permalink
Switch to synchronous locks for data
Browse files Browse the repository at this point in the history
This required some changes to avoid holding a lock across an await
boundary

See
[here](https://docs.rs/tokio/1.0.2/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use)
for rationale
  • Loading branch information
sam-kirby committed Jan 21, 2021
1 parent 36e9edd commit beeded8
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 48 deletions.
58 changes: 57 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "taskinator"
version = "0.2.0"
version = "0.2.1"
authors = ["Sam Kirby <[email protected]>"]
edition = "2018"
license = "AGPL-3.0-or-later"
Expand All @@ -9,6 +9,7 @@ license = "AGPL-3.0-or-later"

[dependencies]
futures = "0.3"
parking_lot = "0.11"
tokio-stream = "0.1"
toml = "0.5"
tracing = "0.1"
Expand Down
96 changes: 50 additions & 46 deletions src/bot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use std::{
time::Duration,
};

use tokio::{signal::ctrl_c, sync::RwLock, task::JoinHandle, time::sleep};
use parking_lot::RwLock;
use tokio::{signal::ctrl_c, task::JoinHandle, time::sleep};
use tracing::error;
use twilight_cache_inmemory::{model::CachedMember, InMemoryCache as DiscordCache, ResourceType};
use twilight_command_parser::{Arguments, Command, CommandParserConfig, Parser};
Expand Down Expand Up @@ -107,7 +108,7 @@ impl Bot<'_> {
(owners, UserId(app_info.id.0))
};

self.owners.write().await.replace(owners);
self.owners.write().replace(owners);
self.bot_id.replace(current_user);

Ok(())
Expand Down Expand Up @@ -161,7 +162,7 @@ impl Bot<'_> {
.delete_message(msg.channel_id, msg.id)
.await?;

if self.is_in_control(msg.author.id).await {
if self.is_in_control(msg.author.id) {
self.end_game().await?;
}
}
Expand All @@ -181,8 +182,8 @@ impl Bot<'_> {
.delete_message(msg.channel_id, msg.id)
.await?;

if self.is_in_control(msg.author.id).await {
if self.is_game_in_progress().await {
if self.is_in_control(msg.author.id) {
if self.is_game_in_progress() {
self.end_game().await?;
}

Expand All @@ -196,10 +197,10 @@ impl Bot<'_> {
}

pub async fn reaction_add_handler(&self, reaction: &Reaction) -> Result<()> {
if self.is_reacting_to_control(&reaction).await {
if self.is_reacting_to_control(&reaction) {
match reaction.emoji {
ReactionType::Unicode { ref name } if name == EMER_EMOJI => {
if self.is_in_control(reaction.user_id).await {
if self.is_in_control(reaction.user_id) {
self.begin_meeting().await?;
}
}
Expand All @@ -214,8 +215,8 @@ impl Bot<'_> {
}

pub async fn reaction_remove_handler(&self, reaction: &Reaction) -> Result<()> {
if self.is_reacting_to_control(&reaction).await
&& self.is_in_control(reaction.user_id).await
if self.is_reacting_to_control(&reaction)
&& self.is_in_control(reaction.user_id)
&& matches!(reaction.emoji, ReactionType::Unicode { ref name } if name == EMER_EMOJI)
{
self.end_meeting().await?;
Expand Down Expand Up @@ -261,7 +262,7 @@ impl Bot<'_> {
Ok(())
});

self.game.write().await.replace(Game {
self.game.write().replace(Game {
dead: HashSet::new(),
ctrl_channel: msg.channel_id,
ctrl_msg: ctrl_msg.id,
Expand All @@ -288,7 +289,8 @@ impl Bot<'_> {
}

async fn end_game(&self) -> Result<()> {
if let Some(game) = self.game.write().await.take() {
let game = self.game.write().take();
if let Some(game) = game {
self.discord_http
.delete_message(game.ctrl_channel, game.ctrl_msg)
.await?;
Expand All @@ -304,15 +306,15 @@ impl Bot<'_> {

let mut futures = Vec::new();

for member in self.get_members_in_channel(living_channel).await {
for member in self.get_members_in_channel(&living_channel) {
futures.push(
self.discord_http
.update_guild_member(member.guild_id, member.user.id)
.mute(false),
);
}

for member in self.get_members_in_channel(dead_channel).await {
for member in self.get_members_in_channel(&dead_channel) {
futures.push(
self.discord_http
.update_guild_member(member.guild_id, member.user.id)
Expand All @@ -335,10 +337,10 @@ impl Bot<'_> {
let mut futures = Vec::new();

{
let game_lock = self.game.read().await;
let game_lock = self.game.read();
let game = game_lock.as_ref().unwrap();

for member in self.get_members_in_channel(living_channel).await {
for member in self.get_members_in_channel(&living_channel) {
if game.dead.contains(&member.user.id) {
continue;
}
Expand All @@ -350,7 +352,7 @@ impl Bot<'_> {
}
}

for member in self.get_members_in_channel(dead_channel).await {
for member in self.get_members_in_channel(&dead_channel) {
futures.push(
self.discord_http
.update_guild_member(member.guild_id, member.user.id)
Expand All @@ -361,7 +363,7 @@ impl Bot<'_> {

self.batch(futures).await;

let mut game_lock = self.game.write().await;
let mut game_lock = self.game.write();
let g = game_lock.as_mut().expect("expected game");
g.meeting_in_progress = true;

Expand All @@ -375,11 +377,10 @@ impl Bot<'_> {
.unwrap();

let (alive_players, dead_players): (Vec<_>, Vec<_>) = {
let game_lock = self.game.read().await;
let game_lock = self.game.read();
let game = game_lock.as_ref().unwrap();

self.get_members_in_channel(living_channel)
.await
self.get_members_in_channel(&living_channel)
.into_iter()
.partition(|p| !game.dead.contains(&p.user.id))
};
Expand All @@ -405,16 +406,16 @@ impl Bot<'_> {

self.batch(futures).await;

let mut game_lock = self.game.write().await;
let mut game_lock = self.game.write();
let g = game_lock.as_mut().expect("expected game");
g.meeting_in_progress = false;

Ok(())
}

async fn deadify(&self, msg: &Message, mut args: Arguments<'_>) -> Result<()> {
if let Some(broadcast) = self.broadcast().await {
if self.is_in_control(msg.author.id).await {
if let Some(broadcast) = self.broadcast() {
if self.is_in_control(msg.author.id) {
match args.next().map(UserId::parse) {
Some(Ok(target)) => {
let reply = broadcast
Expand Down Expand Up @@ -451,16 +452,22 @@ impl Bot<'_> {
}

async fn make_dead(&self, target: UserId) {
if let Some(game) = self.game.write().await.as_mut() {
let mut fut = None;

if let Some(game) = self.game.write().as_mut() {
if game.dead.insert(target) && game.meeting_in_progress {
if let Err(why) = self
.discord_http
.update_guild_member(game.guild_id, target)
.mute(true)
.await
{
error!("Error occurred when making {} dead:\n{}", target, why);
}
let guild_id = game.guild_id;
fut = Some(
self.discord_http
.update_guild_member(guild_id, target)
.mute(true),
);
}
}

if let Some(fut) = fut {
if let Err(why) = fut.await {
error!("Error occured when making {} dead\n{}", target, why);
}
}
}
Expand All @@ -475,7 +482,8 @@ impl Bot<'_> {
.filter_map(StdResult::err)
.collect::<Vec<_>>();
if !errors.is_empty() {
if let Some(channel) = self.game.read().await.as_ref().map(|g| g.ctrl_channel) {
let channel = self.game.read().as_ref().map(|g| g.ctrl_channel);
if let Some(channel) = channel {
let _ = self
.discord_http
.create_message(channel)
Expand All @@ -489,18 +497,14 @@ impl Bot<'_> {
}
}

async fn broadcast(&self) -> Option<CreateMessage<'_>> {
fn broadcast(&self) -> Option<CreateMessage<'_>> {
self.game
.read()
.await
.as_ref()
.map(|g| self.discord_http.create_message(g.ctrl_channel))
}

async fn get_members_in_channel(
&self,
voice_channel: Arc<GuildChannel>,
) -> Vec<Arc<CachedMember>> {
fn get_members_in_channel(&self, voice_channel: &GuildChannel) -> Vec<Arc<CachedMember>> {
self.cache
.voice_channel_states(voice_channel.id())
.map_or(Vec::new(), |vs| {
Expand All @@ -511,18 +515,18 @@ impl Bot<'_> {
})
}

async fn is_game_in_progress(&self) -> bool {
self.game.read().await.is_some()
fn is_game_in_progress(&self) -> bool {
self.game.read().is_some()
}

async fn is_in_control(&self, user_id: UserId) -> bool {
matches!(self.owners.read().await.as_ref(), Some(owners) if owners.contains(&user_id))
|| matches!(self.game.read().await.as_ref(), Some(game) if game.ctrl_user == user_id)
fn is_in_control(&self, user_id: UserId) -> bool {
matches!(self.owners.read().as_ref(), Some(owners) if owners.contains(&user_id))
|| matches!(self.game.read().as_ref(), Some(game) if game.ctrl_user == user_id)
}

async fn is_reacting_to_control(&self, reaction: &Reaction) -> bool {
fn is_reacting_to_control(&self, reaction: &Reaction) -> bool {
matches!(
self.game.read().await.as_ref(),
self.game.read().as_ref(),
Some(game) if game.ctrl_msg == reaction.message_id,
)
}
Expand Down

0 comments on commit beeded8

Please sign in to comment.