Skip to content

Commit

Permalink
Fix our weird API where were returnning a task_run and another raw_ou…
Browse files Browse the repository at this point in the history
…tput. Task run has output... we don't need it twice.

Make the task run always returned, and clear the ID if it's ephemerial (not persisted).
  • Loading branch information
scosman committed Oct 18, 2024
1 parent e62278e commit d02405c
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 156 deletions.
23 changes: 12 additions & 11 deletions libs/core/kiln_ai/adapters/base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,28 @@ class AdapterInfo:
prompt_builder_name: str


class AdapterRun(BaseModel):
run: TaskRun | None
output: Dict | str


class BaseAdapter(metaclass=ABCMeta):
def __init__(self, kiln_task: Task):
self.kiln_task = kiln_task
self.output_schema = self.kiln_task.output_json_schema
self.input_schema = self.kiln_task.input_json_schema

async def invoke(
async def invoke_returning_raw(
self,
input: Dict | str,
input_source: DataSource | None = None,
) -> Dict | str:
result = await self.invoke_returning_run(input, input_source)
return result.output
result = await self.invoke(input, input_source)
if self.kiln_task.output_json_schema is None:
return result.output.output
else:
return json.loads(result.output.output)

async def invoke_returning_run(
async def invoke(
self,
input: Dict | str,
input_source: DataSource | None = None,
) -> AdapterRun:
) -> TaskRun:
# validate input
if self.input_schema is not None:
if not isinstance(input, dict):
Expand All @@ -74,8 +72,11 @@ async def invoke_returning_run(
# Save the run if configured to do so, and we have a path to save to
if Config.shared().autosave_runs and self.kiln_task.path is not None:
run.save_to_file()
else:
# Clear the ID to indicate it's not persisted
run.id = None

return AdapterRun(run=run, output=result)
return run

def has_structured_output(self) -> bool:
return self.output_schema is not None
Expand Down
41 changes: 16 additions & 25 deletions libs/core/kiln_ai/adapters/repair/test_repair_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,15 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
repair_task, model_name="llama_3_1_8b", provider="groq"
)

adapter_response = await adapter.invoke_returning_run(
repair_task_input.model_dump()
)
print("output", adapter_response.output)
assert adapter_response.run is not None
assert adapter_response.run.id is not None
assert (
"Please come up with a more original chicken-related joke."
in adapter_response.run.input
)
parsed_output = json.loads(adapter_response.run.output.output)
run = await adapter.invoke(repair_task_input.model_dump())
print("output", run.output)
assert run is not None
assert run.id is not None
assert "Please come up with a more original chicken-related joke." in run.input
parsed_output = json.loads(run.output.output)
assert "setup" in parsed_output
assert "punchline" in parsed_output
assert adapter_response.run.output.source.properties == {
assert run.output.source.properties == {
"adapter_name": "kiln_langchain_adapter",
"model_name": "llama_3_1_8b",
"model_provider": "groq",
Expand Down Expand Up @@ -230,27 +225,23 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
repair_task, model_name="llama_3_1_8b", provider="groq"
)

adapter_response = await adapter.invoke_returning_run(
repair_task_input.model_dump()
)
run = await adapter.invoke(repair_task_input.model_dump())

assert adapter_response.run is not None
assert adapter_response.run.id is not None
assert (
"Please come up with a more original chicken-related joke."
in adapter_response.run.input
)
assert run is not None
# because it's mocked, the run is not saved to the disk, and returns no ID
assert run.id is None
assert "Please come up with a more original chicken-related joke." in run.input

parsed_output = json.loads(adapter_response.run.output.output)
parsed_output = json.loads(run.output.output)
assert parsed_output == mocked_output
assert adapter_response.run.output.source.properties == {
assert run.output.source.properties == {
"adapter_name": "kiln_langchain_adapter",
"model_name": "llama_3_1_8b",
"model_provider": "groq",
"prompt_builder_name": "simple_prompt_builder",
}
assert adapter_response.run.input_source.type == DataSourceType.human
assert "created_by" in adapter_response.run.input_source.properties
assert run.input_source.type == DataSourceType.human
assert "created_by" in run.input_source.properties

# Verify that the mock was called
mock_run.assert_called_once()
36 changes: 16 additions & 20 deletions libs/core/kiln_ai/adapters/test_prompt_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,22 @@ async def test_mock(tmp_path):
task = build_test_task(tmp_path)
mockChatModel = FakeListChatModel(responses=["mock response"])
adapter = LangChainPromptAdapter(task, custom_model=mockChatModel)
answer = await adapter.invoke("You are a mock, send me the response!")
assert "mock response" in answer
run = await adapter.invoke("You are a mock, send me the response!")
assert "mock response" in run.output.output


async def test_mock_returning_run(tmp_path):
task = build_test_task(tmp_path)
mockChatModel = FakeListChatModel(responses=["mock response"])
adapter = LangChainPromptAdapter(task, custom_model=mockChatModel)
adapter_response = await adapter.invoke_returning_run(
"You are a mock, send me the response!"
)
assert adapter_response.output == "mock response"
assert adapter_response.run is not None
assert adapter_response.run.id is not None
assert adapter_response.run.input == "You are a mock, send me the response!"
assert adapter_response.run.output.output == "mock response"
assert "created_by" in adapter_response.run.input_source.properties
assert adapter_response.run.output.source.properties == {
run = await adapter.invoke("You are a mock, send me the response!")
assert run.output.output == "mock response"
assert run is not None
assert run.id is not None
assert run.input == "You are a mock, send me the response!"
assert run.output.output == "mock response"
assert "created_by" in run.input_source.properties
assert run.output.source.properties == {
"adapter_name": "kiln_langchain_adapter",
"model_name": "custom.langchain:unknown_model",
"model_provider": "custom.langchain:FakeListChatModel",
Expand Down Expand Up @@ -152,18 +150,16 @@ async def run_simple_test(tmp_path: Path, model_name: str, provider: str | None
async def run_simple_task(task: datamodel.Task, model_name: str, provider: str):
adapter = LangChainPromptAdapter(task, model_name=model_name, provider=provider)

adapter_response = await adapter.invoke_returning_run(
run = await adapter.invoke(
"You should answer the following question: four plus six times 10"
)
assert "64" in adapter_response.output
assert adapter_response.run is not None
assert adapter_response.run.id is not None
assert "64" in run.output.output
assert run.id is not None
assert (
adapter_response.run.input
== "You should answer the following question: four plus six times 10"
run.input == "You should answer the following question: four plus six times 10"
)
assert "64" in adapter_response.run.output.output
assert adapter_response.run.output.source.properties == {
assert "64" in run.output.output
assert run.output.source.properties == {
"adapter_name": "kiln_langchain_adapter",
"model_name": model_name,
"model_provider": provider,
Expand Down
10 changes: 8 additions & 2 deletions libs/core/kiln_ai/adapters/test_saving_adapter_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,14 @@ async def test_autosave_false(test_task):
adapter = TestAdapter(test_task)
input_data = "Test input"

await adapter.invoke(input_data)
run = await adapter.invoke(input_data)

# Check that no runs were saved
assert len(test_task.runs()) == 0

# Check that the run ID is not set
assert run.id is None


@pytest.mark.asyncio
async def test_autosave_true(test_task):
Expand All @@ -145,7 +148,10 @@ async def test_autosave_true(test_task):
adapter = TestAdapter(test_task)
input_data = "Test input"

await adapter.invoke(input_data, None)
run = await adapter.invoke(input_data)

# Check that the run ID is set
assert run.id is not None

# Check that an task input was saved
task_runs = test_task.runs()
Expand Down
6 changes: 3 additions & 3 deletions libs/core/kiln_ai/adapters/test_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def test_mock_unstructred_response(tmp_path):

# don't error on valid response
adapter = MockAdapter(task, response={"setup": "asdf", "punchline": "asdf"})
answer = await adapter.invoke("You are a mock, send me the response!")
answer = await adapter.invoke_returning_raw("You are a mock, send me the response!")
assert answer["setup"] == "asdf"
assert answer["punchline"] == "asdf"

Expand Down Expand Up @@ -145,7 +145,7 @@ def build_structured_output_test_task(tmp_path: Path):
async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
task = build_structured_output_test_task(tmp_path)
a = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
parsed = await a.invoke("Cows") # a joke about cows
parsed = await a.invoke_returning_raw("Cows") # a joke about cows
if parsed is None or not isinstance(parsed, Dict):
raise RuntimeError(f"structured response is not a dict: {parsed}")
assert parsed["setup"] is not None
Expand Down Expand Up @@ -190,7 +190,7 @@ async def run_structured_input_test(tmp_path: Path, model_name: str, provider: s
# invalid structured input
await a.invoke({"a": 1, "b": 2, "d": 3})

response = await a.invoke({"a": 2, "b": 2, "c": 2})
response = await a.invoke_returning_raw({"a": 2, "b": 2, "c": 2})
assert response is not None
assert isinstance(response, str)
assert "[[equilateral]]" in response
Expand Down
9 changes: 7 additions & 2 deletions libs/core/kiln_ai/datamodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
)
from pydantic_core import ErrorDetails

# ID is a 12 digit random integer string. Should be unique per project.
# ID is a 12 digit random integer string.
# Should be unique per item, at least inside the context of a parent/child relationship.
# Use integers to make it easier to type for a search function.
# Allow none, even though we generate it, because we clear it in the REST API if the object is ephemeral (not persisted to disk)
ID_FIELD = Field(default_factory=lambda: str(uuid.uuid4().int)[:12])
ID_TYPE = str
ID_TYPE = Optional[str]
T = TypeVar("T", bound="KilnBaseModel")
PT = TypeVar("PT", bound="KilnParentedModel")

Expand Down Expand Up @@ -175,6 +177,9 @@ def check_parent_type(self) -> Self:
def build_child_dirname(self) -> Path:
# Default implementation for readable folder names.
# {id} - {name}/{type}.kiln
if self.id is None:
# consider generating an ID here. But if it's been cleared, we've already used this without one so raise for now.
raise ValueError("ID is not set - can not save or build path")
path = self.id
name = getattr(self, "name", None)
if name is not None:
Expand Down
1 change: 1 addition & 0 deletions libs/core/kiln_ai/datamodel/test_basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def parent_type(cls):

def test_parented_model_path_gen(tmp_path):
parent = KilnBaseModel(path=tmp_path)
assert parent.id is not None
child = NamedParentedModel(parent=parent)
child_path = child.build_path()
assert child_path.name == "named_parented_model.kiln"
Expand Down
16 changes: 2 additions & 14 deletions libs/studio/kiln_studio/run_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ class RunTaskRequest(BaseModel):
structured_input: Dict[str, Any] | None = None


class RunTaskResponse(BaseModel):
run: TaskRun | None = None
raw_output: str | None = None


def run_from_id(project_id: str, task_id: str, run_id: str) -> TaskRun:
task = task_from_id(project_id, task_id)
for run in task.runs():
Expand All @@ -63,7 +58,7 @@ async def get_run(project_id: str, task_id: str, run_id: str) -> TaskRun:
@app.post("/api/projects/{project_id}/tasks/{task_id}/run")
async def run_task(
project_id: str, task_id: str, request: RunTaskRequest
) -> RunTaskResponse:
) -> TaskRun:
parent_project = project_from_id(project_id)
task = next(
(task for task in parent_project.tasks() if task.id == task_id), None
Expand All @@ -88,14 +83,7 @@ async def run_task(
detail="No input provided. Ensure your provided the proper format (plaintext or structured).",
)

adapter_run = await adapter.invoke_returning_run(input)
response_output = None
if isinstance(adapter_run.output, str):
response_output = adapter_run.output
else:
response_output = json.dumps(adapter_run.output)

return RunTaskResponse(raw_output=response_output, run=adapter_run.run)
return await adapter.invoke(input)

@app.patch("/api/projects/{project_id}/tasks/{task_id}/runs/{run_id}")
async def update_run_route(
Expand Down
Loading

0 comments on commit d02405c

Please sign in to comment.