Skip to content

Commit

Permalink
Allow custom template class
Browse files Browse the repository at this point in the history
  • Loading branch information
giovcandido authored Feb 12, 2025
1 parent 5147469 commit 6ca87af
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions deepeval/metrics/faithfulness/faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
strict_mode: bool = False,
verbose_mode: bool = False,
truths_extraction_limit: Optional[int] = None,
custom_template_class: class = None
):
self.threshold = 1 if strict_mode else threshold
self.model, self.using_native_model = initialize_model(model)
Expand All @@ -50,6 +51,7 @@ def __init__(
self.async_mode = async_mode
self.strict_mode = strict_mode
self.verbose_mode = verbose_mode
self.template_class = custom_template_class if custom_template_class else FaithfulnessTemplate

self.truths_extraction_limit = truths_extraction_limit
if self.truths_extraction_limit is not None:
Expand Down Expand Up @@ -132,7 +134,7 @@ async def _a_generate_reason(self) -> str:
if verdict.verdict.strip().lower() == "no":
contradictions.append(verdict.reason)

prompt: dict = FaithfulnessTemplate.generate_reason(
prompt: dict = self.template_class.generate_reason(
contradictions=contradictions,
score=format(self.score, ".2f"),
)
Expand All @@ -159,7 +161,7 @@ def _generate_reason(self) -> str:
if verdict.verdict.strip().lower() == "no":
contradictions.append(verdict.reason)

prompt: dict = FaithfulnessTemplate.generate_reason(
prompt: dict = self.template_class.generate_reason(
contradictions=contradictions,
score=format(self.score, ".2f"),
)
Expand All @@ -182,7 +184,7 @@ async def _a_generate_verdicts(self) -> List[FaithfulnessVerdict]:
return []

verdicts: List[FaithfulnessVerdict] = []
prompt = FaithfulnessTemplate.generate_verdicts(
prompt = self.template_class.generate_verdicts(
claims=self.claims, retrieval_context="\n\n".join(self.truths)
)
if self.using_native_model:
Expand Down Expand Up @@ -210,7 +212,7 @@ def _generate_verdicts(self) -> List[FaithfulnessVerdict]:
return []

verdicts: List[FaithfulnessVerdict] = []
prompt = FaithfulnessTemplate.generate_verdicts(
prompt = self.template_class.generate_verdicts(
claims=self.claims, retrieval_context="\n\n".join(self.truths)
)
if self.using_native_model:
Expand All @@ -232,7 +234,7 @@ def _generate_verdicts(self) -> List[FaithfulnessVerdict]:
return verdicts

async def _a_generate_truths(self, retrieval_context: str) -> List[str]:
prompt = FaithfulnessTemplate.generate_truths(
prompt = self.template_class.generate_truths(
text="\n\n".join(retrieval_context),
extraction_limit=self.truths_extraction_limit,
)
Expand All @@ -250,7 +252,7 @@ async def _a_generate_truths(self, retrieval_context: str) -> List[str]:
return data["truths"]

def _generate_truths(self, retrieval_context: str) -> List[str]:
prompt = FaithfulnessTemplate.generate_truths(
prompt = self.template_class.generate_truths(
text="\n\n".join(retrieval_context),
extraction_limit=self.truths_extraction_limit,
)
Expand All @@ -268,7 +270,7 @@ def _generate_truths(self, retrieval_context: str) -> List[str]:
return data["truths"]

async def _a_generate_claims(self, actual_output: str) -> List[str]:
prompt = FaithfulnessTemplate.generate_claims(text=actual_output)
prompt = self.template_class.generate_claims(text=actual_output)
if self.using_native_model:
res, cost = await self.model.a_generate(prompt, schema=Claims)
self.evaluation_cost += cost
Expand All @@ -283,7 +285,7 @@ async def _a_generate_claims(self, actual_output: str) -> List[str]:
return data["claims"]

def _generate_claims(self, actual_output: str) -> List[str]:
prompt = FaithfulnessTemplate.generate_claims(text=actual_output)
prompt = self.template_class.generate_claims(text=actual_output)
if self.using_native_model:
res, cost = self.model.generate(prompt, schema=Claims)
self.evaluation_cost += cost
Expand Down

0 comments on commit 6ca87af

Please sign in to comment.