Skip to content

Commit

Permalink
add badtranslate
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacherr committed Jun 25, 2024
1 parent f989cb6 commit a1c6a63
Show file tree
Hide file tree
Showing 11 changed files with 237 additions and 7 deletions.
1 change: 1 addition & 0 deletions assyst-common/src/config/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct Urls {
pub filer: String,
pub eval: String,
pub wsi: String,
pub bad_translation: String,
}

#[derive(Deserialize)]
Expand Down
44 changes: 39 additions & 5 deletions assyst-core/src/command/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,43 @@ impl FlagDecode for ColourRemoveAllFlags {
}
flag_parse_argument! { ColourRemoveAllFlags }

#[derive(Default)]
pub struct BadTranslateFlags {
pub chain: bool,
pub count: Option<u64>,
}
impl FlagDecode for BadTranslateFlags {
fn from_str(input: &str) -> anyhow::Result<Self>
where
Self: Sized,
{
let mut valid_flags = HashMap::new();
valid_flags.insert("chain", FlagType::NoValue);
valid_flags.insert("count", FlagType::WithValue);

let raw_decode = flags_from_str(input, valid_flags)?;

let count = raw_decode
.get("count")
.map(|x| x.clone().map(|y| y.parse::<u64>()))
.flatten();

let count = if let Some(inner) = count {
Some(inner.context("Failed to parse translation count")?)
} else {
None
};

let result = Self {
chain: raw_decode.get("chain").is_some(),
count,
};

Ok(result)
}
}
flag_parse_argument! { BadTranslateFlags }

#[derive(Default)]
pub struct ChargeFlags {
pub verbose: bool,
Expand Down Expand Up @@ -199,7 +236,7 @@ pub fn flags_from_str(input: &str, valid_flags: ValidFlags) -> anyhow::Result<Ha

if let FlagType::NoValue = flag {
entries.insert(c.clone(), None);
current_flag = None;
current_flag = Some(arg[2..].to_owned());
} else {
bail!("Flag {c} expects a value, but none was provided");
}
Expand All @@ -219,10 +256,7 @@ pub fn flags_from_str(input: &str, valid_flags: ValidFlags) -> anyhow::Result<Ha
} else {
bail!("Flag {c} does not expect a value, even though one was provided");
}
} //else {
// random value not following any flag: ignore?

//}
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion assyst-core/src/command/fun/colour.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ pub async fn remove(ctxt: CommandCtxt<'_>, name: Word) -> anyhow::Result<()> {
category = Category::Fun,
usage = "--i-am-sure",
examples = ["", "--i-am-sure"],
flag_descriptions = [
("i-am-sure", "Confirm this operation"),
]
)]
pub async fn remove_all(ctxt: CommandCtxt<'_>, flags: ColourRemoveAllFlags) -> anyhow::Result<()> {
if let Some(id) = ctxt.data.guild_id.map(|x| x.get()) {
Expand Down Expand Up @@ -367,7 +370,7 @@ pub async fn default(ctxt: CommandCtxt<'_>, colour: Option<Word>) -> anyhow::Res
define_commandgroup! {
name: colour,
access: Availability::Public,
category: Category::Misc,
category: Category::Fun,
aliases: ["color", "colours", "colors"],
cooldown: Duration::from_secs(5),
description: "Assyst colour roles",
Expand Down
1 change: 1 addition & 0 deletions assyst-core/src/command/fun/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod colour;
pub mod translation;
58 changes: 58 additions & 0 deletions assyst-core/src/command/fun/translation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use std::time::Duration;

use anyhow::{bail, Context};
use assyst_proc_macro::command;

use crate::command::arguments::Rest;
use crate::command::flags::BadTranslateFlags;
use crate::command::{Availability, Category, CommandCtxt};
use crate::rest::bad_translation::{
bad_translate as bad_translate_default, bad_translate_with_count, TranslateResult, Translation,
};

#[command(
aliases = ["bt"],
description = "Badly translate some text",
access = Availability::Public,
cooldown = Duration::from_secs(5),
category = Category::Fun,
usage = "[text]",
examples = ["hello i love assyst"],
flag_descriptions = [
("chain", "Show language chain"),
("count", "Set the amount of translations to perform")
],
send_processing = true
)]
pub async fn bad_translate(ctxt: CommandCtxt<'_>, text: Rest, flags: BadTranslateFlags) -> anyhow::Result<()> {
let TranslateResult {
result: Translation { text, .. },
translations,
} = if let Some(count) = flags.count {
if count < 10 {
bad_translate_with_count(&ctxt.assyst().reqwest_client, &text.0, count as u32)
.await
.context("Failed to run bad translation")?
} else {
bail!("Translation count cannot exceed 10")
}
} else {
bad_translate_default(&ctxt.assyst().reqwest_client, &text.0)
.await
.context("Failed to run bad translation")?
};

let mut output = format!("**Output:**\n{}", text);

if flags.chain {
output += "\n\n**Language chain:**\n";

for (idx, translation) in translations.iter().enumerate() {
output += &format!("{}) {}: {}\n", idx + 1, translation.lang, translation.text);
}
}

ctxt.reply(output).await?;

Ok(())
}
2 changes: 1 addition & 1 deletion assyst-core/src/command/misc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::time::{Duration, Instant};
use crate::command::Availability;
use crate::rest::audio_identification::identify_song_notsoidentify;

use super::arguments::{Image, ImageUrl, Rest, Time, Word};
use super::arguments::{Image, ImageUrl, Rest, Word};
use super::{Category, CommandCtxt};

use anyhow::Context;
Expand Down
1 change: 1 addition & 0 deletions assyst-core/src/command/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ macro_rules! declare_commands {
}

declare_commands!(
fun::translation::bad_translate_command,
fun::colour::colour_command,
misc::enlarge_command,
misc::exec_command,
Expand Down
128 changes: 128 additions & 0 deletions assyst-core/src/rest/bad_translation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::fmt::Display;

use assyst_common::config::CONFIG;
use reqwest::{Client, Error as ReqwestError};
use serde::Deserialize;
use std::error::Error;

const MAX_ATTEMPTS: u8 = 5;

mod routes {
pub const LANGUAGES: &str = "/languages";
}

#[derive(Debug)]
pub enum TranslateError {
Reqwest(ReqwestError),
Raw(&'static str),
}

impl Display for TranslateError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TranslateError::Reqwest(_) => write!(f, "A network error occurred"),
TranslateError::Raw(s) => write!(f, "{}", s),
}
}
}

impl Error for TranslateError {}

#[derive(Deserialize)]
pub struct Translation {
pub lang: String,
pub text: String,
}

#[derive(Deserialize)]
pub struct TranslateResult {
pub translations: Vec<Translation>,
pub result: Translation,
}

async fn translate_retry(
client: &Client,
text: &str,
target: Option<&str>,
count: Option<u32>,
additional_data: Option<&[(&str, String)]>,
) -> Result<TranslateResult, TranslateError> {
let mut query_args = vec![("text", text.to_owned())];

if let Some(target) = target {
query_args.push(("target", target.to_owned()));
}

if let Some(count) = count {
query_args.push(("count", count.to_string()));
}

if let Some(data) = additional_data {
for (k, v) in data.into_iter() {
query_args.push((k, v.to_string()));
}
}

client
.get(&CONFIG.urls.bad_translation)
.query(&query_args)
.send()
.await
.map_err(TranslateError::Reqwest)?
.json()
.await
.map_err(TranslateError::Reqwest)
}

async fn translate(
client: &Client,
text: &str,
target: Option<&str>,
count: Option<u32>,
additional_data: Option<&[(&str, String)]>,
) -> Result<TranslateResult, TranslateError> {
let mut attempt = 0;

while attempt <= MAX_ATTEMPTS {
match translate_retry(client, text, target, count, additional_data).await {
Ok(result) => return Ok(result),
Err(e) => eprintln!("Proxy failed! {:?}", e),
};

attempt += 1;
}

Err(TranslateError::Raw("BT Failed: Too many attempts"))
}

pub async fn bad_translate(client: &Client, text: &str) -> Result<TranslateResult, TranslateError> {
translate(client, text, None, None, None).await
}

pub async fn bad_translate_with_count(
client: &Client,
text: &str,
count: u32,
) -> Result<TranslateResult, TranslateError> {
translate(client, text, None, Some(count), None).await
}

pub async fn translate_single(client: &Client, text: &str, target: &str) -> Result<TranslateResult, TranslateError> {
translate(client, text, Some(target), Some(1), None).await
}

pub async fn get_languages(client: &Client) -> Result<Vec<(Box<str>, Box<str>)>, TranslateError> {
client
.get(format!("{}{}", CONFIG.urls.bad_translation, routes::LANGUAGES))
.send()
.await
.map_err(TranslateError::Reqwest)?
.json()
.await
.map_err(TranslateError::Reqwest)
}

pub async fn validate_language(client: &Client, provided_language: &str) -> Result<bool, TranslateError> {
let languages = get_languages(client).await?;
Ok(languages.iter().any(|(language, _)| &**language == provided_language))
}
1 change: 1 addition & 0 deletions assyst-core/src/rest/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod audio_identification;
pub mod bad_translation;
pub mod cooltext;
pub mod eval;
pub mod filer;
Expand Down
1 change: 1 addition & 0 deletions assyst-proc-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ pub fn command(attrs: TokenStream, func: TokenStream) -> TokenStream {
let flag_descriptions = fields.remove("flag_descriptions").unwrap_or_else(empty_array_expr);

let following = quote::quote! {
#[allow(non_camel_case_types)]
pub struct #struct_name;

#[::async_trait::async_trait]
Expand Down
2 changes: 2 additions & 0 deletions config.template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ proxy = []
filer = ""
# URL for the WSI TCP server.
wsi = ""
# Bad translation URL.
bad_translation = ""

[authentication]
# Token to authenticate with Discord.
Expand Down

0 comments on commit a1c6a63

Please sign in to comment.