Skip to content

Commit

Permalink
chore(wren-ai-service): allow retrieving sql pairs while retrieving h…
Browse files Browse the repository at this point in the history
…istorical questions (#1318)
  • Loading branch information
cyyeh authored Feb 20, 2025
1 parent b7ad839 commit 1bb6d4b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import sys
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -26,9 +27,9 @@ def run(self, documents: List[Document]):
for doc in documents:
formatted = {
"question": doc.content,
"summary": doc.meta.get("summary"),
"statement": doc.meta.get("statement"),
"viewId": doc.meta.get("viewId"),
"summary": doc.meta.get("summary", ""),
"statement": doc.meta.get("statement") or doc.meta.get("sql"),
"viewId": doc.meta.get("viewId", ""),
}
list.append(formatted)

Expand All @@ -37,7 +38,11 @@ def run(self, documents: List[Document]):

## Start of Pipeline
@observe(capture_input=False)
async def count_documents(store: QdrantDocumentStore, id: Optional[str] = None) -> int:
async def count_documents(
view_questions_store: QdrantDocumentStore,
sql_pair_store: QdrantDocumentStore,
id: Optional[str] = None,
) -> int:
filters = (
{
"operator": "AND",
Expand All @@ -48,8 +53,11 @@ async def count_documents(store: QdrantDocumentStore, id: Optional[str] = None)
if id
else None
)
document_count = await store.count_documents(filters=filters)
return document_count
view_question_count, sql_pair_count = await asyncio.gather(
view_questions_store.count_documents(filters=filters),
sql_pair_store.count_documents(filters=filters),
)
return view_question_count + sql_pair_count


@observe(capture_input=False, capture_output=False)
Expand All @@ -61,7 +69,9 @@ async def embedding(count_documents: int, query: str, embedder: Any) -> dict:


@observe(capture_input=False)
async def retrieval(embedding: dict, id: str, retriever: Any) -> dict:
async def retrieval(
embedding: dict, id: str, view_questions_retriever: Any, sql_pair_retriever: Any
) -> dict:
if embedding:
filters = (
{
Expand All @@ -74,11 +84,19 @@ async def retrieval(embedding: dict, id: str, retriever: Any) -> dict:
else None
)

res = await retriever.run(
query_embedding=embedding.get("embedding"),
filters=filters,
view_question_res, sql_pair_res = await asyncio.gather(
view_questions_retriever.run(
query_embedding=embedding.get("embedding"),
filters=filters,
),
sql_pair_retriever.run(
query_embedding=embedding.get("embedding"),
filters=filters,
),
)
return dict(
documents=view_question_res.get("documents") + sql_pair_res.get("documents")
)
return dict(documents=res.get("documents"))

return {}

Expand Down Expand Up @@ -111,12 +129,19 @@ def __init__(
document_store_provider: DocumentStoreProvider,
**kwargs,
) -> None:
store = document_store_provider.get_store(dataset_name="view_questions")
view_questions_store = document_store_provider.get_store(
dataset_name="view_questions"
)
sql_pair_store = document_store_provider.get_store(dataset_name="sql_pairs")
self._components = {
"store": store,
"view_questions_store": view_questions_store,
"sql_pair_store": sql_pair_store,
"embedder": embedder_provider.get_text_embedder(),
"retriever": document_store_provider.get_retriever(
document_store=store,
"view_questions_retriever": document_store_provider.get_retriever(
document_store=view_questions_store,
),
"sql_pair_retriever": document_store_provider.get_retriever(
document_store=sql_pair_store,
),
"score_filter": ScoreFilter(),
# TODO: add a llm filter to filter out low scoring document, in case ScoreFilter is not accurate enough
Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ async def ask(
sql_generation_reasoning = None
sql_samples = []
api_results = []
table_names = []
error_message = ""

try:
Expand Down

0 comments on commit 1bb6d4b

Please sign in to comment.