Skip to content

Commit

Permalink
Refactor rag metrics and judges (#1515)
Browse files Browse the repository at this point in the history
  • Loading branch information
lilacheden authored Jan 19, 2025
1 parent b66c23f commit 6dcf08e
Show file tree
Hide file tree
Showing 190 changed files with 2,851 additions and 471 deletions.
22 changes: 8 additions & 14 deletions examples/evaluate_external_rag_results_with_binary_llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,17 @@

# Select the desired metric(s).
# Each metric measures a certain aspect of the generated answer (answer_correctness, faithfulness,
# answer_relevance, context_relevance and correctness_holistic).
# All available metrics are under "catalog.metrics.rag"
# Those with extension "logprobs" provide a real value prediction in [0,1], the others provide a binary prediction.
# By default, all judges use llama_3_1_70b_instruct_wml. We will soon see how to change this.
# answer_relevance and context_relevance).
# All available metrics are under "catalog.metrics.rag.autorag.", ending with "judge"
# By default, all judges use llama_3_3_70b_instruct. We will soon see how to change this.
metric_names = [
"metrics.rag.answer_correctness.llama_3_1_70b_instruct_wml_q_a_gt_loose_logprobs",
"metrics.rag.faithfulness.llama_3_1_70b_instruct_wml_q_c_a_logprobs",
"metrics.rag.autorag.answer_correctness.llama_3_3_70b_instruct_wml_judge",
"metrics.rag.autorag.faithfulness.llama_3_3_70b_instruct_wml_judge",
]

# select the desired model.
# all available models are under "catalog.engines.classification"
model_names = [
"engines.classification.mixtral_8x7b_instruct_v01_wml",
"engines.classification.llama_3_1_70b_instruct_wml",
# "engines.classification.gpt_4_turbo_openai",
]
model_names = ["engines.classification.mixtral_8x7b_instruct_v01_wml"]

if __name__ == "__main__":
multi_stream = MultiStream.from_iterables({"test": test_examples}, copying=True)
Expand All @@ -79,9 +74,8 @@

for metric_name in metric_names:
for model_name in model_names:
# override the metric with the inference model. the default model is llama_3_1_70b_instruct_wml so
# no need to override when using it.
llmaj_metric_name = f"{metric_name}[model={model_name}]"
# override the metric with the inference model (to use a model different from the one in the metric name)
llmaj_metric_name = f"{metric_name}[inference_model={model_name}]"

# apply the metric over the input
metrics_operator = SequentialOperator(steps=[llmaj_metric_name])
Expand Down
10 changes: 10 additions & 0 deletions examples/evaluate_rag_end_to_end_dataset_with_given_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,21 @@
},
]

# select recommended metrics according to your available resources.
metrics = [
"metrics.rag.end_to_end.recommended.cpu_only.all",
# "metrics.rag.end_to_end.recommended.small_llm.all",
# "metrics.rag.end_to_end.recommended.llmaj_watsonx.all",
# "metrics.rag.end_to_end.recommended.llmaj_rits.all"
# "metrics.rag.end_to_end.recommended.llmaj_azure.all"
]

dataset = create_dataset(
task="tasks.rag.end_to_end",
test_set=dataset,
split="test",
postprocessors=[],
metrics=metrics,
)

results = evaluate(predictions, dataset)
Expand Down
10 changes: 10 additions & 0 deletions examples/evaluate_rag_response_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,23 @@
),
)

# select recommended metrics according to your available resources.
metrics = [
"metrics.rag.response_generation.recommended.cpu_only.all",
# "metrics.rag.response_generation.recommended.small_llm.all",
# "metrics.rag.response_generation.recommended.llmaj_watsonx.all",
# "metrics.rag.response_generation.recommended.llmaj_rits.all"
# "metrics.rag.response_generation.recommended.llmaj_azure.all"
]

# Verbalize the dataset using the template
dataset = load_dataset(
card=card,
template_card_index="simple",
format="formats.chat_api",
split="test",
max_test_instances=10,
metrics=metrics,
)


Expand Down
57 changes: 29 additions & 28 deletions prepare/engines/classification/classification_engines.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,57 @@
from unitxt import add_to_catalog
from unitxt.inference import (
AzureOpenAIInferenceEngine,
IbmGenAiInferenceEngine,
RITSInferenceEngine,
CrossProviderInferenceEngine,
WMLInferenceEngineGeneration,
)

model_names_to_provider = {
"llama-3-3-70b-instruct": ["watsonx", "rits"],
"llama-3-1-70b-instruct": ["watsonx", "rits"],
"gpt-4o": ["open-ai"],
"gpt-4-turbo": ["open-ai"],
"gpt-4-turbo-2024-04-09": ["azure"],
"gpt-4o-2024-08-06": ["azure"],
"mistralai/mixtral-8x7b-instruct-v01": ["ibm_wml"],
"meta-llama/llama-3-3-70b-instruct": ["ibm_wml"],
"meta-llama/llama-3-1-70b-instruct": ["ibm_wml"],
"meta-llama/llama-3-405b-instruct": ["ibm_wml"],
"llama-3-1-405b-instruct-fp8": ["rits"],
}


def get_inference_engine(model_name, framework_name):
if framework_name == "ibm_wml":
def get_inference_engine(model_name, provider):
if provider == "ibm_wml":
return WMLInferenceEngineGeneration(
model_name=model_name,
max_new_tokens=5,
random_seed=42,
decoding_method="greedy",
)
if framework_name == "ibm_gen_ai":
return IbmGenAiInferenceEngine(
model_name=model_name,
max_new_tokens=5,
random_seed=42,
decoding_method="greedy",
)
if framework_name == "openai":

if provider == "azure":
return AzureOpenAIInferenceEngine(
model_name=model_name,
logprobs=True,
max_tokens=5,
temperature=0.0,
top_logprobs=5,
)
if framework_name == "rits":
return RITSInferenceEngine(
model_name=model_name, logprobs=True, max_tokens=5, temperature=0.0
)
raise ValueError("Unsupported framework name " + framework_name)

return CrossProviderInferenceEngine(
model=model_name,
logprobs=True,
max_tokens=5,
temperature=0.0,
top_logprobs=5,
provider=provider,
)

model_names_to_infer_framework = {
"meta-llama/llama-3-1-70b-instruct": ["ibm_wml", "rits", "ibm_gen_ai"],
"meta-llama/llama-3-3-70b-instruct": ["ibm_wml", "rits"],
"gpt-4-turbo-2024-04-09": ["openai"],
"gpt-4o-2024-08-06": ["openai"],
"mistralai/mixtral-8x7b-instruct-v01": ["ibm_wml", "ibm_gen_ai", "rits"],
"meta-llama/llama-3-1-405b-instruct-fp8": ["ibm_gen_ai", "rits"],
"meta-llama/llama-3-405b-instruct": ["ibm_wml"],
}

for judge_model_name, infer_frameworks in model_names_to_infer_framework.items():
for judge_model_name, infer_frameworks in model_names_to_provider.items():
for infer_framework in infer_frameworks:
inference_engine = get_inference_engine(judge_model_name, infer_framework)
inference_engine_label = inference_engine.get_engine_id()
inference_engine_label = inference_engine.get_engine_id().replace("-", "_")

add_to_catalog(
inference_engine,
Expand Down
26 changes: 5 additions & 21 deletions prepare/metrics/hhem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from unitxt import add_to_catalog
from unitxt.metrics import FaithfulnessHHEM, MetricPipeline
from unitxt.operators import Copy
from unitxt.metrics import FaithfulnessHHEM
from unitxt.test_utils.metrics import test_metric

pairs = [
Expand All @@ -12,21 +11,6 @@
predictions = [p[1] for p in pairs]
task_data = [{"contexts": [p[0]]} for p in pairs]

## This metric pipeline supports two usecases:
## 1. Regular unitxt flow: predictions are taken from model prediction and contexts appears in the task data
## 2. Running on external rag output: each instance contains field "answer" and field "contexts"
metric = MetricPipeline(
main_score="hhem_score",
preprocess_steps=[
Copy(
field_to_field={"task_data/contexts": "references", "answer": "prediction"},
not_exist_do_nothing=True,
),
Copy(field_to_field={"contexts": "references"}, not_exist_do_nothing=True),
],
metric=FaithfulnessHHEM(),
__description__="Vectara's halucination detection model, HHEM2.1, compares contexts and generated answer to determine faithfulness.",
)
instance_targets = [
{"score": 0.01, "score_name": "hhem_score", "hhem_score": 0.01},
{"score": 0.65, "score_name": "hhem_score", "hhem_score": 0.65},
Expand All @@ -43,13 +27,13 @@
"hhem_score_ci_high": 0.65,
}


references = [[p[0]] for p in pairs]
metric = FaithfulnessHHEM()
outputs = test_metric(
metric=metric,
predictions=predictions,
references=[[""]] * len(instance_targets),
task_data=task_data,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)
add_to_catalog(metric, "metrics.rag.faithfulness.vectara_hhem_2_1", overwrite=True)
add_to_catalog(metric, "metrics.vectara_groundedness_hhem_2_1", overwrite=True)
67 changes: 58 additions & 9 deletions prepare/metrics/llm_as_judge/rag_judge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from unitxt import add_to_catalog
from unitxt.artifact import UnitxtArtifactNotFoundError, fetch_artifact
from unitxt.inference import GenericInferenceEngine
from unitxt.llm_as_judge import (
TaskBasedLLMasJudge,
Expand Down Expand Up @@ -31,12 +30,12 @@


def get_prediction_field(metric_type):
return None if metric_type == "context_relevance" else "answer"
return "contexts" if metric_type == "context_relevance" else "answer"


for metric_type, template_dict in metric_type_to_template_dict.items():
for template_short_name, template_name in template_dict.items():
task_name = f"tasks.rag_eval.{metric_type}.binary"
judge_task_name = f"tasks.rag_eval.{metric_type}.binary"
for logprobs_label in [
"",
"_logprobs",
Expand All @@ -46,10 +45,7 @@ def get_prediction_field(metric_type):
template = (
f"templates.rag_eval.{metric_type}.{template_name}{logprobs_label}"
)
try:
t = fetch_artifact(template)[0]
except UnitxtArtifactNotFoundError:
continue

for inf_label, inference_model in inference_models.items():
if (
use_logprobs and inf_label == generic_engine_label
Expand All @@ -60,7 +56,7 @@ def get_prediction_field(metric_type):
metric = TaskBasedLLMasJudge(
inference_model=inference_model,
template=template,
task=task_name,
task=judge_task_name,
format=None,
main_score=metric_label,
prediction_field=get_prediction_field(metric_type),
Expand All @@ -79,7 +75,7 @@ def get_prediction_field(metric_type):
metric = TaskBasedLLMasJudge(
inference_model=inference_model,
template=template,
task=task_name,
task=judge_task_name,
format=None,
main_score=metric_label,
prediction_field=get_prediction_field(metric_type),
Expand All @@ -92,3 +88,56 @@ def get_prediction_field(metric_type):
f"metrics.llm_as_judge.binary.{inf_label}_{metric_label}",
overwrite=True,
)


# now add new metrics under unitxt rag tasks
metric_type_to_template_v2 = {
"faithfulness": "judge_with_question_simplified",
"context_relevance": "judge_context_relevance_ares",
"answer_correctness": "judge_loose_match_no_context",
"answer_relevance": "judge_answer_relevance",
}

inference_models_v2 = {
"llama_3_3_70b_instruct_watsonx": "engines.classification.llama_3_3_70b_instruct_watsonx",
"llama_3_3_70b_instruct_rits": "engines.classification.llama_3_3_70b_instruct_rits",
"gpt_4o_azure": "engines.classification.gpt_4o_2024_08_06_azure_openai",
generic_engine_label: GenericInferenceEngine(),
}

for metric_type, template_name in metric_type_to_template_v2.items():
judge_task_name = f"tasks.rag_eval.{metric_type}.binary"
realization_sufffix = metric_type_to_realization[metric_type]
template = f"templates.rag_eval.{metric_type}.{template_name}{realization_sufffix}"
for inf_label, inference_model in inference_models_v2.items():
for rag_unitxt_task in ["external_rag", "response_generation", "end_to_end"]:
if (
rag_unitxt_task == "response_generation"
and metric_type == "context_relevance"
):
continue

judge_to_generator_fields_mapping = (
{}
if rag_unitxt_task == "external_rag"
else {"ground_truths": "reference_answers"}
)

new_catalog_name = (
f"metrics.rag.{rag_unitxt_task}.{metric_type}.{inf_label}_judge"
)
metric = TaskBasedLLMasJudge(
inference_model=inference_model,
template=template,
task=judge_task_name,
format=None,
main_score=f"{metric_type}_judge",
prediction_field=get_prediction_field(metric_type),
infer_log_probs=False,
judge_to_generator_fields_mapping=judge_to_generator_fields_mapping,
)
add_to_catalog(
metric,
new_catalog_name,
overwrite=True,
)
27 changes: 21 additions & 6 deletions prepare/metrics/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,24 @@
# metrics.rag.recall
# metrics.rag.bert_recall

for axis, base_metric, main_score in [
("correctness", "token_overlap", "f1"),
("correctness", "bert_score.deberta_large_mnli", "recall"),
("correctness", "bert_score.deberta_v3_base_mnli_xnli_ml", "recall"),
("faithfullness", "token_overlap", "precision"),
for axis, base_metric, main_score, new_metric in [
("correctness", "token_overlap", "f1", "answer_correctness.token_recall"),
(
"correctness",
"bert_score.deberta_large_mnli",
"recall",
"answer_correctness.bert_score_recall",
),
(
"correctness",
"bert_score.deberta_v3_base_mnli_xnli_ml",
"recall",
"answer_correctness.bert_score_recall_ml",
),
("faithfullness", "token_overlap", "precision", "faithfulness.token_k_precision"),
]:
deprecated_path = f"metrics.rag.response_generation.{axis}.{base_metric}"
new_metric_path = f"metrics.rag.response_generation.{new_metric}"
preprocess_steps = (
[
Copy(field="task_data/contexts", to_field="references"),
Expand All @@ -379,10 +391,13 @@
],
metric=f"metrics.{base_metric}",
prediction_type=str,
__deprecated_msg__=f"Metric {deprecated_path} is deprecated. Please use {new_metric_path} instead.",
)

add_to_catalog(
metric, f"metrics.rag.response_generation.{axis}.{base_metric}", overwrite=True
metric,
f"metrics.rag.response_generation.{axis}.{base_metric}",
overwrite=True,
)

# end to end
Expand Down
Loading

0 comments on commit 6dcf08e

Please sign in to comment.