-
Notifications
You must be signed in to change notification settings - Fork 213
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a repair task, which can take a task+taskrun+feedback, and improv…
…e the result.
- Loading branch information
Showing
2 changed files
with
330 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import json | ||
from typing import Type | ||
|
||
from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_registry | ||
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun | ||
from pydantic import BaseModel | ||
|
||
|
||
class RepairData(BaseModel): | ||
original_task: Task | ||
task_run: TaskRun | ||
evaluator_feedback: str | ||
|
||
|
||
# TODO add evaluator rating | ||
class RepairTaskInput(BaseModel): | ||
original_prompt: str | ||
original_input: str | ||
original_output: str | ||
evaluator_feedback: str | ||
|
||
|
||
class RepairTaskRun(Task, parent_of={}): | ||
def __init__(self, original_task: Task): | ||
# Keep the typechecker happy | ||
tmp_project = Project(name="Repair") | ||
super().__init__( | ||
name="Repair", | ||
parent=tmp_project, | ||
description="Repair a task run, given feedback from an evaluator about how the response can be improved.", | ||
instruction="You are an assistant which helps improve output from another assistant (original assistant). You'll be provided a task that the original assistant executed (prompt), \ | ||
the input it was given, and the output it generated. An evaluator has determined that the output it generated did not satisfy the task and should be improved. The evaluator will provide \ | ||
feedback describing what should be improved. Your job is to understand the evaluator's feedback and improve the response.", | ||
requirements=[ | ||
TaskRequirement( | ||
name="Follow Evaluator Feedback", | ||
instruction="The evaluator's feedback is the most important thing to consider. If it conflicts with the original task instruction or prompt, prioritize the evaluator's feedback.", | ||
priority=Priority.p0, | ||
) | ||
], | ||
input_json_schema=json.dumps(RepairTaskInput.model_json_schema()), | ||
output_json_schema=original_task.output_json_schema, | ||
) | ||
|
||
@classmethod | ||
def _original_prompt(cls, run: TaskRun, task: Task) -> str: | ||
prompt_builder_class: Type[BasePromptBuilder] | None = None | ||
prompt_builder_name = run.output.source.properties.get( | ||
"prompt_builder_name", None | ||
) | ||
if prompt_builder_name is not None and isinstance(prompt_builder_name, str): | ||
prompt_builder_class = prompt_builder_registry.get( | ||
prompt_builder_name, None | ||
) | ||
if prompt_builder_class is None: | ||
raise ValueError(f"No prompt builder found for name: {prompt_builder_name}") | ||
prompt_builder = prompt_builder_class(task=task) | ||
if not isinstance(prompt_builder, BasePromptBuilder): | ||
raise ValueError( | ||
f"Prompt builder {prompt_builder_name} is not a valid prompt builder" | ||
) | ||
return prompt_builder.build_prompt() | ||
|
||
@classmethod | ||
def build_repair_task_input(cls, repair_data: RepairData) -> RepairTaskInput: | ||
original_prompt = cls._original_prompt( | ||
repair_data.task_run, repair_data.original_task | ||
) | ||
return RepairTaskInput( | ||
original_prompt=original_prompt, | ||
original_input=repair_data.task_run.input, | ||
original_output=repair_data.task_run.output.output, | ||
evaluator_feedback=repair_data.evaluator_feedback, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
import json | ||
import os | ||
from unittest.mock import AsyncMock, patch | ||
|
||
import pytest | ||
from kiln_ai.adapters.langchain_adapters import ( | ||
LangChainPromptAdapter, | ||
) | ||
from kiln_ai.adapters.repair.repair_task import ( | ||
RepairData, | ||
RepairTaskInput, | ||
RepairTaskRun, | ||
) | ||
from kiln_ai.datamodel import ( | ||
DataSource, | ||
DataSourceType, | ||
Priority, | ||
Task, | ||
TaskOutput, | ||
TaskRequirement, | ||
TaskRun, | ||
) | ||
|
||
json_joke_schema = """{ | ||
"type": "object", | ||
"properties": { | ||
"setup": { | ||
"description": "The setup of the joke", | ||
"title": "Setup", | ||
"type": "string" | ||
}, | ||
"punchline": { | ||
"description": "The punchline to the joke", | ||
"title": "Punchline", | ||
"type": "string" | ||
}, | ||
"rating": { | ||
"anyOf": [ | ||
{ | ||
"type": "integer" | ||
}, | ||
{ | ||
"type": "null" | ||
} | ||
], | ||
"default": null, | ||
"description": "How funny the joke is, from 1 to 10", | ||
"title": "Rating" | ||
} | ||
}, | ||
"required": [ | ||
"setup", | ||
"punchline" | ||
] | ||
} | ||
""" | ||
|
||
|
||
@pytest.fixture | ||
def sample_task(tmp_path): | ||
task_path = tmp_path / "task.json" | ||
task = Task( | ||
name="Joke Generator", | ||
path=task_path, | ||
description="Generate a funny joke", | ||
instruction="Create a joke with a setup and punchline", | ||
requirements=[ | ||
TaskRequirement( | ||
id="req1", | ||
name="Humor", | ||
instruction="The joke should be funny and appropriate", | ||
priority=Priority.p1, | ||
) | ||
], | ||
output_json_schema=json_joke_schema, | ||
) | ||
task.save_to_file() | ||
return task | ||
|
||
|
||
@pytest.fixture | ||
def sample_task_run(sample_task): | ||
task_run = TaskRun( | ||
parent=sample_task, | ||
input='{"topic": "chicken"}', | ||
input_source=DataSource( | ||
type=DataSourceType.human, properties={"created_by": "Jane Doe"} | ||
), | ||
output=TaskOutput( | ||
output='{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side", "rating": null}', | ||
source=DataSource( | ||
type=DataSourceType.synthetic, | ||
properties={ | ||
"model_name": "gpt_4o", | ||
"model_provider": "openai", | ||
"adapter_name": "langchain_adapter", | ||
"prompt_builder_name": "simple_prompt_builder", | ||
}, | ||
), | ||
), | ||
) | ||
task_run.save_to_file() | ||
return task_run | ||
|
||
|
||
@pytest.fixture | ||
def sample_repair_data(sample_task, sample_task_run): | ||
return RepairData( | ||
original_task=sample_task, | ||
task_run=sample_task_run, | ||
evaluator_feedback="The joke is too cliché. Please come up with a more original chicken-related joke.", | ||
) | ||
|
||
|
||
def test_build_repair_task_input(sample_repair_data): | ||
result = RepairTaskRun.build_repair_task_input(sample_repair_data) | ||
|
||
assert isinstance(result, RepairTaskInput) | ||
assert "Create a joke with a setup and punchline" in result.original_prompt | ||
assert "1) The joke should be funny and appropriate" in result.original_prompt | ||
assert result.original_input == '{"topic": "chicken"}' | ||
assert ( | ||
result.original_output | ||
== '{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side", "rating": null}' | ||
) | ||
assert ( | ||
result.evaluator_feedback | ||
== "The joke is too cliché. Please come up with a more original chicken-related joke." | ||
) | ||
|
||
|
||
def test_repair_input_schema(): | ||
schema = RepairTaskInput.model_json_schema() | ||
assert schema["type"] == "object" | ||
assert "original_prompt" in schema["properties"] | ||
assert "original_input" in schema["properties"] | ||
assert "original_output" in schema["properties"] | ||
assert "evaluator_feedback" in schema["properties"] | ||
|
||
|
||
def test_repair_task_initialization(sample_task): | ||
repair_task = RepairTaskRun(sample_task) | ||
|
||
assert repair_task.name == "Repair" | ||
assert "Repair a task run" in repair_task.description | ||
assert "You are an assistant which helps improve output" in repair_task.instruction | ||
assert len(repair_task.requirements) == 1 | ||
assert repair_task.requirements[0].name == "Follow Evaluator Feedback" | ||
assert repair_task.input_json_schema == json.dumps( | ||
RepairTaskInput.model_json_schema() | ||
) | ||
assert repair_task.output_json_schema == sample_task.output_json_schema | ||
|
||
|
||
def test_build_repair_task_input_with_empty_values(sample_repair_data): | ||
# Arrange | ||
sample_repair_data.task_run.input = "" | ||
sample_repair_data.task_run.output.output = "" | ||
sample_repair_data.evaluator_feedback = "" | ||
|
||
# Act | ||
result = RepairTaskRun.build_repair_task_input(sample_repair_data) | ||
|
||
# Assert | ||
assert isinstance(result, RepairTaskInput) | ||
assert "Create a joke with a setup and punchline" in result.original_prompt | ||
assert result.original_input == "" | ||
assert result.original_output == "" | ||
assert result.evaluator_feedback == "" | ||
|
||
|
||
@pytest.mark.parametrize("invalid_input", [None, "", 123, {}]) | ||
def test_build_repair_task_input_with_invalid_input(invalid_input): | ||
# Act & Assert | ||
with pytest.raises(AttributeError): | ||
RepairTaskRun.build_repair_task_input(invalid_input) | ||
|
||
|
||
@pytest.mark.paid | ||
async def test_live_run(sample_task, sample_task_run, sample_repair_data): | ||
if os.getenv("GROQ_API_KEY") is None: | ||
pytest.skip("GROQ_API_KEY not set") | ||
repair_task = RepairTaskRun(sample_task) | ||
repair_task_input = RepairTaskRun.build_repair_task_input(sample_repair_data) | ||
assert isinstance(repair_task_input, RepairTaskInput) | ||
|
||
adapter = LangChainPromptAdapter( | ||
repair_task, model_name="llama_3_1_8b", provider="groq" | ||
) | ||
|
||
adapter_response = await adapter.invoke_returning_run( | ||
repair_task_input.model_dump() | ||
) | ||
print("output", adapter_response.output) | ||
assert adapter_response.run is not None | ||
assert adapter_response.run.id is not None | ||
assert ( | ||
"Please come up with a more original chicken-related joke." | ||
in adapter_response.run.input | ||
) | ||
parsed_output = json.loads(adapter_response.run.output.output) | ||
assert "setup" in parsed_output | ||
assert "punchline" in parsed_output | ||
assert adapter_response.run.output.source.properties == { | ||
"adapter_name": "kiln_langchain_adapter", | ||
"model_name": "llama_3_1_8b", | ||
"model_provider": "groq", | ||
"prompt_builder_name": "simple_prompt_builder", | ||
} | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repair_data): | ||
repair_task = RepairTaskRun(sample_task) | ||
repair_task_input = RepairTaskRun.build_repair_task_input(sample_repair_data) | ||
assert isinstance(repair_task_input, RepairTaskInput) | ||
|
||
mocked_output = { | ||
"setup": "Why did the chicken join a band?", | ||
"punchline": "Because it had excellent drumsticks!", | ||
"rating": 8, | ||
} | ||
|
||
with patch.object( | ||
LangChainPromptAdapter, "_run", new_callable=AsyncMock | ||
) as mock_run: | ||
mock_run.return_value = mocked_output | ||
|
||
adapter = LangChainPromptAdapter( | ||
repair_task, model_name="llama_3_1_8b", provider="groq" | ||
) | ||
|
||
adapter_response = await adapter.invoke_returning_run( | ||
repair_task_input.model_dump() | ||
) | ||
|
||
assert adapter_response.run is not None | ||
assert adapter_response.run.id is not None | ||
assert ( | ||
"Please come up with a more original chicken-related joke." | ||
in adapter_response.run.input | ||
) | ||
|
||
parsed_output = json.loads(adapter_response.run.output.output) | ||
assert parsed_output == mocked_output | ||
assert adapter_response.run.output.source.properties == { | ||
"adapter_name": "kiln_langchain_adapter", | ||
"model_name": "llama_3_1_8b", | ||
"model_provider": "groq", | ||
"prompt_builder_name": "simple_prompt_builder", | ||
} | ||
assert adapter_response.run.input_source.type == DataSourceType.human | ||
assert "created_by" in adapter_response.run.input_source.properties | ||
|
||
# Verify that the mock was called | ||
mock_run.assert_called_once() |