Skip to content

Commit

Permalink
Add eval config comparison summary API
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Feb 25, 2025
1 parent f0d4144 commit 7e51c3e
Show file tree
Hide file tree
Showing 2 changed files with 510 additions and 5 deletions.
224 changes: 222 additions & 2 deletions app/desktop/studio_server/eval_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List, Set
from typing import Any, Dict, List, Set, Tuple

from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse
Expand All @@ -12,6 +12,7 @@
DataSourceType,
PromptId,
Task,
TaskRun,
)
from kiln_ai.datamodel.basemodel import ID_TYPE
from kiln_ai.datamodel.dataset_filters import DatasetFilterId, dataset_filter_from_id
Expand All @@ -23,6 +24,7 @@
EvalRun,
EvalTemplate,
)
from kiln_ai.datamodel.json_schema import string_to_json_key
from kiln_ai.datamodel.prompt_id import is_frozen_prompt
from kiln_ai.datamodel.task import RunConfigProperties, TaskRunConfig
from kiln_ai.utils.name_generator import generate_memorable_name
Expand Down Expand Up @@ -119,12 +121,84 @@ class EvalResultSummary(BaseModel):
dataset_size: int


class EvalConfigScoreSummary(BaseModel):
mean_absolute_error: float
mean_squared_error: float


class EvalConfigCompareSummary(BaseModel):
# Summary of results. eval_config_id -> output_score_id -> ScoreSummary
results: Dict[str, Dict[str, EvalConfigScoreSummary]]
# eval_config_id -> percent of the dataset that has been processed (run with eval scores)
eval_config_percent_complete: Dict[str, float]
# The total size of the dataset used for the eval config comparisons (eval.eval_configs_filter_id set size)
dataset_size: int
# The number of dataset items which are fully rated, partially rated, or not rated at all.
fully_rated_count: int
partially_rated_count: int
not_rated_count: int


def dataset_ids_in_filter(task: Task, filter_id: DatasetFilterId) -> Set[ID_TYPE]:
# Fetch all the dataset items IDs in a filter
filter = dataset_filter_from_id(filter_id)
return {run.id for run in task.runs() if filter(run)}


def human_score_from_task_run(
task_run: TaskRun,
score_key: str,
score_key_to_task_requirement_id: Dict[str, ID_TYPE],
) -> float | None:
if not task_run.output.rating:
return None

human_score: float | None = None
if score_key == "overall_rating":
human_score = task_run.output.rating.value
else:
req_rating = task_run.output.rating.requirement_ratings.get(
score_key_to_task_requirement_id[score_key], None
)
if req_rating is not None:
human_score = req_rating.value

return human_score


def count_human_evals(
items: Set[TaskRun],
eval: Eval,
score_key_to_task_requirement_id: Dict[str, ID_TYPE],
) -> Tuple[int, int, int]:
# Track how often we are missing human evals in dataset items
fully_rated_count: int = 0
partially_rated_count: int = 0
not_rated_count: int = 0
for dataset_item in items:
# Check it has all scores
has_all_scores = True
has_any_scores = False
for output_score in eval.output_scores:
score_key = output_score.json_key()
score = human_score_from_task_run(
dataset_item, score_key, score_key_to_task_requirement_id
)
if score is None:
has_all_scores = False
else:
has_any_scores = True

if not has_any_scores:
not_rated_count += 1
elif has_all_scores:
fully_rated_count += 1
else:
partially_rated_count += 1

return fully_rated_count, partially_rated_count, not_rated_count


def connect_evals_api(app: FastAPI):
@app.post("/api/projects/{project_id}/tasks/{task_id}/create_evaluator")
async def create_evaluator(
Expand Down Expand Up @@ -168,6 +242,15 @@ async def get_eval_configs(
eval = eval_from_id(project_id, task_id, eval_id)
return eval.configs()

@app.get(
"/api/projects/{project_id}/tasks/{task_id}/eval/{eval_id}/eval_config/{eval_config_id}"
)
async def get_eval_config(
project_id: str, task_id: str, eval_id: str, eval_config_id: str
) -> EvalConfig:
eval_config = eval_config_from_id(project_id, task_id, eval_id, eval_config_id)
return eval_config

@app.post("/api/projects/{project_id}/tasks/{task_id}/task_run_config")
async def create_task_run_config(
project_id: str,
Expand Down Expand Up @@ -368,7 +451,7 @@ async def get_eval_config_score_summary(

# Check if we should count this eval_run. Not every eval_run has to go into the stats:
# - a dataset_id can be removed from the dataset filter (removed a tag)
# - this dataset_id was already counted (okay there are dupes, but shouldn't be double counted)
# - this dataset_id was already counted (not great there are dupes, but really shouldn't be double counted)
if eval_run.dataset_id not in remaining_expected_dataset_ids[run_config_id]:
continue
else:
Expand Down Expand Up @@ -421,3 +504,140 @@ async def get_eval_config_score_summary(
run_config_percent_complete=run_config_percent_complete,
dataset_size=len(expected_dataset_ids),
)

# Compared to above, this is comparing all eval configs to each other, not looking at a single eval config
@app.get(
"/api/projects/{project_id}/tasks/{task_id}/eval/{eval_id}/eval_configs_score_summary"
)
async def get_eval_configs_score_summary(
project_id: str,
task_id: str,
eval_id: str,
) -> EvalConfigCompareSummary:
task = task_from_id(project_id, task_id)
eval = eval_from_id(project_id, task_id, eval_id)
eval_configs = eval.configs(readonly=True)

# Create a map of score_key -> Task requirement ID
score_key_to_task_requirement_id: Dict[str, ID_TYPE] = {}
for task_requirement in task.requirements:
score_key = string_to_json_key(task_requirement.name)
score_key_to_task_requirement_id[score_key] = task_requirement.id

# Build a set of all the dataset items IDs we expect to have scores for
# Fetch all the dataset items in a filter, and return a map of dataset_id -> TaskRun
filter = dataset_filter_from_id(eval.eval_configs_filter_id)
expected_dataset_items = {run.id: run for run in task.runs() if filter(run)}
expected_dataset_ids = set(expected_dataset_items.keys())
if len(expected_dataset_ids) == 0:
return EvalConfigCompareSummary(
results={},
eval_config_percent_complete={},
dataset_size=0,
fully_rated_count=0,
partially_rated_count=0,
not_rated_count=0,
)

# save a copy of the expected dataset ids for each eval config, we'll update each as we process each eval run
remaining_expected_dataset_ids: Dict[str, Set[ID_TYPE]] = {
str(eval_config.id): set(expected_dataset_ids)
for eval_config in eval_configs
}

# eval_config_id -> output_score_id -> scores/total
total_squared_error: Dict[str, Dict[str, float]] = {}
total_absolute_error: Dict[str, Dict[str, float]] = {}
total_count: Dict[str, Dict[str, int]] = {}

# important: readonly makes this much faster
for eval_config in eval_configs:
eval_config_id = str(eval_config.id)
for eval_run in eval_config.runs(readonly=True):
dataset_item = expected_dataset_items.get(eval_run.dataset_id, None)
if dataset_item is None:
# A dataset_id can be removed from the dataset filter (ran previously, then removed the tag to remove it from the eval config set filter)
# A dataset_id could be for an run_config, not for comparing eval at all
continue

# Check if we should count this eval_run. Not every eval_run has to go into the stats:
# Example: this dataset_id was already counted (not great there are dupes, but really shouldn't be double counted)
if (
eval_run.dataset_id
not in remaining_expected_dataset_ids[eval_config_id]
):
continue
else:
remaining_expected_dataset_ids[eval_config_id].remove(
eval_run.dataset_id
)

for output_score in eval.output_scores:
score_key = output_score.json_key()
eval_score: float | None = eval_run.scores.get(score_key, None)

# Fetch the human eval score from the dataset item
human_score = human_score_from_task_run(
dataset_item, score_key, score_key_to_task_requirement_id
)

if human_score is None or eval_score is None:
# This score doesn't have both a human eval and eval score, so we can't compare
continue

if eval_config_id not in total_squared_error:
total_squared_error[eval_config_id] = {}
total_absolute_error[eval_config_id] = {}
total_count[eval_config_id] = {}
if score_key not in total_squared_error[eval_config_id]:
total_squared_error[eval_config_id][score_key] = 0
total_absolute_error[eval_config_id][score_key] = 0
total_count[eval_config_id][score_key] = 0

# TODO normalize MSE?
total_squared_error[eval_config_id][score_key] += (
eval_score - human_score
) ** 2
total_absolute_error[eval_config_id][score_key] += abs(
eval_score - human_score
)
total_count[eval_config_id][score_key] += 1

# Convert to score summaries
results: Dict[str, Dict[str, EvalConfigScoreSummary]] = {}
for eval_config_id in total_count.keys():
results[eval_config_id] = {}
for score_key in total_count[eval_config_id].keys():
count = total_count[eval_config_id][score_key]
if count > 0:
results[eval_config_id][score_key] = EvalConfigScoreSummary(
mean_squared_error=(
total_squared_error[eval_config_id][score_key] / count
),
mean_absolute_error=(
total_absolute_error[eval_config_id][score_key] / count
),
)

# Calculate the percent of the dataset that has been processed
eval_config_percent_complete: Dict[str, float] = {}
for eval_config in eval_configs:
eval_config_id = str(eval_config.id)
# Partial incomplete (missing scores), and fully incomplete (no eval_run)
incomplete_count = len(remaining_expected_dataset_ids[eval_config_id])
percent_incomplete = incomplete_count / len(expected_dataset_ids)
eval_config_percent_complete[str(eval_config.id)] = 1 - percent_incomplete

# Count how many dataset items have human evals
fully_rated_count, partially_rated_count, not_rated_count = count_human_evals(
expected_dataset_items.values(), eval, score_key_to_task_requirement_id
)

return EvalConfigCompareSummary(
results=results,
eval_config_percent_complete=eval_config_percent_complete,
dataset_size=len(expected_dataset_ids),
fully_rated_count=fully_rated_count,
partially_rated_count=partially_rated_count,
not_rated_count=not_rated_count,
)
Loading

0 comments on commit 7e51c3e

Please sign in to comment.