Skip to content

Commit

Permalink
fix: avoid special keywords (e.g AND) failed the query parsing (#695)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys authored Nov 3, 2023
1 parent 2adcc07 commit e4efcc4
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 43 deletions.
26 changes: 11 additions & 15 deletions crates/tabby/src/serve/completions/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use lazy_static::lazy_static;
use regex::Regex;
use strfmt::strfmt;
use tabby_common::languages::get_language;
use tantivy::{query::BooleanQuery, query_grammar::Occur};
use textdistance::Algorithm;
use tracing::warn;

Expand Down Expand Up @@ -106,17 +107,16 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {

fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec<Snippet> {
let mut ret = Vec::new();
let mut tokens = Box::new(tokenize_text(text));
let mut tokens = tokenize_text(text);

let sanitized_text = tokens.join(" ");
let sanitized_text = sanitized_text.trim();
if sanitized_text.is_empty() {
return ret;
}

let query_text = format!("language:{} AND ({})", language, sanitized_text);
let language_query = index_server.language_query(language).unwrap();
let body_query = index_server.body_query(&tokens).unwrap();
let query = BooleanQuery::new(vec![
(Occur::Must, language_query),
(Occur::Must, body_query),
]);

let serp = match index_server.search(&query_text, MAX_SNIPPETS_TO_FETCH, 0) {
let serp = match index_server.search_with_query(&query, MAX_SNIPPETS_TO_FETCH, 0) {
Ok(serp) => serp,
Err(IndexServerError::NotReady) => {
// Ignore.
Expand Down Expand Up @@ -154,7 +154,7 @@ fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> V
// Prepend body tokens and update tokens, so future similarity calculation will consider
// added snippets.
body_tokens.append(&mut tokens);
*tokens = body_tokens;
tokens.append(&mut body_tokens);

count_characters += body.len();
ret.push(Snippet {
Expand All @@ -172,11 +172,7 @@ lazy_static! {
}

fn tokenize_text(text: &str) -> Vec<String> {
TOKENIZER
.split(text)
.filter(|s| *s != "AND" && *s != "OR" && *s != "NOT" && !s.is_empty())
.map(|x| x.to_owned())
.collect()
TOKENIZER.split(text).map(|x| x.to_owned()).collect()
}

#[cfg(test)]
Expand Down
92 changes: 64 additions & 28 deletions crates/tabby/src/serve/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ use serde::{Deserialize, Serialize};
use tabby_common::{index::IndexExt, path};
use tantivy::{
collector::{Count, TopDocs},
query::QueryParser,
schema::Field,
DocAddress, Document, Index, IndexReader,
query::{QueryParser, TermQuery, TermSetQuery},
schema::{Field, IndexRecordOption},
DocAddress, Document, Index, IndexReader, Term,
};
use thiserror::Error;
use tokio::{sync::OnceCell, task, time::sleep};
use tracing::{debug, instrument, log::info};
use tracing::{debug, instrument, log::info, warn};
use utoipa::{IntoParams, ToSchema};

#[derive(Deserialize, IntoParams)]
Expand Down Expand Up @@ -70,15 +70,18 @@ pub async fn search(
State(state): State<Arc<IndexServer>>,
query: Query<SearchQuery>,
) -> Result<Json<SearchResponse>, StatusCode> {
let Ok(serp) = state.search(
match state.search(
&query.q,
query.limit.unwrap_or(20),
query.offset.unwrap_or(0),
) else {
return Err(StatusCode::NOT_IMPLEMENTED);
};

Ok(Json(serp))
) {
Ok(serp) => Ok(Json(serp)),
Err(IndexServerError::NotReady) => Err(StatusCode::NOT_IMPLEMENTED),
Err(IndexServerError::TantivyError(err)) => {
warn!("{}", err);
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}

struct IndexServerImpl {
Expand Down Expand Up @@ -119,17 +122,19 @@ impl IndexServerImpl {
}

pub fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
let query = self
.query_parser
.parse_query(q)
.expect("Parsing the query failed");
let query = self.query_parser.parse_query(q)?;
self.search_with_query(&query, limit, offset)
}

pub fn search_with_query(
&self,
q: &dyn tantivy::query::Query,
limit: usize,
offset: usize,
) -> tantivy::Result<SearchResponse> {
let searcher = self.reader.searcher();
let (top_docs, num_hits) = {
searcher.search(
&query,
&(TopDocs::with_limit(limit).and_offset(offset), Count),
)?
};
let (top_docs, num_hits) =
{ searcher.search(q, &(TopDocs::with_limit(limit).and_offset(offset), Count))? };
let hits: Vec<Hit> = {
top_docs
.iter()
Expand Down Expand Up @@ -179,8 +184,15 @@ impl IndexServer {
Self {}
}

fn get_cell(&self) -> Option<&IndexServerImpl> {
IMPL.get()
fn with_impl<T, F>(&self, op: F) -> Result<T, IndexServerError>
where
F: FnOnce(&IndexServerImpl) -> Result<T, IndexServerError>,
{
if let Some(imp) = IMPL.get() {
op(imp)
} else {
Err(IndexServerError::NotReady)
}
}

async fn worker() -> IndexServerImpl {
Expand All @@ -199,17 +211,41 @@ impl IndexServer {
}
}

pub fn language_query(&self, language: &str) -> Result<Box<TermQuery>, IndexServerError> {
self.with_impl(|imp| {
Ok(Box::new(TermQuery::new(
Term::from_field_text(imp.field_language, language),
IndexRecordOption::WithFreqsAndPositions,
)))
})
}

pub fn body_query(&self, tokens: &[String]) -> Result<Box<TermSetQuery>, IndexServerError> {
self.with_impl(|imp| {
Ok(Box::new(TermSetQuery::new(
tokens
.iter()
.map(|x| Term::from_field_text(imp.field_body, x)),
)))
})
}

pub fn search(
&self,
q: &str,
limit: usize,
offset: usize,
) -> Result<SearchResponse, IndexServerError> {
if let Some(imp) = self.get_cell() {
Ok(imp.search(q, limit, offset)?)
} else {
Err(IndexServerError::NotReady)
}
self.with_impl(|imp| Ok(imp.search(q, limit, offset)?))
}

pub fn search_with_query(
&self,
q: &dyn tantivy::query::Query,
limit: usize,
offset: usize,
) -> Result<SearchResponse, IndexServerError> {
self.with_impl(|imp| Ok(imp.search_with_query(q, limit, offset)?))
}
}

Expand All @@ -218,6 +254,6 @@ pub enum IndexServerError {
#[error("index not ready")]
NotReady,

#[error("underlying tantivy error")]
#[error("{0}")]
TantivyError(#[from] tantivy::TantivyError),
}

0 comments on commit e4efcc4

Please sign in to comment.