Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vectara to provenance LLM #9

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 105 additions & 21 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import os
import re
import itertools
import warnings
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from warnings import warn
from transformers import pipeline, AutoTokenizer
import torch

import nltk
import numpy as np
from guardrails.utils.docs_utils import get_chunks_from_text
from guardrails.utils.validator_utils import PROVENANCE_V1_PROMPT
from guardrails.validator_base import (
FailResult,
PassResult,
Expand All @@ -21,6 +23,47 @@
from tenacity import retry, stop_after_attempt, wait_random_exponential
from sentence_transformers import SentenceTransformer

PROVENANCE_V1_PROMPT = """Instruction:
As an Attribution Validator, your task is to determine if the given contexts provide irrefutable evidence to support the claim. Follow these strict guidelines:

Respond "Yes" ONLY if ALL of the following conditions are met:

The contexts explicitly and unambiguously state information that fully confirms ALL aspects of the claim.
There is NO room for alternative interpretations or assumptions.
The support is direct and doesn't require complex inference chains.
If numbers or specific details are mentioned in the claim, they MUST be exactly matched in the contexts.


Respond "No" if ANY of the following are true:

The contexts do not provide explicit information that fully substantiates every part of the claim.
The claim requires any degree of inference or assumption not directly stated in the contexts.
The contexts only partially support the claim or support it with slight differences in details.
There is any ambiguity, vagueness, or room for interpretation in how the contexts relate to the claim.
The claim includes any information not present in the contexts, even if it seems common knowledge.
The contexts contradict any part of the claim, no matter how minor.


Treat the contexts as the ONLY source of truth. Do not use any outside knowledge or assumptions.
For multi-part claims, EVERY single part must be explicitly supported by the contexts for a "Yes" response.
If there is ANY doubt whatsoever, respond with "No".
Be extremely literal in your interpretation. Do not extrapolate or generalize from the given information.

Provide your analysis in this format:
<reasoning>

Point 1
Point 2
Point 3 (if needed)
</reasoning>


<decision>Yes</decision> OR <decision>No</decision>
Claim:
{}
Contexts:
{}
Response:"""

@register_validator(name="guardrails/provenance_llm", data_type="string")
class ProvenanceLLM(Validator):
Expand Down Expand Up @@ -125,8 +168,32 @@ def call_llm(self, prompt: str) -> str:
response (str): String representing the LLM response.
"""
return self._llm_callable(prompt)

def evaluate_with_vectara(self, text:str, pass_on_invalid:bool) -> bool:
classifier = pipeline(
"text-classification",
model="vectara/hallucination_evaluation_model",
tokenizer=AutoTokenizer.from_pretrained("google/flan-t5-base"),
trust_remote_code=True,
device="cpu" if torch.cuda.is_available() else "cpu",
)
result = classifier(text, batch_size=1)
if result[0]['label'] == 'consistent':
return True
if result[0]['label'] == 'hallucinated':
return False
nichwch marked this conversation as resolved.
Show resolved Hide resolved
if pass_on_invalid:
warn(
"The Vectara returned an invalid response. Considering the sentence as supported..."
)
return True
else:
warn(
"The Vectara returned an invalid response. Considering the sentence as unsupported..."
)
return False

def evaluate_with_llm(self, text: str, query_function: Callable) -> str:
def evaluate_with_llm(self, text: str, query_function: Callable, pass_on_invalid: bool) -> bool:
"""Validate that the LLM-generated text is supported by the provided
contexts.

Expand All @@ -145,35 +212,28 @@ def evaluate_with_llm(self, text: str, query_function: Callable) -> str:

# Get evaluation response
eval_response = self.call_llm(prompt)
return eval_response
return self.parse_response(eval_response, pass_on_invalid=pass_on_invalid)

def validate_each_sentence(
self, value: Any, query_function: Callable, metadata: Dict[str, Any]
) -> ValidationResult:
"""Validate each sentence in the response."""
pass_on_invalid = metadata.get("pass_on_invalid", False) # Default to False
use_vectara = metadata.get("use_vectara", False)

# Split the value into sentences using nltk sentence tokenizer.
sentences = nltk.sent_tokenize(value)

unsupported_sentences, supported_sentences = [], []
for sentence in sentences:
eval_response = self.evaluate_with_llm(sentence, query_function)
if eval_response == "yes":
if use_vectara:
eval_response = self.evaluate_with_vectara(sentence, pass_on_invalid=pass_on_invalid)
else:
eval_response = self.evaluate_with_llm(sentence, query_function, pass_on_invalid=pass_on_invalid)
if eval_response == True:
supported_sentences.append(sentence)
elif eval_response == "no":
elif eval_response == False:
unsupported_sentences.append(sentence)
else:
if pass_on_invalid:
warn(
"The LLM returned an invalid response. Considering the sentence as supported..."
)
supported_sentences.append(sentence)
else:
warn(
"The LLM returned an invalid response. Considering the sentence as unsupported..."
)
unsupported_sentences.append(sentence)

if unsupported_sentences:
unsupported_sentences = "- " + "\n- ".join(unsupported_sentences)
Expand All @@ -187,18 +247,42 @@ def validate_each_sentence(
fix_value="\n".join(supported_sentences),
)
return PassResult(metadata=metadata)

def parse_response(self, response:str, pass_on_invalid:bool) -> bool:
response = response.lower()
# Extract decision
decision_match = re.search(r'<decision>(yes|no)</decision>', response)
decision = decision_match.group(1) if decision_match else None
if decision is None or decision == 'no':
return False
elif decision == 'yes':
return True
else:
if pass_on_invalid:
warn(
"The LLM returned an invalid response. Considering the sentence as supported..."
)
return True
else:
warn(
"The LLM returned an invalid response. Considering the sentence as unsupported..."
)
return False

def validate_full_text(
self, value: Any, query_function: Callable, metadata: Dict[str, Any]
) -> ValidationResult:
"""Validate the entire LLM text."""
pass_on_invalid = metadata.get("pass_on_invalid", False) # Default to False

use_vectara = metadata.get("use_vectara", False)
# Self-evaluate LLM with entire text
eval_response = self.evaluate_with_llm(value, query_function)
if eval_response == "yes":
if use_vectara:
passed = self.evaluate_with_vectara(value, pass_on_invalid=pass_on_invalid)
else:
passed = self.evaluate_with_llm(value, query_function, pass_on_invalid=pass_on_invalid)
if passed == True:
return PassResult(metadata=metadata)
if eval_response == "no":
if passed == False:
return FailResult(
metadata=metadata,
error_message=(
Expand Down