Skip to content

Commit

Permalink
separate run and task into their own files
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Oct 17, 2024
1 parent 36460c8 commit 2e1f373
Show file tree
Hide file tree
Showing 5 changed files with 619 additions and 584 deletions.
118 changes: 118 additions & 0 deletions libs/studio/kiln_studio/run_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import json
from asyncio import Lock
from typing import Any, Dict

from fastapi import FastAPI, HTTPException
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
from kiln_ai.datamodel import TaskRun
from pydantic import BaseModel

from libs.studio.kiln_studio.project_api import project_from_id

# Lock to prevent overwriting via concurrent updates. We use a load/update/write pattern that is not atomic.
update_run_lock = Lock()


def deep_update(
source: Dict[str, Any] | None, update: Dict[str, Any | None]
) -> Dict[str, Any]:
if source is None:
return {k: v for k, v in update.items() if v is not None}
for key, value in update.items():
if value is None:
source.pop(key, None)
elif isinstance(value, dict):
if key not in source or not isinstance(source[key], dict):
source[key] = {}
source[key] = deep_update(source[key], value)
else:
source[key] = value
return {k: v for k, v in source.items() if v is not None}


class RunTaskRequest(BaseModel):
model_name: str
provider: str
plaintext_input: str | None = None
structured_input: Dict[str, Any] | None = None


class RunTaskResponse(BaseModel):
run: TaskRun | None = None
raw_output: str | None = None


def connect_run_api(app: FastAPI):
@app.post("/api/projects/{project_id}/task/{task_id}/run")
async def run_task(
project_id: str, task_id: str, request: RunTaskRequest
) -> RunTaskResponse:
parent_project = project_from_id(project_id)
task = next(
(task for task in parent_project.tasks() if task.id == task_id), None
)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Task not found. ID: {task_id}",
)

adapter = LangChainPromptAdapter(
task, model_name=request.model_name, provider=request.provider
)

input = request.plaintext_input
if task.input_schema() is not None:
input = request.structured_input

if input is None:
raise HTTPException(
status_code=400,
detail="No input provided. Ensure your provided the proper format (plaintext or structured).",
)

adapter_run = await adapter.invoke_returning_run(input)
response_output = None
if isinstance(adapter_run.output, str):
response_output = adapter_run.output
else:
response_output = json.dumps(adapter_run.output)

return RunTaskResponse(raw_output=response_output, run=adapter_run.run)

@app.patch("/api/projects/{project_id}/task/{task_id}/run/{run_id}")
async def update_run_route(
project_id: str, task_id: str, run_id: str, run_data: Dict[str, Any]
) -> TaskRun:
return await update_run(project_id, task_id, run_id, run_data)


async def update_run(
project_id: str, task_id: str, run_id: str, run_data: Dict[str, Any]
) -> TaskRun:
# Lock to prevent overwriting concurrent updates
async with update_run_lock:
parent_project = project_from_id(project_id)
task = next(
(task for task in parent_project.tasks() if task.id == task_id), None
)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Task not found. ID: {task_id}",
)

run = next((run for run in task.runs() if run.id == run_id), None)
if run is None:
raise HTTPException(
status_code=404,
detail=f"Run not found. ID: {run_id}",
)

# Update and save
old_run_dumped = run.model_dump()
merged = deep_update(old_run_dumped, run_data)
updated_run = TaskRun.model_validate(merged)
updated_run.path = run.path
updated_run.save_to_file()
return updated_run
2 changes: 2 additions & 0 deletions libs/studio/kiln_studio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .custom_errors import connect_custom_errors
from .project_api import connect_project_api
from .provider_api import connect_provider_api
from .run_api import connect_run_api
from .settings import connect_settings
from .task_api import connect_task_api
from .webhost import connect_webhost
Expand All @@ -31,6 +32,7 @@ def ping():
connect_project_api(app)
connect_provider_api(app)
connect_task_api(app)
connect_run_api(app)
connect_settings(app)
connect_custom_errors(app)

Expand Down
112 changes: 1 addition & 111 deletions libs/studio/kiln_studio/task_api.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,10 @@
import json
from asyncio import Lock
from typing import Any, Dict

from fastapi import FastAPI, HTTPException
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
from kiln_ai.datamodel import Task, TaskRun
from pydantic import BaseModel
from kiln_ai.datamodel import Task

from libs.studio.kiln_studio.project_api import project_from_id

# Add this at the module level
update_run_lock = Lock()


class RunTaskRequest(BaseModel):
model_name: str
provider: str
plaintext_input: str | None = None
structured_input: Dict[str, Any] | None = None


class RunTaskResponse(BaseModel):
run: TaskRun | None = None
raw_output: str | None = None


def deep_update(
source: Dict[str, Any] | None, update: Dict[str, Any | None]
) -> Dict[str, Any]:
if source is None:
return {k: v for k, v in update.items() if v is not None}
for key, value in update.items():
if value is None:
source.pop(key, None)
elif isinstance(value, dict):
if key not in source or not isinstance(source[key], dict):
source[key] = {}
source[key] = deep_update(source[key], value)
else:
source[key] = value
return {k: v for k, v in source.items() if v is not None}


def connect_task_api(app: FastAPI):
@app.post("/api/projects/{project_id}/task")
Expand Down Expand Up @@ -75,77 +39,3 @@ async def get_task(project_id: str, task_id: str):
status_code=404,
detail=f"Task not found. ID: {task_id}",
)

@app.post("/api/projects/{project_id}/task/{task_id}/run")
async def run_task(
project_id: str, task_id: str, request: RunTaskRequest
) -> RunTaskResponse:
parent_project = project_from_id(project_id)
task = next(
(task for task in parent_project.tasks() if task.id == task_id), None
)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Task not found. ID: {task_id}",
)

adapter = LangChainPromptAdapter(
task, model_name=request.model_name, provider=request.provider
)

input = request.plaintext_input
if task.input_schema() is not None:
input = request.structured_input

if input is None:
raise HTTPException(
status_code=400,
detail="No input provided. Ensure your provided the proper format (plaintext or structured).",
)

adapter_run = await adapter.invoke_returning_run(input)
response_output = None
if isinstance(adapter_run.output, str):
response_output = adapter_run.output
else:
response_output = json.dumps(adapter_run.output)

return RunTaskResponse(raw_output=response_output, run=adapter_run.run)

@app.patch("/api/projects/{project_id}/task/{task_id}/run/{run_id}")
async def update_run_route(
project_id: str, task_id: str, run_id: str, run_data: Dict[str, Any]
) -> TaskRun:
return await update_run(project_id, task_id, run_id, run_data)


async def update_run(
project_id: str, task_id: str, run_id: str, run_data: Dict[str, Any]
) -> TaskRun:
# Lock to prevent overwriting concurrent updates
async with update_run_lock:
parent_project = project_from_id(project_id)
task = next(
(task for task in parent_project.tasks() if task.id == task_id), None
)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Task not found. ID: {task_id}",
)

run = next((run for run in task.runs() if run.id == run_id), None)
if run is None:
raise HTTPException(
status_code=404,
detail=f"Run not found. ID: {run_id}",
)

# Update and save
old_run_dumped = run.model_dump()
merged = deep_update(old_run_dumped, run_data)
updated_run = TaskRun.model_validate(merged)
updated_run.path = run.path
updated_run.save_to_file()
return updated_run
Loading

0 comments on commit 2e1f373

Please sign in to comment.