Skip to content

Commit

Permalink
add competitor check validation (#474)
Browse files Browse the repository at this point in the history
* add competitor check validation

* lint

* lint fix continued

* update deps

* typing

* 3.8 typing

* comp check tests, notebook

* do not run comp check nb in workflow

---------

Co-authored-by: zsimjee <[email protected]>
  • Loading branch information
ShreyaR and zsimjee authored Nov 30, 2023
1 parent 9cdd821 commit ff395e8
Show file tree
Hide file tree
Showing 10 changed files with 588 additions and 128 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/scripts/run_notebooks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cd docs/examples
# Function to process a notebook
process_notebook() {
notebook="$1"
if [ "$notebook" != "valid_chess_moves.ipynb" ] && [ "$notebook" != "translation_with_quality_check.ipynb" ]; then
if [ "$notebook" != "valid_chess_moves.ipynb" ] && [ "$notebook" != "translation_with_quality_check.ipynb" ] && [ "$notebook" != "competitors_check.ipynb" ]; then
echo "Processing $notebook..."
poetry run jupyter nbconvert --to notebook --execute "$notebook"
if [ $? -ne 0 ]; then
Expand Down
240 changes: 240 additions & 0 deletions docs/examples/competitors_check.ipynb

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions guardrails/utils/openai_utils/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def construct_nonchat_response(
) -> LLMResponse:
"""Construct an LLMResponse from an OpenAI response.
Splits execution based on whether the `stream` parameter
is set in the kwargs.
Splits execution based on whether the `stream` parameter is set
in the kwargs.
"""
if stream:
# If stream is defined and set to True,
Expand Down Expand Up @@ -152,8 +152,8 @@ def construct_chat_response(
) -> LLMResponse:
"""Construct an LLMResponse from an OpenAI response.
Splits execution based on whether the `stream` parameter
is set in the kwargs.
Splits execution based on whether the `stream` parameter is set
in the kwargs.
"""
if stream:
# If stream is defined and set to True,
Expand Down Expand Up @@ -296,8 +296,8 @@ async def construct_chat_response(
) -> LLMResponse:
"""Construct an LLMResponse from an OpenAI response.
Splits execution based on whether the `stream` parameter
is set in the kwargs.
Splits execution based on whether the `stream` parameter is set
in the kwargs.
"""
if stream:
# If stream is defined and set to True,
Expand Down
12 changes: 6 additions & 6 deletions guardrails/utils/openai_utils/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def construct_nonchat_response(
) -> LLMResponse:
"""Construct an LLMResponse from an OpenAI response.
Splits execution based on whether the `stream` parameter
is set in the kwargs.
Splits execution based on whether the `stream` parameter is set
in the kwargs.
"""
if stream:
# If stream is defined and set to True,
Expand Down Expand Up @@ -140,8 +140,8 @@ def construct_chat_response(
) -> LLMResponse:
"""Construct an LLMResponse from an OpenAI response.
Splits execution based on whether the `stream` parameter
is set in the kwargs.
Splits execution based on whether the `stream` parameter is set
in the kwargs.
"""
if stream:
# If stream is defined and set to True,
Expand Down Expand Up @@ -298,8 +298,8 @@ async def construct_chat_response(
) -> LLMResponse:
"""Construct an LLMResponse from an OpenAI response.
Splits execution based on whether the `stream` parameter
is set in the kwargs.
Splits execution based on whether the `stream` parameter is set
in the kwargs.
"""
if stream:
# If stream is defined and set to True,
Expand Down
166 changes: 159 additions & 7 deletions guardrails/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@
except LookupError:
nltk.download("punkt")

try:
import spacy
except ImportError:
spacy = None


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -703,7 +708,7 @@ def __init__(
if not _HAS_NUMPY:
raise ImportError(
f"The {self.__class__.__name__} validator requires the numpy package.\n"
"`pip install numpy` to install it."
"`poetry add numpy` to install it."
)

self.client = OpenAIClient()
Expand Down Expand Up @@ -775,7 +780,7 @@ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
except ImportError:
raise ImportError(
"`is-profanity-free` validator requires the `alt-profanity-check`"
"package. Please install it with `pip install profanity-check`."
"package. Please install it with `poetry add profanity-check`."
)

prediction = predict([value])
Expand Down Expand Up @@ -823,7 +828,7 @@ def __init__(self, *args, **kwargs):
except ImportError:
raise ImportError(
"`is-high-quality-translation` validator requires the `inspiredco`"
"package. Please install it with `pip install inspiredco`."
"package. Please install it with `poetry add inspiredco`."
)

def validate(self, value: Any, metadata: Dict) -> ValidationResult:
Expand Down Expand Up @@ -1122,7 +1127,7 @@ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
except ImportError:
raise ImportError(
"`thefuzz` library is required for `extractive-summary` validator. "
"Please install it with `pip install thefuzz`."
"Please install it with `poetry add thefuzz`."
)

# Split the value into sentences.
Expand Down Expand Up @@ -1217,7 +1222,7 @@ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
except ImportError:
raise ImportError(
"`thefuzz` library is required for `remove-redundant-sentences` "
"validator. Please install it with `pip install thefuzz`."
"validator. Please install it with `poetry add thefuzz`."
)

# Split the value into sentences.
Expand Down Expand Up @@ -1613,7 +1618,7 @@ def validate_each_sentence(
if nltk is None:
raise ImportError(
"`nltk` library is required for `provenance-v0` validator. "
"Please install it with `pip install nltk`."
"Please install it with `poetry add nltk`."
)
# Split the value into sentences using nltk sentence tokenizer.
sentences = nltk.sent_tokenize(value)
Expand Down Expand Up @@ -1973,7 +1978,7 @@ def validate_each_sentence(
if nltk is None:
raise ImportError(
"`nltk` library is required for `provenance-v0` validator. "
"Please install it with `pip install nltk`."
"Please install it with `poetry add nltk`."
)
# Split the value into sentences using nltk sentence tokenizer.
sentences = nltk.sent_tokenize(value)
Expand Down Expand Up @@ -2535,3 +2540,150 @@ def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult:
fix_value=modified_value,
)
return PassResult()


@register_validator(name="competitor-check", data_type="string")
class CompetitorCheck(Validator):
"""Validates that LLM-generated text is not naming any competitors from a
given list.
In order to use this validator you need to provide an extensive list of the
competitors you want to avoid naming including all common variations.
Args:
competitors (List[str]): List of competitors you want to avoid naming
"""

def __init__(
self,
competitors: List[str],
on_fail: Optional[Callable] = None,
):
super().__init__(competitors=competitors, on_fail=on_fail)
self._competitors = competitors
model = "en_core_web_trf"
if spacy is None:
raise ImportError(
"You must install spacy in order to use the CompetitorCheck validator."
)

if not spacy.util.is_package(model):
logger.info(
f"Spacy model {model} not installed. "
"Download should start now and take a few minutes."
)
spacy.cli.download(model) # type: ignore

self.nlp = spacy.load(model)

def exact_match(self, text: str, competitors: List[str]) -> List[str]:
"""Performs exact match to find competitors from a list in a given
text.
Args:
text (str): The text to search for competitors.
competitors (list): A list of competitor entities to match.
Returns:
list: A list of matched entities.
"""

found_entities = []
for entity in competitors:
pattern = rf"\b{re.escape(entity)}\b"
match = re.search(pattern.lower(), text.lower())
if match:
found_entities.append(entity)
return found_entities

def perform_ner(self, text: str, nlp) -> List[str]:
"""Performs named entity recognition on text using a provided NLP
model.
Args:
text (str): The text to perform named entity recognition on.
nlp: The NLP model to use for entity recognition.
Returns:
entities: A list of entities found.
"""

doc = nlp(text)
entities = []
for ent in doc.ents:
entities.append(ent.text)
return entities

def is_entity_in_list(self, entities: List[str], competitors: List[str]) -> List:
"""Checks if any entity from a list is present in a given list of
competitors.
Args:
entities (list): A list of entities to check
competitors (list): A list of competitor names to match
Returns:
List: List of found competitors
"""

found_competitors = []
for entity in entities:
for item in competitors:
pattern = rf"\b{re.escape(item)}\b"
match = re.search(pattern.lower(), entity.lower())
if match:
found_competitors.append(item)
return found_competitors

def validate(self, value: str, metadata=Dict) -> ValidationResult:
"""Checks a text to find competitors' names in it.
While running, store sentences naming competitors and generate a fixed output
filtering out all flagged sentences.
Args:
value (str): The value to be validated.
metadata (Dict, optional): Additional metadata. Defaults to empty dict.
Returns:
ValidationResult: The validation result.
"""

if nltk is None:
raise ImportError(
"`nltk` library is required for `competitors-check` validator. "
"Please install it with `poetry add nltk`."
)
sentences = nltk.sent_tokenize(value)
flagged_sentences = []
filtered_sentences = []
list_of_competitors_found = []

for sentence in sentences:
entities = self.exact_match(sentence, self._competitors)
if entities:
ner_entities = self.perform_ner(sentence, self.nlp)
found_competitors = self.is_entity_in_list(ner_entities, entities)

if found_competitors:
flagged_sentences.append((found_competitors, sentence))
list_of_competitors_found.append(found_competitors)
logger.debug(f"Found: {found_competitors} named in '{sentence}'")
else:
filtered_sentences.append(sentence)

else:
filtered_sentences.append(sentence)

filtered_output = " ".join(filtered_sentences)

if len(flagged_sentences):
return FailResult(
error_message=(
f"Found the following competitors: {list_of_competitors_found}. "
"Please avoid naming those competitors next time"
),
fix_value=filtered_output,
)
else:
return PassResult()
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ nav:
- 'Check key info present in generated summary': examples/text_summarization_quality.ipynb
- 'Detect and limit hallucinations in generated text': examples/provenance.ipynb
- 'Check whether a value is similar to a set of other values': examples/value_within_distribution.ipynb
- 'Check if a competitor is named': examples/competitors_check.ipynb
- 'Integrations':
- 'Azure OpenAI': integrations/azure_openai.ipynb
- 'OpenAI Functions': integrations/openai_functions.ipynb
Expand Down
Loading

0 comments on commit ff395e8

Please sign in to comment.