diff --git a/backend/app/api/endpoints/base/example.py b/backend/app/api/endpoints/base/example.py index 99757c54..a1908759 100644 --- a/backend/app/api/endpoints/base/example.py +++ b/backend/app/api/endpoints/base/example.py @@ -2,6 +2,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import json + from fastapi import APIRouter, Response from app.domain.schemas.base.example import ( @@ -69,15 +71,32 @@ def partial_creation_generative_example( def create_example( model: CreateExampleRequest, ): - return ExampleService().create_example( - model.context_id, - model.user_id, - model.model_wrong, - model.model_endpoint_name, - model.input_json, - model.output_json, - model.metadata, - model.tag, + return ( + ExampleService().create_example_and_increment_counters( + model.context_id, + model.user_id, + model.model_wrong, + model.model_endpoint_name, + json.dumps(model.input_json), + json.dumps(model.output_json), + json.dumps(model.metadata), + model.tag, + model.round_id, + model.task_id, + text=model.text, + ) + if model.increment_context + else ExampleService().create_example( + model.context_id, + model.user_id, + model.model_wrong, + model.model_endpoint_name, + model.input_json, + model.output_json, + model.metadata, + model.tag, + model.text, + ) ) diff --git a/backend/app/domain/schemas/base/example.py b/backend/app/domain/schemas/base/example.py index f0ed6088..125a7d9c 100644 --- a/backend/app/domain/schemas/base/example.py +++ b/backend/app/domain/schemas/base/example.py @@ -35,6 +35,10 @@ class CreateExampleRequest(BaseModel): output_json: Optional[dict] = None metadata: Optional[dict] = None tag: Optional[str] = "generative" + increment_context: Optional[bool] = False + text: Optional[str] = None + task_id: Optional[int] = None + round_id: Optional[int] = None class PartialCreationExampleRequest(BaseModel): diff --git a/backend/app/domain/services/base/context.py b/backend/app/domain/services/base/context.py index db54f2ea..154ac7a9 100644 --- a/backend/app/domain/services/base/context.py +++ b/backend/app/domain/services/base/context.py @@ -314,5 +314,13 @@ def get_random_context_from_key_value(self, key_name: str, key_value: dict) -> d ) if not contexts: return None - contexts = [json.loads(context.context_json) for context in contexts] + contexts = [ + { + "id": context.id, + "round_id": context.r_realid, + **json.loads(context.context_json), + } + for context in contexts + ] + return random.choice(contexts) diff --git a/backend/app/domain/services/base/example.py b/backend/app/domain/services/base/example.py index 9b4bac95..00339747 100644 --- a/backend/app/domain/services/base/example.py +++ b/backend/app/domain/services/base/example.py @@ -54,6 +54,7 @@ def create_example( output_json: Json, metadata: Json, tag: str, + text: str = None, ) -> dict: return self.example_repository.create_example( context_id, @@ -64,6 +65,7 @@ def create_example( output_json, metadata, tag, + text, ) def increment_counter_examples_submitted( @@ -112,6 +114,7 @@ def create_example_and_increment_counters( amount_necessary_examples: int = -1, url_external_provider: str = None, provider_artifacts: dict = None, + text: str = None, ) -> dict: new_sample_info = self.create_example( context_id, @@ -122,6 +125,7 @@ def create_example_and_increment_counters( output_json, metadata, tag, + text, ) self.increment_counter_examples_submitted( round_id, user_id, context_id, task_id, model_wrong diff --git a/backend/app/infrastructure/models/models.py b/backend/app/infrastructure/models/models.py index ad8cba50..183dc14f 100644 --- a/backend/app/infrastructure/models/models.py +++ b/backend/app/infrastructure/models/models.py @@ -439,6 +439,7 @@ class Example(Base): cid = Column(ForeignKey("contexts.id"), nullable=False, index=True) uid = Column(ForeignKey("users.id"), index=True) tag = Column(Text) + text = Column(Text) input_json = Column(Text) output_json = Column(Text) metadata_json = Column(Text) diff --git a/backend/app/infrastructure/repositories/example.py b/backend/app/infrastructure/repositories/example.py index 70583c68..c13d2250 100644 --- a/backend/app/infrastructure/repositories/example.py +++ b/backend/app/infrastructure/repositories/example.py @@ -27,6 +27,7 @@ def create_example( output_json: Json, metadata: Json, tag: str, + text: str, ) -> dict: return self.add( { @@ -42,6 +43,7 @@ def create_example( "split": "undecided", "flagged": 0, "total_verified": 0, + "text": text, } )