Skip to content

Commit

Permalink
Add a repair task, which can take a task+taskrun+feedback, and improv…
Browse files Browse the repository at this point in the history
…e the result.
  • Loading branch information
scosman committed Oct 17, 2024
1 parent 2599549 commit 2024a04
Show file tree
Hide file tree
Showing 2 changed files with 330 additions and 0 deletions.
74 changes: 74 additions & 0 deletions libs/core/kiln_ai/adapters/repair/repair_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import json
from typing import Type

from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_registry
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
from pydantic import BaseModel


class RepairData(BaseModel):
original_task: Task
task_run: TaskRun
evaluator_feedback: str


# TODO add evaluator rating
class RepairTaskInput(BaseModel):
original_prompt: str
original_input: str
original_output: str
evaluator_feedback: str


class RepairTaskRun(Task, parent_of={}):
def __init__(self, original_task: Task):
# Keep the typechecker happy
tmp_project = Project(name="Repair")
super().__init__(
name="Repair",
parent=tmp_project,
description="Repair a task run, given feedback from an evaluator about how the response can be improved.",
instruction="You are an assistant which helps improve output from another assistant (original assistant). You'll be provided a task that the original assistant executed (prompt), \
the input it was given, and the output it generated. An evaluator has determined that the output it generated did not satisfy the task and should be improved. The evaluator will provide \
feedback describing what should be improved. Your job is to understand the evaluator's feedback and improve the response.",
requirements=[
TaskRequirement(
name="Follow Evaluator Feedback",
instruction="The evaluator's feedback is the most important thing to consider. If it conflicts with the original task instruction or prompt, prioritize the evaluator's feedback.",
priority=Priority.p0,
)
],
input_json_schema=json.dumps(RepairTaskInput.model_json_schema()),
output_json_schema=original_task.output_json_schema,
)

@classmethod
def _original_prompt(cls, run: TaskRun, task: Task) -> str:
prompt_builder_class: Type[BasePromptBuilder] | None = None
prompt_builder_name = run.output.source.properties.get(
"prompt_builder_name", None
)
if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
prompt_builder_class = prompt_builder_registry.get(
prompt_builder_name, None
)
if prompt_builder_class is None:
raise ValueError(f"No prompt builder found for name: {prompt_builder_name}")
prompt_builder = prompt_builder_class(task=task)
if not isinstance(prompt_builder, BasePromptBuilder):
raise ValueError(
f"Prompt builder {prompt_builder_name} is not a valid prompt builder"
)
return prompt_builder.build_prompt()

@classmethod
def build_repair_task_input(cls, repair_data: RepairData) -> RepairTaskInput:
original_prompt = cls._original_prompt(
repair_data.task_run, repair_data.original_task
)
return RepairTaskInput(
original_prompt=original_prompt,
original_input=repair_data.task_run.input,
original_output=repair_data.task_run.output.output,
evaluator_feedback=repair_data.evaluator_feedback,
)
256 changes: 256 additions & 0 deletions libs/core/kiln_ai/adapters/repair/test_repair_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import json
import os
from unittest.mock import AsyncMock, patch

import pytest
from kiln_ai.adapters.langchain_adapters import (
LangChainPromptAdapter,
)
from kiln_ai.adapters.repair.repair_task import (
RepairData,
RepairTaskInput,
RepairTaskRun,
)
from kiln_ai.datamodel import (
DataSource,
DataSourceType,
Priority,
Task,
TaskOutput,
TaskRequirement,
TaskRun,
)

json_joke_schema = """{
"type": "object",
"properties": {
"setup": {
"description": "The setup of the joke",
"title": "Setup",
"type": "string"
},
"punchline": {
"description": "The punchline to the joke",
"title": "Punchline",
"type": "string"
},
"rating": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"description": "How funny the joke is, from 1 to 10",
"title": "Rating"
}
},
"required": [
"setup",
"punchline"
]
}
"""


@pytest.fixture
def sample_task(tmp_path):
task_path = tmp_path / "task.json"
task = Task(
name="Joke Generator",
path=task_path,
description="Generate a funny joke",
instruction="Create a joke with a setup and punchline",
requirements=[
TaskRequirement(
id="req1",
name="Humor",
instruction="The joke should be funny and appropriate",
priority=Priority.p1,
)
],
output_json_schema=json_joke_schema,
)
task.save_to_file()
return task


@pytest.fixture
def sample_task_run(sample_task):
task_run = TaskRun(
parent=sample_task,
input='{"topic": "chicken"}',
input_source=DataSource(
type=DataSourceType.human, properties={"created_by": "Jane Doe"}
),
output=TaskOutput(
output='{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side", "rating": null}',
source=DataSource(
type=DataSourceType.synthetic,
properties={
"model_name": "gpt_4o",
"model_provider": "openai",
"adapter_name": "langchain_adapter",
"prompt_builder_name": "simple_prompt_builder",
},
),
),
)
task_run.save_to_file()
return task_run


@pytest.fixture
def sample_repair_data(sample_task, sample_task_run):
return RepairData(
original_task=sample_task,
task_run=sample_task_run,
evaluator_feedback="The joke is too cliché. Please come up with a more original chicken-related joke.",
)


def test_build_repair_task_input(sample_repair_data):
result = RepairTaskRun.build_repair_task_input(sample_repair_data)

assert isinstance(result, RepairTaskInput)
assert "Create a joke with a setup and punchline" in result.original_prompt
assert "1) The joke should be funny and appropriate" in result.original_prompt
assert result.original_input == '{"topic": "chicken"}'
assert (
result.original_output
== '{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side", "rating": null}'
)
assert (
result.evaluator_feedback
== "The joke is too cliché. Please come up with a more original chicken-related joke."
)


def test_repair_input_schema():
schema = RepairTaskInput.model_json_schema()
assert schema["type"] == "object"
assert "original_prompt" in schema["properties"]
assert "original_input" in schema["properties"]
assert "original_output" in schema["properties"]
assert "evaluator_feedback" in schema["properties"]


def test_repair_task_initialization(sample_task):
repair_task = RepairTaskRun(sample_task)

assert repair_task.name == "Repair"
assert "Repair a task run" in repair_task.description
assert "You are an assistant which helps improve output" in repair_task.instruction
assert len(repair_task.requirements) == 1
assert repair_task.requirements[0].name == "Follow Evaluator Feedback"
assert repair_task.input_json_schema == json.dumps(
RepairTaskInput.model_json_schema()
)
assert repair_task.output_json_schema == sample_task.output_json_schema


def test_build_repair_task_input_with_empty_values(sample_repair_data):
# Arrange
sample_repair_data.task_run.input = ""
sample_repair_data.task_run.output.output = ""
sample_repair_data.evaluator_feedback = ""

# Act
result = RepairTaskRun.build_repair_task_input(sample_repair_data)

# Assert
assert isinstance(result, RepairTaskInput)
assert "Create a joke with a setup and punchline" in result.original_prompt
assert result.original_input == ""
assert result.original_output == ""
assert result.evaluator_feedback == ""


@pytest.mark.parametrize("invalid_input", [None, "", 123, {}])
def test_build_repair_task_input_with_invalid_input(invalid_input):
# Act & Assert
with pytest.raises(AttributeError):
RepairTaskRun.build_repair_task_input(invalid_input)


@pytest.mark.paid
async def test_live_run(sample_task, sample_task_run, sample_repair_data):
if os.getenv("GROQ_API_KEY") is None:
pytest.skip("GROQ_API_KEY not set")
repair_task = RepairTaskRun(sample_task)
repair_task_input = RepairTaskRun.build_repair_task_input(sample_repair_data)
assert isinstance(repair_task_input, RepairTaskInput)

adapter = LangChainPromptAdapter(
repair_task, model_name="llama_3_1_8b", provider="groq"
)

adapter_response = await adapter.invoke_returning_run(
repair_task_input.model_dump()
)
print("output", adapter_response.output)
assert adapter_response.run is not None
assert adapter_response.run.id is not None
assert (
"Please come up with a more original chicken-related joke."
in adapter_response.run.input
)
parsed_output = json.loads(adapter_response.run.output.output)
assert "setup" in parsed_output
assert "punchline" in parsed_output
assert adapter_response.run.output.source.properties == {
"adapter_name": "kiln_langchain_adapter",
"model_name": "llama_3_1_8b",
"model_provider": "groq",
"prompt_builder_name": "simple_prompt_builder",
}


@pytest.mark.asyncio
async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repair_data):
repair_task = RepairTaskRun(sample_task)
repair_task_input = RepairTaskRun.build_repair_task_input(sample_repair_data)
assert isinstance(repair_task_input, RepairTaskInput)

mocked_output = {
"setup": "Why did the chicken join a band?",
"punchline": "Because it had excellent drumsticks!",
"rating": 8,
}

with patch.object(
LangChainPromptAdapter, "_run", new_callable=AsyncMock
) as mock_run:
mock_run.return_value = mocked_output

adapter = LangChainPromptAdapter(
repair_task, model_name="llama_3_1_8b", provider="groq"
)

adapter_response = await adapter.invoke_returning_run(
repair_task_input.model_dump()
)

assert adapter_response.run is not None
assert adapter_response.run.id is not None
assert (
"Please come up with a more original chicken-related joke."
in adapter_response.run.input
)

parsed_output = json.loads(adapter_response.run.output.output)
assert parsed_output == mocked_output
assert adapter_response.run.output.source.properties == {
"adapter_name": "kiln_langchain_adapter",
"model_name": "llama_3_1_8b",
"model_provider": "groq",
"prompt_builder_name": "simple_prompt_builder",
}
assert adapter_response.run.input_source.type == DataSourceType.human
assert "created_by" in adapter_response.run.input_source.properties

# Verify that the mock was called
mock_run.assert_called_once()

0 comments on commit 2024a04

Please sign in to comment.