From 79f122189a327f97b37d8843c1d317b787ff764e Mon Sep 17 00:00:00 2001 From: Shroominic Date: Wed, 15 Nov 2023 18:06:44 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20Refactor=20extraction=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/extraction_test.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/extraction_test.py b/tests/extraction_test.py index e214baa..264ad32 100644 --- a/tests/extraction_test.py +++ b/tests/extraction_test.py @@ -1,6 +1,7 @@ from funcchain import achain, chain, settings, BaseModel settings.MODEL_TEMPERATURE = 0 +settings.MODEL_NAME = "gpt-3.5-turbo-1106" class Task(BaseModel): @@ -22,18 +23,15 @@ def description(task: Task) -> str: async def extract_task(task_description: str) -> Task: """ - EXTRACT TASK: {task_description} + Extract the task based on the task description: + {task_description} """ return await achain() -def compare_tasks(task: Task, task2: Task) -> bool: +def compare_tasks(task1: Task, task2: Task) -> bool: """ - COMPARE TASKS: - 1: {task} - 2: {task2} - - Are the tasks kind of equal? + Are the task1 and task2 similar? """ return chain() @@ -41,17 +39,17 @@ def compare_tasks(task: Task, task2: Task) -> bool: def test_extraction() -> None: from asyncio import run as _await - task = Task( + task1 = Task( name="Do dishes", description="Do the dishes in the kitchen.", difficulty="easy", keywords=["kitchen", "dishes"], ) - task_description = description(task) - _task = _await(extract_task(task_description)) + task_description = description(task1) + task2 = _await(extract_task(task_description)) - assert compare_tasks(task, _task) + assert compare_tasks(task1, task2) if __name__ == "__main__":