Skip to content

Commit

Permalink
split massive bot code into smaller chunks, add lfm + beautiful loadi…
Browse files Browse the repository at this point in the history
…ng msg
  • Loading branch information
duckfromdiscord committed Dec 27, 2023
1 parent 18a1b42 commit e697f59
Show file tree
Hide file tree
Showing 10 changed files with 403 additions and 233 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ serde_json = "1.0.108"
env_logger = "0.10.1"
futures = "0.3.29"
mljcl = { git = "https://github.com/duckfromdiscord/mljcl", version = "0.4.1" }
lastfm = "0.6.1"
lastfm = { version = "0.6.1", git = "https://github.com/duckfromdiscord/lastfm" }
url = "2.5.0"
shuttle-runtime = { optional = true, version = "0.34.0" }
shuttle-secrets = { optional = true, version = "0.34.0" }
shuttle-serenity = { optional = true, version = "0.34.1", default-features = false, features = ["serenity-0-12-rustls_backend"] }
shuttle-shared-db = { optional = true, version = "0.34.0", features = ["postgres-rustls"] }
sqlx = { version = "0.7.3", features = ["postgres", "runtime-tokio"] }
futures-util = "0.3.30"
14 changes: 8 additions & 6 deletions src/db/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub async fn get_websites(pool: &PgPool, formatted_user: String) -> Vec<DiscordW
r#"
SELECT * FROM discord_websites
WHERE discord_username = $1
"#
"#,
)
.bind(formatted_user)
.fetch_all(pool)
Expand Down Expand Up @@ -52,7 +52,8 @@ pub async fn insert_website(pool: &PgPool, formatted_user: String, website: Stri
VALUES ( $1, $2 )
"#,
)
.bind(formatted_user).bind(website)
.bind(formatted_user)
.bind(website)
.execute(pool)
.await
.expect("Failed to add website to DB");
Expand All @@ -63,9 +64,10 @@ pub async fn insert_discord_pairing_code(pool: &PgPool, formatted_user: String,
r#"
INSERT INTO discord_pairing_codes
VALUES ( $1, $2 )
"#
"#,
)
.bind(formatted_user).bind(key)
.bind(formatted_user)
.bind(key)
.execute(pool)
.await
.expect("Failed to add pairing code to DB");
Expand All @@ -76,7 +78,7 @@ pub async fn delete_discord_pairing_code(pool: &PgPool, formatted_user: String)
r#"
DELETE FROM discord_pairing_codes
WHERE discord_username = $1
"#
"#,
)
.bind(formatted_user)
.execute(pool)
Expand All @@ -90,7 +92,7 @@ pub async fn delete_website(pool: &PgPool, formatted_user: String) -> u64 {
r#"
DELETE FROM discord_websites
WHERE discord_username = $1
"#
"#,
)
.bind(formatted_user)
.execute(pool)
Expand Down
252 changes: 28 additions & 224 deletions src/discord/bot.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
use crate::db::postgres::{
delete_discord_pairing_code, delete_website, get_discord_pairing_code, get_websites,
insert_discord_pairing_code, insert_website,
};
use crate::db::postgres::{get_discord_pairing_code, get_websites};
use crate::hos::*;
use core::num::NonZeroU16;
use mljcl::history::numscrobbles_async;
use mljcl::range::Range;
use mljcl::MalojaCredentials;
use serenity::all::CreateMessage;
use serenity::async_trait;
use serenity::builder::CreateEmbed;
use serenity::model::channel::PrivateChannel;
use serenity::model::gateway::Ready;
use serenity::model::prelude::Message;
use serenity::model::user::User;
Expand Down Expand Up @@ -46,34 +38,6 @@ pub fn format_user(user: User) -> String {
)
}

pub async fn try_dm_channel(
author: User,
original_message: Option<Message>,
ctx: Context,
) -> Option<PrivateChannel> {
let dm_channel = author.create_dm_channel(ctx.clone()).await;
match dm_channel {
Ok(dm_channel) => Some(dm_channel),
Err(err) => {
let _ = original_message?
.reply_ping(ctx, "Couldn't create a DM channel with you")
.await;
log::error!(
"Error creating a DM channel with {}, {}",
format_user(author),
err
);
None
}
}
}

macro_rules! dm_channel {
( $msg:ident, $ctx:ident ) => {
try_dm_channel($msg.clone().author, Some($msg.clone()), $ctx.clone()).await
};
}

impl Handler {
pub async fn handle_hos_user(
&self,
Expand Down Expand Up @@ -195,11 +159,11 @@ impl Handler {
let creds = self
.handle_website_user(formatted_user.clone(), ctx.clone(), msg.clone())
.await;
if creds.is_err() {
if let Ok(creds) = creds {
Some(creds)
} else {
self.handle_hos_user(formatted_user, ctx.clone(), msg.clone())
.await
} else {
Some(creds.unwrap())
}
}
}
Expand All @@ -216,197 +180,37 @@ impl EventHandler for Handler {
let formatted_user = format_user(msg.author.clone());

if msg.content == "!hos_setup" {
if let Some(dm_channel) = dm_channel!(msg, ctx) {
let mut match_found = false;

let query = get_discord_pairing_code(&self.pool, formatted_user.clone()).await;

if query.len() >= 1 {
match_found = true;
}

if !match_found {
let key = crate::generate_api_key();
//TODO: check unique
insert_discord_pairing_code(&self.pool, formatted_user, key.clone()).await;
dm_channel.send_message(ctx,
CreateMessage::new().content(format!("You have been assigned the pairing code `{}`. Make sure to pass this to your HOS client.", key))
).await.unwrap();
} else {
dm_channel
.send_message(ctx,
CreateMessage::new().content(
"You've already made a pairing code, or you have a website linked. Do `!reset` to revoke the code and/or remove the website.",
)
)
.await
.unwrap();
}
}
super::setups::hos_setup(ctx.clone(), msg.clone(), &self.pool, formatted_user).await;
} else if msg.content.starts_with("!website_setup") {
let arg = get_arg(msg.clone().content);
let mut match_found = false;

let query = get_websites(&self.pool, formatted_user.clone()).await;
if query.len() >= 1 {
match_found = true;
}

if match_found {
msg.reply_ping(
ctx.clone(),
"You've already set a website. Do `!reset` to remove it.",
)
.await
.unwrap();
}
if !arg.is_empty() {
if !(arg.starts_with("http://") || arg.starts_with("https://")) {
msg.reply_ping(ctx.clone(), "Remember that your website has to start with `http://` or `https://`. Try again with \
one of those two, and keep in mind if you're using https you cannot use an invalid certificate.").await.unwrap();
return;
}
msg.reply_ping(ctx.clone(), format!("Setting your website to {}.", arg))
.await
.unwrap();

insert_website(&self.pool, formatted_user, arg.clone()).await;
} else {
msg.reply_ping(ctx, "No website provided.").await.unwrap();
}
super::setups::website_setup(ctx.clone(), msg.clone(), &self.pool, formatted_user, arg)
.await;
} else if msg.content == "!reset" {
if let Some(dm_channel) = dm_channel!(msg, ctx) {
for row in get_websites(&self.pool, formatted_user.clone()).await {
if row.discord_username == Some(formatted_user.clone()) {
dm_channel
.send_message(
ctx.clone(),
CreateMessage::new().content(format!(
"Removing your website `{}` from mljboard's database. \
Run `!site_setup` to assign yourself one.",
row.website.unwrap_or("[none]".to_string())
)),
)
.await
.unwrap();
}
}

let query = delete_website(&self.pool, formatted_user.clone()).await;

if query >= 1 {
dm_channel
.send_message(
ctx.clone(),
CreateMessage::new().content(format!("Removed {} entries.", query)),
)
.await
.unwrap();
}

let query = get_discord_pairing_code(&self.pool, formatted_user.clone()).await;

let mut affected: u16 = 0;

for row in query {
affected += 1;
dm_channel
.send_message(
ctx.clone(),
CreateMessage::new().content(format!(
"Removing your pairing code `{}` from mljboard's database. \
Run `!hos_setup` to be issued a new one.",
row.pairing_code.unwrap_or("[none]".to_string())
)),
)
.await
.unwrap();
}

let query = delete_discord_pairing_code(&self.pool, formatted_user).await;

if query >= 1 {
dm_channel
.send_message(
ctx.clone(),
CreateMessage::new().content(format!("Removed {} entries.", query)),
)
.await
.unwrap();
}

if affected == 0 {
dm_channel
.send_message(
ctx.clone(),
CreateMessage::new()
.content("We couldn't find any pairing codes that were yours."),
)
.await
.unwrap();
}
}
super::setups::reset(ctx.clone(), msg.clone(), &self.pool, formatted_user).await;
} else if msg.content == "!scrobbles" {
if let Some(creds) = self
let creds = self
.handle_creds(formatted_user, ctx.clone(), msg.clone())
.await
{
let all_time_scrobbles = numscrobbles_async(
None,
Range::AllTime,
creds.clone(),
self.reqwest_client.clone(),
)
.await
.unwrap();
let this_year_scrobbles = numscrobbles_async(
None,
Range::In("thisyear".to_string()),
creds,
self.reqwest_client.clone(),
)
.await
.unwrap();
msg.channel_id
.send_message(
ctx,
CreateMessage::new().embed(
CreateEmbed::new()
.title(format!("{}'s scrobbles", msg.author.name))
.field("All time", all_time_scrobbles.to_string(), false)
.field("This year", this_year_scrobbles.to_string(), false),
),
)
.await
.unwrap();
};
.await;

super::ops::scrobbles_cmd(msg.clone(), self.reqwest_client.clone(), creds, ctx.clone())
.await;
} else if msg.content.starts_with("!artistscrobbles") {
let arg = get_arg(msg.clone().content);

if let Some(creds) = self
let creds = self
.handle_creds(formatted_user, ctx.clone(), msg.clone())
.await
{
let all_time_scrobbles = numscrobbles_async(
Some(arg.clone()),
Range::AllTime,
creds.clone(),
self.reqwest_client.clone(),
)
.await
.unwrap();
msg.channel_id
.send_message(
ctx,
CreateMessage::new().embed(
CreateEmbed::new()
.title(format!("{}'s scrobbles for {}", msg.author.name, arg))
.field("All time", all_time_scrobbles.to_string(), false),
),
)
.await
.unwrap();
}
.await;

super::ops::artistscrobbles_cmd(
msg.clone(),
self.reqwest_client.clone(),
creds,
ctx.clone(),
arg,
)
.await;
} else if msg.content.starts_with("!lfmuser") {
let arg = get_arg(msg.clone().content);
super::lastfm::lfmuser_cmd(ctx.clone(), msg, self.lastfm_api.clone(), arg).await;
}
}

Expand All @@ -428,7 +232,7 @@ pub async fn build_bot(
| GatewayIntents::DIRECT_MESSAGES
| GatewayIntents::MESSAGE_CONTENT;

Client::builder(&token, intents).event_handler(Handler {
Client::builder(token, intents).event_handler(Handler {
pool,
hos_server_ip,
hos_server_port,
Expand Down
Loading

0 comments on commit e697f59

Please sign in to comment.