From a1c6a63f9af60344cf35f14f6bcfb010daacc8d6 Mon Sep 17 00:00:00 2001 From: Jacherr Date: Tue, 25 Jun 2024 20:27:08 +0100 Subject: [PATCH] add badtranslate --- assyst-common/src/config/config.rs | 1 + assyst-core/src/command/flags.rs | 44 ++++++- assyst-core/src/command/fun/colour.rs | 5 +- assyst-core/src/command/fun/mod.rs | 1 + assyst-core/src/command/fun/translation.rs | 58 ++++++++++ assyst-core/src/command/misc/mod.rs | 2 +- assyst-core/src/command/registry.rs | 1 + assyst-core/src/rest/bad_translation.rs | 128 +++++++++++++++++++++ assyst-core/src/rest/mod.rs | 1 + assyst-proc-macro/src/lib.rs | 1 + config.template.toml | 2 + 11 files changed, 237 insertions(+), 7 deletions(-) create mode 100644 assyst-core/src/command/fun/translation.rs create mode 100644 assyst-core/src/rest/bad_translation.rs diff --git a/assyst-common/src/config/config.rs b/assyst-common/src/config/config.rs index 582b8f6..911aa43 100644 --- a/assyst-common/src/config/config.rs +++ b/assyst-common/src/config/config.rs @@ -18,6 +18,7 @@ pub struct Urls { pub filer: String, pub eval: String, pub wsi: String, + pub bad_translation: String, } #[derive(Deserialize)] diff --git a/assyst-core/src/command/flags.rs b/assyst-core/src/command/flags.rs index abada96..3d4967c 100644 --- a/assyst-core/src/command/flags.rs +++ b/assyst-core/src/command/flags.rs @@ -100,6 +100,43 @@ impl FlagDecode for ColourRemoveAllFlags { } flag_parse_argument! { ColourRemoveAllFlags } +#[derive(Default)] +pub struct BadTranslateFlags { + pub chain: bool, + pub count: Option, +} +impl FlagDecode for BadTranslateFlags { + fn from_str(input: &str) -> anyhow::Result + where + Self: Sized, + { + let mut valid_flags = HashMap::new(); + valid_flags.insert("chain", FlagType::NoValue); + valid_flags.insert("count", FlagType::WithValue); + + let raw_decode = flags_from_str(input, valid_flags)?; + + let count = raw_decode + .get("count") + .map(|x| x.clone().map(|y| y.parse::())) + .flatten(); + + let count = if let Some(inner) = count { + Some(inner.context("Failed to parse translation count")?) + } else { + None + }; + + let result = Self { + chain: raw_decode.get("chain").is_some(), + count, + }; + + Ok(result) + } +} +flag_parse_argument! { BadTranslateFlags } + #[derive(Default)] pub struct ChargeFlags { pub verbose: bool, @@ -199,7 +236,7 @@ pub fn flags_from_str(input: &str, valid_flags: ValidFlags) -> anyhow::Result anyhow::Result, name: Word) -> anyhow::Result<()> { category = Category::Fun, usage = "--i-am-sure", examples = ["", "--i-am-sure"], + flag_descriptions = [ + ("i-am-sure", "Confirm this operation"), + ] )] pub async fn remove_all(ctxt: CommandCtxt<'_>, flags: ColourRemoveAllFlags) -> anyhow::Result<()> { if let Some(id) = ctxt.data.guild_id.map(|x| x.get()) { @@ -367,7 +370,7 @@ pub async fn default(ctxt: CommandCtxt<'_>, colour: Option) -> anyhow::Res define_commandgroup! { name: colour, access: Availability::Public, - category: Category::Misc, + category: Category::Fun, aliases: ["color", "colours", "colors"], cooldown: Duration::from_secs(5), description: "Assyst colour roles", diff --git a/assyst-core/src/command/fun/mod.rs b/assyst-core/src/command/fun/mod.rs index 36fdc84..95e1ba2 100644 --- a/assyst-core/src/command/fun/mod.rs +++ b/assyst-core/src/command/fun/mod.rs @@ -1 +1,2 @@ pub mod colour; +pub mod translation; diff --git a/assyst-core/src/command/fun/translation.rs b/assyst-core/src/command/fun/translation.rs new file mode 100644 index 0000000..46daaed --- /dev/null +++ b/assyst-core/src/command/fun/translation.rs @@ -0,0 +1,58 @@ +use std::time::Duration; + +use anyhow::{bail, Context}; +use assyst_proc_macro::command; + +use crate::command::arguments::Rest; +use crate::command::flags::BadTranslateFlags; +use crate::command::{Availability, Category, CommandCtxt}; +use crate::rest::bad_translation::{ + bad_translate as bad_translate_default, bad_translate_with_count, TranslateResult, Translation, +}; + +#[command( + aliases = ["bt"], + description = "Badly translate some text", + access = Availability::Public, + cooldown = Duration::from_secs(5), + category = Category::Fun, + usage = "[text]", + examples = ["hello i love assyst"], + flag_descriptions = [ + ("chain", "Show language chain"), + ("count", "Set the amount of translations to perform") + ], + send_processing = true +)] +pub async fn bad_translate(ctxt: CommandCtxt<'_>, text: Rest, flags: BadTranslateFlags) -> anyhow::Result<()> { + let TranslateResult { + result: Translation { text, .. }, + translations, + } = if let Some(count) = flags.count { + if count < 10 { + bad_translate_with_count(&ctxt.assyst().reqwest_client, &text.0, count as u32) + .await + .context("Failed to run bad translation")? + } else { + bail!("Translation count cannot exceed 10") + } + } else { + bad_translate_default(&ctxt.assyst().reqwest_client, &text.0) + .await + .context("Failed to run bad translation")? + }; + + let mut output = format!("**Output:**\n{}", text); + + if flags.chain { + output += "\n\n**Language chain:**\n"; + + for (idx, translation) in translations.iter().enumerate() { + output += &format!("{}) {}: {}\n", idx + 1, translation.lang, translation.text); + } + } + + ctxt.reply(output).await?; + + Ok(()) +} diff --git a/assyst-core/src/command/misc/mod.rs b/assyst-core/src/command/misc/mod.rs index 7561c2a..29e48f6 100644 --- a/assyst-core/src/command/misc/mod.rs +++ b/assyst-core/src/command/misc/mod.rs @@ -3,7 +3,7 @@ use std::time::{Duration, Instant}; use crate::command::Availability; use crate::rest::audio_identification::identify_song_notsoidentify; -use super::arguments::{Image, ImageUrl, Rest, Time, Word}; +use super::arguments::{Image, ImageUrl, Rest, Word}; use super::{Category, CommandCtxt}; use anyhow::Context; diff --git a/assyst-core/src/command/registry.rs b/assyst-core/src/command/registry.rs index 93b5b05..0b8274a 100644 --- a/assyst-core/src/command/registry.rs +++ b/assyst-core/src/command/registry.rs @@ -18,6 +18,7 @@ macro_rules! declare_commands { } declare_commands!( + fun::translation::bad_translate_command, fun::colour::colour_command, misc::enlarge_command, misc::exec_command, diff --git a/assyst-core/src/rest/bad_translation.rs b/assyst-core/src/rest/bad_translation.rs new file mode 100644 index 0000000..2e78003 --- /dev/null +++ b/assyst-core/src/rest/bad_translation.rs @@ -0,0 +1,128 @@ +use std::fmt::Display; + +use assyst_common::config::CONFIG; +use reqwest::{Client, Error as ReqwestError}; +use serde::Deserialize; +use std::error::Error; + +const MAX_ATTEMPTS: u8 = 5; + +mod routes { + pub const LANGUAGES: &str = "/languages"; +} + +#[derive(Debug)] +pub enum TranslateError { + Reqwest(ReqwestError), + Raw(&'static str), +} + +impl Display for TranslateError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TranslateError::Reqwest(_) => write!(f, "A network error occurred"), + TranslateError::Raw(s) => write!(f, "{}", s), + } + } +} + +impl Error for TranslateError {} + +#[derive(Deserialize)] +pub struct Translation { + pub lang: String, + pub text: String, +} + +#[derive(Deserialize)] +pub struct TranslateResult { + pub translations: Vec, + pub result: Translation, +} + +async fn translate_retry( + client: &Client, + text: &str, + target: Option<&str>, + count: Option, + additional_data: Option<&[(&str, String)]>, +) -> Result { + let mut query_args = vec![("text", text.to_owned())]; + + if let Some(target) = target { + query_args.push(("target", target.to_owned())); + } + + if let Some(count) = count { + query_args.push(("count", count.to_string())); + } + + if let Some(data) = additional_data { + for (k, v) in data.into_iter() { + query_args.push((k, v.to_string())); + } + } + + client + .get(&CONFIG.urls.bad_translation) + .query(&query_args) + .send() + .await + .map_err(TranslateError::Reqwest)? + .json() + .await + .map_err(TranslateError::Reqwest) +} + +async fn translate( + client: &Client, + text: &str, + target: Option<&str>, + count: Option, + additional_data: Option<&[(&str, String)]>, +) -> Result { + let mut attempt = 0; + + while attempt <= MAX_ATTEMPTS { + match translate_retry(client, text, target, count, additional_data).await { + Ok(result) => return Ok(result), + Err(e) => eprintln!("Proxy failed! {:?}", e), + }; + + attempt += 1; + } + + Err(TranslateError::Raw("BT Failed: Too many attempts")) +} + +pub async fn bad_translate(client: &Client, text: &str) -> Result { + translate(client, text, None, None, None).await +} + +pub async fn bad_translate_with_count( + client: &Client, + text: &str, + count: u32, +) -> Result { + translate(client, text, None, Some(count), None).await +} + +pub async fn translate_single(client: &Client, text: &str, target: &str) -> Result { + translate(client, text, Some(target), Some(1), None).await +} + +pub async fn get_languages(client: &Client) -> Result, Box)>, TranslateError> { + client + .get(format!("{}{}", CONFIG.urls.bad_translation, routes::LANGUAGES)) + .send() + .await + .map_err(TranslateError::Reqwest)? + .json() + .await + .map_err(TranslateError::Reqwest) +} + +pub async fn validate_language(client: &Client, provided_language: &str) -> Result { + let languages = get_languages(client).await?; + Ok(languages.iter().any(|(language, _)| &**language == provided_language)) +} diff --git a/assyst-core/src/rest/mod.rs b/assyst-core/src/rest/mod.rs index 6f8a725..b4d0a6b 100644 --- a/assyst-core/src/rest/mod.rs +++ b/assyst-core/src/rest/mod.rs @@ -1,4 +1,5 @@ pub mod audio_identification; +pub mod bad_translation; pub mod cooltext; pub mod eval; pub mod filer; diff --git a/assyst-proc-macro/src/lib.rs b/assyst-proc-macro/src/lib.rs index c0e90f6..ecf45d5 100644 --- a/assyst-proc-macro/src/lib.rs +++ b/assyst-proc-macro/src/lib.rs @@ -112,6 +112,7 @@ pub fn command(attrs: TokenStream, func: TokenStream) -> TokenStream { let flag_descriptions = fields.remove("flag_descriptions").unwrap_or_else(empty_array_expr); let following = quote::quote! { + #[allow(non_camel_case_types)] pub struct #struct_name; #[::async_trait::async_trait] diff --git a/config.template.toml b/config.template.toml index 551f952..f14f7bb 100644 --- a/config.template.toml +++ b/config.template.toml @@ -11,6 +11,8 @@ proxy = [] filer = "" # URL for the WSI TCP server. wsi = "" +# Bad translation URL. +bad_translation = "" [authentication] # Token to authenticate with Discord.