Skip to content

Commit

Permalink
feat: add QA on given chunk refs (#181)
Browse files Browse the repository at this point in the history
Signed-off-by: Panos Vagenas <[email protected]>
  • Loading branch information
vagenas authored Jun 14, 2024
1 parent 5ecfe40 commit 995dda0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
5 changes: 5 additions & 0 deletions deepsearch/cps/queries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SemanticBackendResource,
)
from deepsearch.cps.client.queries import Query, TaskCoordinates
from deepsearch.cps.queries.results import ChunkRef


def Wf(wf_query: Dict[str, Any], kg: TaskCoordinates) -> Query:
Expand Down Expand Up @@ -110,6 +111,7 @@ class _APISemanticRagParameters(_APISemanticRetrievalParameters):
gen_ctx_window_size: int = 5000
gen_ctx_window_lead_weight: float = 0.5
return_prompt: bool = False
chunk_refs: Optional[List[ChunkRef]] = None
gen_timeout: Optional[float] = None


Expand All @@ -129,6 +131,7 @@ def RAGQuery(
gen_ctx_window_size: int = 5000,
gen_ctx_window_lead_weight: float = 0.5,
return_prompt: bool = False,
chunk_refs: Optional[List[ChunkRef]] = None,
gen_timeout: Optional[float] = None,
) -> Query:
"""Create a RAG query
Expand All @@ -147,6 +150,7 @@ def RAGQuery(
gen_ctx_window_size (int, optional): (relevant only if gen_ctx_extr_method=="window") max chars to use for extracted gen context (actual extraction quantized on doc item level); defaults to 5000
gen_ctx_window_lead_weight (float, optional): (relevant only if gen_ctx_extr_method=="window") weight of leading text for distributing remaining window size after extracting the `main_path`; defaults to 0.5 (centered around `main_path`)
return_prompt (bool, optional): whether to return the instantiated prompt; defaults to False
chunk_refs (Optional[List[ChunkRef]], optional): list of explicit chunk references to use instead of performing retrieval; defaults to None (i.e. retrieval-mode)
gen_timeout (float, optional): timeout for LLM generation; defaults to None, i.e. determined by system
"""

Expand Down Expand Up @@ -181,6 +185,7 @@ def RAGQuery(
gen_ctx_window_size=gen_ctx_window_size,
gen_ctx_window_lead_weight=gen_ctx_window_lead_weight,
return_prompt=return_prompt,
chunk_refs=chunk_refs,
gen_timeout=gen_timeout,
)

Expand Down
38 changes: 25 additions & 13 deletions deepsearch/cps/queries/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@

from typing import List, Optional

from pydantic.v1 import BaseModel, root_validator
from pydantic.v1 import BaseModel

from deepsearch.cps.client.components.queries import RunQueryResult


class SearchResultItem(BaseModel):
class ChunkRef(BaseModel):
doc_hash: str
main_path: str # the anchor path among the contributing group
path_group: List[str] # the doc paths contributing to the encoding source


class SearchResultItem(ChunkRef):
chunk: str
main_path: str
path_group: List[str]
source_is_text: bool


class RAGGroundingInfo(BaseModel):
retr_items: List[SearchResultItem]
retr_items: Optional[List[SearchResultItem]] = None
gen_ctx_paths: List[str]


Expand Down Expand Up @@ -45,26 +48,35 @@ def __init__(self, msg="Search returned no results", *args, **kwargs):

class RAGResult(BaseModel):
answers: List[RAGAnswerItem]
search_result_items: List[SearchResultItem]
search_result_items: Optional[List[SearchResultItem]] = None

@classmethod
def from_api_output(cls, data: RunQueryResult, raise_on_error=True):
answers: List[RAGAnswerItem] = []
try:
search_result_items = data.outputs["retrieval"]["items"]
if raise_on_error and len(search_result_items) == 0:
raise NoSearchResultsError()
retrieval_part = data.outputs["retrieval"]
if retrieval_part is not None:
search_result_items = retrieval_part["items"]
if raise_on_error and len(search_result_items) == 0:
raise NoSearchResultsError()
else:
search_result_items = None
for answer_item in data.outputs["answers"]:
if raise_on_error and (gen_err := answer_item.get("gen_err")):
raise GenerationError(gen_err)
retr_idxs = answer_item["grounding_info"]["retr_idxs"]
answers.append(
RAGAnswerItem(
answer=answer_item["answer"],
grounding=RAGGroundingInfo(
retr_items=[
SearchResultItem.parse_obj(search_result_items[i])
for i in answer_item["grounding_info"]["retr_idxs"]
],
retr_items=(
[
SearchResultItem.parse_obj(search_result_items[i])
for i in retr_idxs
]
if retr_idxs is not None and retrieval_part is not None
else None
),
gen_ctx_paths=answer_item["grounding_info"][
"gen_ctx_paths"
],
Expand Down

0 comments on commit 995dda0

Please sign in to comment.