diff --git a/deepsearch/model/examples/dummy_qa_generator/model.py b/deepsearch/model/examples/dummy_qa_generator/model.py index 04666a24..2f3b0419 100644 --- a/deepsearch/model/examples/dummy_qa_generator/model.py +++ b/deepsearch/model/examples/dummy_qa_generator/model.py @@ -1,8 +1,12 @@ -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple from deepsearch.model.base.types import Kind from deepsearch.model.kinds.qagen.model import BaseQAGenerator -from deepsearch.model.kinds.qagen.types import GenerateAnswersOutput, QAGenConfig +from deepsearch.model.kinds.qagen.types import ( + GenerateAnswersOutEntry, + GenerateAnswersOutput, + QAGenConfig, +) class DummyQAGenerator(BaseQAGenerator): @@ -18,10 +22,19 @@ def get_qagen_config(self) -> QAGenConfig: return self._config def generate_answers( - self, texts: List[Tuple[List[Dict], str]] + self, + texts: List[Tuple[List[Dict], str]], + extras: Dict[str, Any], ) -> GenerateAnswersOutput: """Just answers with the question itself. Args: texts: a list of context, question pairs. + extras: any extras to pass. """ - return [question for _, question in texts] + return [ + GenerateAnswersOutEntry( + answer=question, + metadata={"foo": "bar"}, + ) + for _, question in texts + ] diff --git a/deepsearch/model/kinds/qagen/controller.py b/deepsearch/model/kinds/qagen/controller.py index b43e87ca..99ee138a 100644 --- a/deepsearch/model/kinds/qagen/controller.py +++ b/deepsearch/model/kinds/qagen/controller.py @@ -43,6 +43,7 @@ def dispatch_predict(self, spec: CtrlPredInput) -> CtrlPredOutput: ([ctx_entry.dict() for ctx_entry in ctx_list], q) for ctx_list, q in zip(gen_answers.contexts, gen_answers.questions) ], + extras=gen_answers.extras or {}, ) return QAGenCtrlPredOutput( answers=answers, diff --git a/deepsearch/model/kinds/qagen/model.py b/deepsearch/model/kinds/qagen/model.py index aec3d4e9..d7b9f174 100644 --- a/deepsearch/model/kinds/qagen/model.py +++ b/deepsearch/model/kinds/qagen/model.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple from deepsearch.model.base.model import BaseDSModel from deepsearch.model.base.types import BaseModelConfig @@ -9,7 +9,7 @@ class BaseQAGenerator(BaseDSModel): @abstractmethod def generate_answers( - self, texts: List[Tuple[List[Dict], str]] + self, texts: List[Tuple[List[Dict], str]], extras: Dict[str, Any] ) -> GenerateAnswersOutput: raise NotImplementedError() diff --git a/deepsearch/model/kinds/qagen/types.py b/deepsearch/model/kinds/qagen/types.py index 19e7670e..714eb3c0 100644 --- a/deepsearch/model/kinds/qagen/types.py +++ b/deepsearch/model/kinds/qagen/types.py @@ -1,4 +1,4 @@ -from typing import List, Literal +from typing import Any, Dict, List, Literal, Optional from pydantic import root_validator @@ -20,6 +20,7 @@ class ContextEntry(StrictModel): class GenerateAnswers(StrictModel): contexts: List[List[ContextEntry]] questions: List[str] + extras: Optional[Dict[str, Any]] = None @root_validator def check_lengths_match(cls, values): @@ -38,7 +39,12 @@ class QAGenAppPredInput(BaseAppPredInput): spec: QAGenReqSpec -GenerateAnswersOutput = List[str] +class GenerateAnswersOutEntry(StrictModel): + answer: str + metadata: Dict[str, Any] + + +GenerateAnswersOutput = List[GenerateAnswersOutEntry] class QAGenCtrlPredOutput(StrictModel):