From dbf480658cf53d2c00311084dcf9c12636022a29 Mon Sep 17 00:00:00 2001 From: Panos Vagenas <35837085+vagenas@users.noreply.github.com> Date: Fri, 1 Sep 2023 14:11:56 +0200 Subject: [PATCH 1/2] feat: extend QA gen types Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> --- .../model/examples/dummy_qa_generator/model.py | 14 ++++++++++++-- deepsearch/model/kinds/qagen/types.py | 13 +++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/deepsearch/model/examples/dummy_qa_generator/model.py b/deepsearch/model/examples/dummy_qa_generator/model.py index 04666a24..68754fc6 100644 --- a/deepsearch/model/examples/dummy_qa_generator/model.py +++ b/deepsearch/model/examples/dummy_qa_generator/model.py @@ -2,7 +2,11 @@ from deepsearch.model.base.types import Kind from deepsearch.model.kinds.qagen.model import BaseQAGenerator -from deepsearch.model.kinds.qagen.types import GenerateAnswersOutput, QAGenConfig +from deepsearch.model.kinds.qagen.types import ( + GenerateAnswersOutEntry, + GenerateAnswersOutput, + QAGenConfig, +) class DummyQAGenerator(BaseQAGenerator): @@ -24,4 +28,10 @@ def generate_answers( Args: texts: a list of context, question pairs. """ - return [question for _, question in texts] + return [ + GenerateAnswersOutEntry( + answer=question, + metadata={"foo": "bar"}, + ) + for _, question in texts + ] diff --git a/deepsearch/model/kinds/qagen/types.py b/deepsearch/model/kinds/qagen/types.py index 19e7670e..9bad5bee 100644 --- a/deepsearch/model/kinds/qagen/types.py +++ b/deepsearch/model/kinds/qagen/types.py @@ -1,6 +1,6 @@ -from typing import List, Literal +from typing import Any, Dict, List, Literal -from pydantic import root_validator +from pydantic import BaseModel, root_validator from deepsearch.model.base.types import ( BaseAppPredInput, @@ -29,7 +29,7 @@ def check_lengths_match(cls, values): return values -class QAGenReqSpec(StrictModel): +class QAGenReqSpec(BaseModel): generateAnswers: GenerateAnswers @@ -38,7 +38,12 @@ class QAGenAppPredInput(BaseAppPredInput): spec: QAGenReqSpec -GenerateAnswersOutput = List[str] +class GenerateAnswersOutEntry(StrictModel): + answer: str + metadata: Dict[str, Any] + + +GenerateAnswersOutput = List[GenerateAnswersOutEntry] class QAGenCtrlPredOutput(StrictModel): From eab4110979130db363fed0507b92054a2491e34f Mon Sep 17 00:00:00 2001 From: Panos Vagenas <35837085+vagenas@users.noreply.github.com> Date: Fri, 1 Sep 2023 15:28:18 +0200 Subject: [PATCH 2/2] add explicit extras Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> --- deepsearch/model/examples/dummy_qa_generator/model.py | 7 +++++-- deepsearch/model/kinds/qagen/controller.py | 1 + deepsearch/model/kinds/qagen/model.py | 4 ++-- deepsearch/model/kinds/qagen/types.py | 7 ++++--- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/deepsearch/model/examples/dummy_qa_generator/model.py b/deepsearch/model/examples/dummy_qa_generator/model.py index 68754fc6..2f3b0419 100644 --- a/deepsearch/model/examples/dummy_qa_generator/model.py +++ b/deepsearch/model/examples/dummy_qa_generator/model.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple from deepsearch.model.base.types import Kind from deepsearch.model.kinds.qagen.model import BaseQAGenerator @@ -22,11 +22,14 @@ def get_qagen_config(self) -> QAGenConfig: return self._config def generate_answers( - self, texts: List[Tuple[List[Dict], str]] + self, + texts: List[Tuple[List[Dict], str]], + extras: Dict[str, Any], ) -> GenerateAnswersOutput: """Just answers with the question itself. Args: texts: a list of context, question pairs. + extras: any extras to pass. """ return [ GenerateAnswersOutEntry( diff --git a/deepsearch/model/kinds/qagen/controller.py b/deepsearch/model/kinds/qagen/controller.py index b43e87ca..99ee138a 100644 --- a/deepsearch/model/kinds/qagen/controller.py +++ b/deepsearch/model/kinds/qagen/controller.py @@ -43,6 +43,7 @@ def dispatch_predict(self, spec: CtrlPredInput) -> CtrlPredOutput: ([ctx_entry.dict() for ctx_entry in ctx_list], q) for ctx_list, q in zip(gen_answers.contexts, gen_answers.questions) ], + extras=gen_answers.extras or {}, ) return QAGenCtrlPredOutput( answers=answers, diff --git a/deepsearch/model/kinds/qagen/model.py b/deepsearch/model/kinds/qagen/model.py index aec3d4e9..d7b9f174 100644 --- a/deepsearch/model/kinds/qagen/model.py +++ b/deepsearch/model/kinds/qagen/model.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple from deepsearch.model.base.model import BaseDSModel from deepsearch.model.base.types import BaseModelConfig @@ -9,7 +9,7 @@ class BaseQAGenerator(BaseDSModel): @abstractmethod def generate_answers( - self, texts: List[Tuple[List[Dict], str]] + self, texts: List[Tuple[List[Dict], str]], extras: Dict[str, Any] ) -> GenerateAnswersOutput: raise NotImplementedError() diff --git a/deepsearch/model/kinds/qagen/types.py b/deepsearch/model/kinds/qagen/types.py index 9bad5bee..714eb3c0 100644 --- a/deepsearch/model/kinds/qagen/types.py +++ b/deepsearch/model/kinds/qagen/types.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel, root_validator +from pydantic import root_validator from deepsearch.model.base.types import ( BaseAppPredInput, @@ -20,6 +20,7 @@ class ContextEntry(StrictModel): class GenerateAnswers(StrictModel): contexts: List[List[ContextEntry]] questions: List[str] + extras: Optional[Dict[str, Any]] = None @root_validator def check_lengths_match(cls, values): @@ -29,7 +30,7 @@ def check_lengths_match(cls, values): return values -class QAGenReqSpec(BaseModel): +class QAGenReqSpec(StrictModel): generateAnswers: GenerateAnswers