diff --git a/crates/tabby-inference/src/decoding.rs b/crates/tabby-inference/src/decoding.rs index 78cb1a764de..fddd308e02e 100644 --- a/crates/tabby-inference/src/decoding.rs +++ b/crates/tabby-inference/src/decoding.rs @@ -1,11 +1,11 @@ use std::sync::Arc; use dashmap::DashMap; -use regex::Regex; +use regex::RegexSet; use tokenizers::tokenizer::Tokenizer; pub struct DecodingFactory { - stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, + stop_regex_cache: DashMap<&'static Vec<&'static str>, RegexSet>, } fn reverse(s: T) -> String @@ -33,7 +33,7 @@ impl DecodingFactory { IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids) } - fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option { + fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option { if stop_words.is_empty() { None } else { @@ -48,18 +48,19 @@ impl DecodingFactory { } } -fn create_stop_regex(stop_words: &[&str]) -> Regex { - let tokens: Vec = stop_words.iter().map(|x| reverse(*x)).collect(); - +fn create_stop_regex(stop_words: &[&str]) -> RegexSet { // (?m) enables multi-line matching mode. // \A means absolute begins of string. - let regex_string = r"(?m)\A".to_owned() + &tokens.join("|"); - Regex::new(®ex_string).unwrap() + let tokens: Vec = stop_words + .iter() + .map(|x| r"(?m)\A".to_owned() + &reverse(*x)) + .collect(); + RegexSet::new(tokens).expect("Failed to create regex set") } pub struct IncrementalDecoding { tokenizer: Arc, - stop_re: Option, + stop_re: Option, token_ids: Vec, prefix_offset: usize, @@ -69,7 +70,11 @@ pub struct IncrementalDecoding { } impl IncrementalDecoding { - pub fn new(tokenizer: Arc, stop_re: Option, input_token_ids: &[u32]) -> Self { + pub fn new( + tokenizer: Arc, + stop_re: Option, + input_token_ids: &[u32], + ) -> Self { let text = tokenizer .decode(input_token_ids, /* skip_special_token = */ true) .expect("Cannot decode token from tokenizer."); @@ -112,7 +117,7 @@ impl IncrementalDecoding { self.reversed_text = reverse(new_text) + &self.reversed_text; if let Some(re) = &self.stop_re { - if re.find(&self.reversed_text).is_some() { + if re.is_match(&self.reversed_text) { return None; } } @@ -121,3 +126,16 @@ impl IncrementalDecoding { Some(new_text.to_owned()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_it_should_not_match() { + let stop_words = vec!["\n\n", "\n\n "]; + let re = create_stop_regex(&stop_words); + let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid"); + assert!(!re.is_match(&text)) + } +} diff --git a/crates/tabby/src/serve/completions/languages.rs b/crates/tabby/src/serve/completions/languages.rs index eed3325f6f8..8dbe04fdfed 100644 --- a/crates/tabby/src/serve/completions/languages.rs +++ b/crates/tabby/src/serve/completions/languages.rs @@ -12,7 +12,6 @@ lazy_static! { "\n\n ", "\n\n ", "\n\n ", - "\n\n", "\n\n\t", "\n\n\t\t", "\n\n\t\t\t",