Skip to content

Commit

Permalink
Merge pull request #221 from Kiln-AI/better_prompts
Browse files Browse the repository at this point in the history
Much better prompt system for evals
  • Loading branch information
scosman authored Feb 23, 2025
2 parents 461a74c + a46b942 commit 0759fb2
Show file tree
Hide file tree
Showing 24 changed files with 366 additions and 235 deletions.
43 changes: 32 additions & 11 deletions app/desktop/studio_server/eval_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
EvalOutputScore,
EvalTemplate,
)
from kiln_ai.datamodel.prompt_id import is_frozen_prompt
from kiln_ai.datamodel.task import RunConfigProperties, TaskRunConfig
from kiln_ai.utils.name_generator import generate_memorable_name
from kiln_server.task_api import task_from_id
Expand Down Expand Up @@ -168,6 +169,31 @@ async def create_task_run_config(
) -> TaskRunConfig:
task = task_from_id(project_id, task_id)
name = request.name or generate_memorable_name()

parent_project = task.parent_project()
if parent_project is None:
raise HTTPException(
status_code=400,
detail="Task must have a parent project.",
)

frozen_prompt: BasePrompt | None = None
if not is_frozen_prompt(request.prompt_id):
# For dynamic prompts, we "freeze" a copy of this prompt into the task run config so we don't accidentially invalidate evals if the user changes something that impacts the prompt (example: chanding data for multi-shot, or chanding task for basic-prompt)
# We then point the task_run_config.run_properties.prompt_id to this new frozen prompt
prompt_builder = prompt_builder_from_id(request.prompt_id, task)
prompt_name = generate_memorable_name()
frozen_prompt = BasePrompt(
name=prompt_name,
long_name=prompt_name
+ " (frozen prompt from '"
+ request.prompt_id
+ "')",
generator_id=request.prompt_id,
prompt=prompt_builder.build_base_prompt(),
chain_of_thought_instructions=prompt_builder.chain_of_thought_prompt(),
)

task_run_config = TaskRunConfig(
parent=task,
name=name,
Expand All @@ -177,7 +203,13 @@ async def create_task_run_config(
model_provider_name=request.model_provider_name,
prompt_id=request.prompt_id,
),
prompt=frozen_prompt,
)
if frozen_prompt is not None:
# Set after, because the ID isn't known until the TaskRunConfig is created
task_run_config.run_config_properties.prompt_id = (
f"task_run_config::{parent_project.id}::{task.id}::{task_run_config.id}"
)
task_run_config.save_to_file()
return task_run_config

Expand All @@ -190,19 +222,9 @@ async def create_eval_config(
eval_id: str,
request: CreateEvalConfigRequest,
) -> EvalConfig:
task = task_from_id(project_id, task_id)
eval = eval_from_id(project_id, task_id, eval_id)
name = request.name or generate_memorable_name()

# Create a prompt instance to save to the eval config
prompt_builder = prompt_builder_from_id(request.prompt_id, task)
prompt = BasePrompt(
name=request.prompt_id,
generator_id=request.prompt_id,
prompt=prompt_builder.build_base_prompt(),
chain_of_thought_instructions=prompt_builder.chain_of_thought_prompt(),
)

eval_config = EvalConfig(
name=name,
config_type=request.type,
Expand All @@ -215,7 +237,6 @@ async def create_eval_config(
"adapter_name": "kiln_eval",
},
),
prompt=prompt,
parent=eval,
)
eval_config.save_to_file()
Expand Down
100 changes: 75 additions & 25 deletions app/desktop/studio_server/test_eval_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BasePrompt,
DataSource,
DataSourceType,
Project,
PromptId,
Task,
)
Expand Down Expand Up @@ -47,12 +48,19 @@ def client(app):

@pytest.fixture
def mock_task(tmp_path):
project = Project(
id="project1",
name="Test Project",
path=tmp_path / "project.kiln",
)
project.save_to_file()
task = Task(
id="task1",
name="Test Task",
description="Test Description",
instruction="Test Instructions",
path=tmp_path / "task.kiln",
parent=project,
)
task.save_to_file()
return task
Expand Down Expand Up @@ -207,19 +215,28 @@ async def test_create_evaluator(


@pytest.mark.asyncio
async def test_create_task_run_config(client, mock_task_from_id, mock_task):
async def test_create_task_run_config_with_freezing(
client, mock_task_from_id, mock_task
):
mock_task_from_id.return_value = mock_task

response = client.post(
"/api/projects/project1/tasks/task1/task_run_config",
json={
"name": "Test Task Run Config",
"description": "Test Description",
"model_name": "gpt-4o",
"model_provider_name": "openai",
"prompt_id": "simple_chain_of_thought_prompt_builder",
},
)
with (
patch(
"app.desktop.studio_server.eval_api.generate_memorable_name"
) as mock_generate_memorable_name,
):
mock_generate_memorable_name.return_value = "Custom Name"

response = client.post(
"/api/projects/project1/tasks/task1/task_run_config",
json={
"name": "Test Task Run Config",
"description": "Test Description",
"model_name": "gpt-4o",
"model_provider_name": "openai",
"prompt_id": "simple_chain_of_thought_prompt_builder",
},
)

assert response.status_code == 200
result = response.json()
Expand All @@ -229,16 +246,61 @@ async def test_create_task_run_config(client, mock_task_from_id, mock_task):
assert result["run_config_properties"]["model_provider_name"] == "openai"
assert (
result["run_config_properties"]["prompt_id"]
== "simple_chain_of_thought_prompt_builder"
== "task_run_config::project1::task1::" + result["id"]
)
assert result["prompt"]["name"] == "Custom Name"
assert (
result["prompt"]["long_name"]
== "Custom Name (frozen prompt from 'simple_chain_of_thought_prompt_builder')"
)

# Fetch it from API
fetch_response = client.get("/api/projects/project1/tasks/task1/task_run_configs")
assert fetch_response.status_code == 200
configs = fetch_response.json()
assert len(configs) == 1
assert configs[0]["id"] == result["id"]
assert configs[0]["name"] == result["name"]
assert configs[0]["prompt"]["name"] == "Custom Name"
assert configs[0]["prompt"]["long_name"] == (
"Custom Name (frozen prompt from 'simple_chain_of_thought_prompt_builder')"
)
assert configs[0]["run_config_properties"]["prompt_id"] == (
"task_run_config::project1::task1::" + result["id"]
)


@pytest.mark.asyncio
async def test_create_task_run_config_without_freezing(
client, mock_task_from_id, mock_task
):
mock_task_from_id.return_value = mock_task

with (
patch(
"app.desktop.studio_server.eval_api.generate_memorable_name"
) as mock_generate_memorable_name,
):
mock_generate_memorable_name.return_value = "Custom Name"

response = client.post(
"/api/projects/project1/tasks/task1/task_run_config",
json={
"name": "Test Task Run Config",
"description": "Test Description",
"model_name": "gpt-4o",
"model_provider_name": "openai",
"prompt_id": "id::prompt_123",
},
)

assert response.status_code == 200
result = response.json()
assert result["name"] == "Test Task Run Config"
assert result["description"] == "Test Description"
assert result["run_config_properties"]["model_name"] == "gpt-4o"
assert result["run_config_properties"]["model_provider_name"] == "openai"
assert result["run_config_properties"]["prompt_id"] == "id::prompt_123"
assert result["prompt"] is None


@pytest.mark.asyncio
Expand All @@ -249,15 +311,8 @@ async def test_create_eval_config(

with (
patch("app.desktop.studio_server.eval_api.eval_from_id") as mock_eval_from_id,
patch(
"app.desktop.studio_server.eval_api.prompt_builder_from_id"
) as mock_prompt_builder,
):
mock_eval_from_id.return_value = mock_eval
mock_prompt_builder.return_value.build_base_prompt.return_value = "base prompt"
mock_prompt_builder.return_value.chain_of_thought_prompt.return_value = (
"cot prompt"
)

response = client.post(
"/api/projects/project1/tasks/task1/eval/eval1/create_eval_config",
Expand All @@ -278,8 +333,6 @@ async def test_create_eval_config(
result["model"]["properties"]["model_provider"]
== valid_eval_config_request.provider
)
assert isinstance(result["prompt"], dict)
# mock_save.assert_called_once()

# Fetch disk
assert len(mock_eval.configs()) == 1
Expand All @@ -291,8 +344,6 @@ async def test_create_eval_config(
assert (
config.model.properties["model_provider"] == valid_eval_config_request.provider
)
assert config.prompt.prompt == "base prompt"
assert config.prompt.chain_of_thought_instructions == "cot prompt"
assert config.properties["eval_steps"][0] == "step1"
assert config.properties["eval_steps"][1] == "step2"

Expand All @@ -317,7 +368,6 @@ def test_get_eval_configs(
assert config["config_type"] == mock_eval_config.config_type
assert config["properties"] == mock_eval_config.properties
assert config["model"]["type"] == mock_eval_config.model.type
assert isinstance(config["prompt"], dict)

mock_eval_from_id.assert_called_once_with("project1", "task1", "eval1")

Expand Down
50 changes: 47 additions & 3 deletions app/web_ui/src/lib/api_schema.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,40 @@ export interface paths {
export type webhooks = Record<string, never>;
export interface components {
schemas: {
/** ApiPrompt */
ApiPrompt: {
/**
* Name
* @description A name for this entity.
*/
name: string;
/**
* Long Name
* @description A more detailed name for the prompt, usually incorporating the source of the prompt.
*/
long_name?: string | null;
/**
* Generator Id
* @description The id of the generator that created this prompt.
*/
generator_id?: string | null;
/**
* Prompt
* @description The prompt for the task.
*/
prompt: string;
/**
* Chain Of Thought Instructions
* @description Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting. COT will not be used unless this is provided.
*/
chain_of_thought_instructions?: string | null;
/** Id */
id: string;
/** Created At */
created_at?: string | null;
/** Created By */
created_by?: string | null;
};
/** AvailableModels */
AvailableModels: {
/** Provider Name */
Expand All @@ -835,6 +869,11 @@ export interface components {
* @description A name for this entity.
*/
name: string;
/**
* Long Name
* @description A more detailed name for the prompt, usually incorporating the source of the prompt.
*/
long_name?: string | null;
/**
* Generator Id
* @description The id of the generator that created this prompt.
Expand Down Expand Up @@ -1256,8 +1295,6 @@ export interface components {
* @default {}
*/
properties: Record<string, never>;
/** @description The prompt to use for this eval config. Both when running the task to generate outputs to evaluate and when explaining to the eval model what the goal of the task was. This is a frozen prompt, so this eval config is consistent over time (for example, if the user selects multi-shot prompting, this saves that dynamic prompt at the point the eval config is created). Freezing the prompt ensures consistent evals. */
prompt: components["schemas"]["BasePrompt"];
/** Model Type */
readonly model_type: string;
};
Expand Down Expand Up @@ -1658,6 +1695,11 @@ export interface components {
* @description A name for this entity.
*/
name: string;
/**
* Long Name
* @description A more detailed name for the prompt, usually incorporating the source of the prompt.
*/
long_name?: string | null;
/**
* Generator Id
* @description The id of the generator that created this prompt.
Expand Down Expand Up @@ -1726,7 +1768,7 @@ export interface components {
/** Generators */
generators: components["schemas"]["PromptGenerator"][];
/** Prompts */
prompts: components["schemas"]["Prompt"][];
prompts: components["schemas"]["ApiPrompt"][];
};
/** ProviderModel */
ProviderModel: {
Expand Down Expand Up @@ -2255,6 +2297,8 @@ export interface components {
description?: string | null;
/** @description The run config properties to use for this task run. */
run_config_properties: components["schemas"]["RunConfigProperties"];
/** @description A prompt to use for run config. */
prompt?: components["schemas"]["BasePrompt"] | null;
/** Model Type */
readonly model_type: string;
};
Expand Down
2 changes: 1 addition & 1 deletion app/web_ui/src/lib/stores.ts
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ export function prompt_name_from_id(prompt_id: string): string {
}
if (!prompt_name) {
prompt_name = get(current_task_prompts)?.prompts.find(
(prompt) => "id::" + prompt.id === prompt_id,
(prompt) => prompt.id === prompt_id,
)?.name
}
if (!prompt_name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@
parts.push(
model_name(eval_config.model.properties["model_name"], model_info),
)
parts.push(prompt_name_from_id(eval_config.prompt.name))
return eval_config.name + "" + parts.join(", ")
}
Expand Down Expand Up @@ -317,11 +316,6 @@
eval_config.model.properties["model_provider"] + "",
),
})
// TODO remove this once we consolidate prompts
properties.push({
name: "Prompt",
value: prompt_name_from_id(eval_config.prompt.name + ""),
})
return properties
}
Expand Down Expand Up @@ -658,9 +652,12 @@
)}
</div>
<div class="text-sm text-gray-500">
{prompt_name_from_id(
task_run_config?.run_config_properties?.prompt_id,
)}
Prompt:
{task_run_config.prompt?.long_name ||
task_run_config.prompt?.name ||
prompt_name_from_id(
task_run_config?.run_config_properties?.prompt_id,
)}
</div>
{#if percent_complete}
<div
Expand Down
Loading

0 comments on commit 0759fb2

Please sign in to comment.