Skip to content

Commit

Permalink
fix(wren-ai-service): Settings Loading and Multi-Pipeline Component M…
Browse files Browse the repository at this point in the history
…echanism for Evaluation Framework (#1198)
  • Loading branch information
paopa authored Jan 21, 2025
1 parent 51876dd commit cfd8367
Show file tree
Hide file tree
Showing 15 changed files with 1,301 additions and 1,342 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ wren-ai-service/demo/custom_dataset
wren-ai-service/demo/.env
wren-ai-service/tools/dev/etc/**
.deepeval-cache.json
.deepeval_telemtry.txt
docker/config.yaml

# python
Expand Down
12 changes: 0 additions & 12 deletions wren-ai-service/eval/.env.example

This file was deleted.

3 changes: 2 additions & 1 deletion wren-ai-service/eval/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.env
.env
config.yaml
21 changes: 21 additions & 0 deletions wren-ai-service/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pydantic import Field, SecretStr

from src.config import Settings


class EvalSettings(Settings):
langfuse_project_id: str = ""
batch_size: int = 4
batch_interval: int = 1
datasource: str = "bigquery"
config_path: str = "eval/config.yaml"
openai_api_key: SecretStr = Field(alias="LLM_OPENAI_API_KEY")

@property
def langfuse_url(self) -> str:
if not self.langfuse_project_id:
return ""
return f"{self.langfuse_host.rstrip('/')}/project/{self.langfuse_project_id}"

def get_openai_api_key(self) -> str:
return self.openai_api_key.get_secret_value()
40 changes: 22 additions & 18 deletions wren-ai-service/eval/data_curation/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
import asyncio
import os
import re
import sys
import uuid
from datetime import datetime
from pathlib import Path

import orjson
import pandas as pd
import streamlit as st
import tomlkit
from openai import AsyncClient
from streamlit_tags import st_tags

sys.path.append(f"{Path().parent.resolve()}")
from eval import EvalSettings
from eval.utils import (
get_documents_given_contexts,
get_eval_dataset_in_toml_string,
get_openai_client,
prepare_duckdb_init_sql,
prepare_duckdb_session_sql,
)
from utils import (
DATA_SOURCES,
WREN_ENGINE_ENDPOINT,
Expand All @@ -21,22 +33,14 @@
prettify_sql,
)

from eval.utils import (
get_documents_given_contexts,
get_eval_dataset_in_toml_string,
get_openai_client,
prepare_duckdb_init_sql,
prepare_duckdb_session_sql,
)

st.set_page_config(layout="wide")
st.title("WrenAI Data Curation App")


LLM_OPTIONS = ["gpt-4o-mini", "gpt-4o"]

llm_client = get_openai_client()

settings = EvalSettings()
llm_client = get_openai_client(api_key=settings.get_openai_api_key())

# session states
if "llm_model" not in st.session_state:
Expand All @@ -63,9 +67,9 @@
def on_change_upload_eval_dataset():
doc = tomlkit.parse(st.session_state.uploaded_eval_file.getvalue().decode("utf-8"))

assert (
doc["mdl"] == st.session_state["mdl_json"]
), "The model in the uploaded dataset is different from the deployed model"
assert doc["mdl"] == st.session_state["mdl_json"], (
"The model in the uploaded dataset is different from the deployed model"
)
st.session_state["candidate_dataset"] = doc["eval_dataset"]


Expand Down Expand Up @@ -93,7 +97,7 @@ def on_click_setup_uploaded_file():
uploaded_file = st.session_state.get("uploaded_mdl_file")
if uploaded_file:
match = re.match(
r".+_(" + "|".join(DATA_SOURCES) + r")_mdl\.json$",
r".+_(" + "|".join(DATA_SOURCES) + r")(_.+)?_mdl\.json$",
uploaded_file.name,
)
if not match:
Expand Down Expand Up @@ -126,7 +130,7 @@ def on_click_setup_uploaded_file():


def on_change_llm_model():
st.toast(f"Switching LLM model to {st.session_state["select_llm_model"]}")
st.toast(f"Switching LLM model to {st.session_state['select_llm_model']}")
st.session_state["llm_model"] = st.session_state["select_llm_model"]


Expand Down Expand Up @@ -338,7 +342,7 @@ def on_click_remove_candidate_dataset_button(i: int):
)
else:
st.error(
f"SQL is invalid: {st.session_state["llm_question_sql_pairs"][i]["error"]}"
f"SQL is invalid: {st.session_state['llm_question_sql_pairs'][i]['error']}"
)

st.button(
Expand Down Expand Up @@ -417,7 +421,7 @@ def on_click_remove_candidate_dataset_button(i: int):
)
else:
st.error(
f"SQL is invalid: {st.session_state.get("user_question_sql_pair", {}).get('error', '')}"
f"SQL is invalid: {st.session_state.get('user_question_sql_pair', {}).get('error', '')}"
)

st.button(
Expand Down Expand Up @@ -475,7 +479,7 @@ def on_click_remove_candidate_dataset_button(i: int):
with st.popover("Save as Evaluation Dataset", use_container_width=True):
file_name = st.text_input(
"File Name",
f'eval_dataset_{datetime.today().strftime("%Y_%m_%d")}.toml',
f"eval_dataset_{datetime.today().strftime('%Y_%m_%d')}.toml",
key="eval_dataset_file_name",
)
download_btn = st.download_button(
Expand Down
50 changes: 23 additions & 27 deletions wren-ai-service/eval/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
import argparse
import os
import sys
from pathlib import Path
from typing import Tuple

import dotenv
from deepeval import evaluate
from deepeval.evaluate import TestResult
from deepeval.test_case import LLMTestCase
from dspy_modules.prompt_optimizer import (
build_optimizing_module,
configure_llm_provider,
optimizer_parameters,
prepare_dataset,
)
from langfuse import Langfuse
from langfuse.decorators import langfuse_context, observe

sys.path.append(f"{Path().parent.resolve()}")
import traceback

import eval.pipelines as pipelines
from eval.utils import parse_toml, trace_metadata
import src.providers as provider
from eval import EvalSettings
from eval.utils import engine_config, parse_toml, trace_metadata
from src import utils


Expand Down Expand Up @@ -88,7 +82,7 @@ def eval(self, meta: dict, predictions: list) -> None:

try:
test_case = LLMTestCase(**formatter(prediction, meta))
result = evaluate([test_case], self._metrics, ignore_errors=True)[0]
result = evaluate([test_case], self._metrics, ignore_errors=True).test_results[0]
self._score_metrics(test_case, result)
[metric.collect(test_case, result) for metric in self._post_metrics]
except Exception:
Expand Down Expand Up @@ -149,31 +143,33 @@ def _average_score(self, meta: dict) -> None:
if __name__ == "__main__":
args = parse_args()

dotenv.load_dotenv()
utils.load_env_vars()
settings = EvalSettings()
pipe_components = provider.generate_components(settings.components)
utils.init_langfuse(settings)

predicted_file = parse_toml(f"outputs/predictions/{args.file}")
meta = predicted_file["meta"]
predictions = predicted_file["predictions"]

dataset = parse_toml(meta["evaluation_dataset"])
metrics = pipelines.metrics_initiator(
meta["pipeline"], dataset["mdl"], args.semantics
)
engine_info = engine_config(dataset["mdl"], pipe_components)
metrics = pipelines.metrics_initiator(meta["pipeline"], engine_info, args.semantics)

evaluator = Evaluator(**metrics)
if args.training_dataset:
optimizer_parameters["evaluator"] = evaluator
optimizer_parameters["metrics"] = metrics
optimizer_parameters["meta"] = meta
optimizer_parameters["predictions"] = predictions
configure_llm_provider(
os.getenv("GENERATION_MODEL"), os.getenv("LLM_OPENAI_API_KEY")
)
trainset, devset = prepare_dataset(args.training_dataset)
build_optimizing_module(trainset)
else:
evaluator.eval(meta, predictions)
evaluator.eval(meta, predictions)
# if args.training_dataset:
# # todo: for now comment dspy related code
# optimizer_parameters["evaluator"] = evaluator
# optimizer_parameters["metrics"] = metrics
# optimizer_parameters["meta"] = meta
# optimizer_parameters["predictions"] = predictions
# configure_llm_provider(
# os.getenv("GENERATION_MODEL"), os.getenv("LLM_OPENAI_API_KEY")
# )
# trainset, devset = prepare_dataset(args.training_dataset)
# build_optimizing_module(trainset)
# else:
# evaluator.eval(meta, predictions)

langfuse_context.flush()

Expand Down
12 changes: 6 additions & 6 deletions wren-ai-service/eval/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@


class AccuracyMetric(BaseMetric):
def __init__(self, engine_config: dict, enable_semantics_comparison: bool = False):
def __init__(self, engine_info: dict, enable_semantics_comparison: bool = False):
self.threshold = 0
self.score = 0
self._engine_config = engine_config
self._enable_semantics_comparison = enable_semantics_comparison
if self._enable_semantics_comparison:
self.engine_info = engine_info
self.enable_semantics_comparison = enable_semantics_comparison
if self.enable_semantics_comparison:
self._openai_client = get_openai_client()

def measure(self, test_case: LLMTestCase):
Expand Down Expand Up @@ -82,7 +82,7 @@ def _rewrite_sql(self, sql: str) -> str:
return sql

async def _retrieve_data(self, sql: str) -> pd.DataFrame:
response = await get_data_from_wren_engine(sql=sql, **self._engine_config)
response = await get_data_from_wren_engine(sql=sql, **self.engine_info)

df = pd.DataFrame(**response)
sorted_columns = sorted(df.columns)
Expand Down Expand Up @@ -145,7 +145,7 @@ async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):

self.score = self._count_partial_matches(expected_dataset, actual_dataset)
# use llm to check sql semantics
if self.score == 0 and self._enable_semantics_comparison:
if self.score == 0 and self.enable_semantics_comparison:
# TODO: we may need to upload the sql semantics result to langfuse
print(f"before _check_sql_semantics: {self.score}")
print(f"expected sql: {rewritten_expected_output}")
Expand Down
8 changes: 4 additions & 4 deletions wren-ai-service/eval/metrics/answer_relevancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@


class AnswerRelevancyMetric(BaseMetric):
def __init__(self, engine_config: dict):
def __init__(self, engine_info: dict):
self.threshold = 0
self.score = 0
self._engine_config = engine_config
self.engine_info = engine_info

def measure(self, test_case: LLMTestCase):
return asyncio.run(self.a_measure(test_case))

async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):
actual_units = await get_contexts_from_sql(
sql=test_case.actual_output, **self._engine_config
sql=test_case.actual_output, **self.engine_info
)

expected_units = await get_contexts_from_sql(
sql=test_case.expected_output, **self._engine_config
sql=test_case.expected_output, **self.engine_info
)

intersection = set(actual_units) & set(expected_units)
Expand Down
6 changes: 3 additions & 3 deletions wren-ai-service/eval/metrics/context_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@


class ContextualRecallMetric(BaseMetric):
def __init__(self, engine_config: dict):
def __init__(self, engine_info: dict):
self.threshold = 0
self.score = 0
self._engine_config = engine_config
self.engine_info = engine_info

def measure(self, test_case: LLMTestCase):
return asyncio.run(self.a_measure(test_case))

async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):
expected_units = await get_contexts_from_sql(
sql=test_case.expected_output, **self._engine_config
sql=test_case.expected_output, **self.engine_info
)

intersection = set(test_case.retrieval_context) & set(expected_units)
Expand Down
6 changes: 3 additions & 3 deletions wren-ai-service/eval/metrics/faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@


class FaithfulnessMetric(BaseMetric):
def __init__(self, engine_config: dict):
def __init__(self, engine_info: dict):
self.threshold = 0
self.score = 0
self._engine_config = engine_config
self.engine_info = engine_info

def measure(self, test_case: LLMTestCase):
return asyncio.run(self.a_measure(test_case))

async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):
actual_units = await get_contexts_from_sql(
sql=test_case.actual_output, **self._engine_config
sql=test_case.actual_output, **self.engine_info
)
intersection = set(actual_units) & set(test_case.retrieval_context)
self.score = len(intersection) / len(actual_units)
Expand Down
Loading

0 comments on commit cfd8367

Please sign in to comment.