Skip to content

Commit

Permalink
Merge pull request #3 from guardrails-ai/jc/add_3_9_compatibility
Browse files Browse the repository at this point in the history
Add some imports and remove some style elements to make the validator _maybe_ compatible with Python 3.9.
  • Loading branch information
JosephCatrambone authored Nov 25, 2024
2 parents e48835e + fb8a6b4 commit 23026d9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
20 changes: 10 additions & 10 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Callable, Optional, Union
from typing import Callable, List, Optional, Union

import torch
from torch.nn import functional as F
Expand Down Expand Up @@ -140,7 +140,7 @@ def _mean_pool(model_output, attention_mask):
input_mask_expanded.sum(1), min=1e-9
)

def _embed(self, prompts: list[str]):
def _embed(self, prompts: List[str]):
"""Taken from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
We use the long-form to avoid a dependency on sentence transformers.
This method returns the maximum of the matches against all known attacks.
Expand All @@ -160,8 +160,8 @@ def _embed(self, prompts: list[str]):

def _match_known_malicious_prompts(
self,
prompts: list[str] | torch.Tensor,
) -> list[float]:
prompts: Union[List[str], torch.Tensor],
) -> List[float]:
"""Returns an array of floats, one per prompt, with the max match to known
attacks. If prompts is a list of strings, embeddings will be generated. If
embeddings are passed, they will be used."""
Expand All @@ -179,7 +179,7 @@ def _match_known_malicious_prompts(
def _predict_and_remap(
self,
model,
prompts: list[str],
prompts: List[str],
label_field: str,
score_field: str,
safe_case: str,
Expand All @@ -199,7 +199,7 @@ def _predict_and_remap(
scores.append(new_score)
return scores

def _predict_jailbreak(self, prompts: list[str]) -> list[float]:
def _predict_jailbreak(self, prompts: List[str]) -> List[float]:
return [
DetectJailbreak._rescale(s, *self.text_attack_scales)
for s in self._predict_and_remap(
Expand All @@ -212,7 +212,7 @@ def _predict_jailbreak(self, prompts: list[str]) -> list[float]:
)
]

def _predict_saturation(self, prompts: list[str]) -> list[float]:
def _predict_saturation(self, prompts: List[str]) -> List[float]:
return [
DetectJailbreak._rescale(
s,
Expand All @@ -230,9 +230,9 @@ def _predict_saturation(self, prompts: list[str]) -> list[float]:

def predict_jailbreak(
self,
prompts: list[str],
prompts: List[str],
reduction_function: Optional[Callable] = max,
) -> Union[list[float], list[dict]]:
) -> Union[List[float], List[dict]]:
if isinstance(prompts, str):
print("WARN: predict_jailbreak should be called with a list of strings.")
prompts = [prompts, ]
Expand All @@ -256,7 +256,7 @@ def predict_jailbreak(

def validate(
self,
value: Union[str, list[str]],
value: Union[str, List[str]],
metadata: Optional[dict] = None,
) -> ValidationResult:
"""Validates that will return a failure if the value is a jailbreak attempt.
Expand Down
10 changes: 5 additions & 5 deletions validator/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import List, Tuple, Optional, Union

import numpy
import torch
Expand All @@ -8,7 +8,7 @@


def string_to_one_hot_tensor(
text: Union[str, list[str], tuple[str]],
text: Union[str, List[str], Tuple[str]],
max_length: int = 2048,
left_truncate: bool = True,
) -> torch.Tensor:
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_current_device(self):

def forward(
self,
x: Union[str, list[str], numpy.ndarray, torch.Tensor]
x: Union[str, List[str], numpy.ndarray, torch.Tensor]
) -> torch.Tensor:
if isinstance(x, str) or isinstance(x, list) or isinstance(x, tuple):
x = string_to_one_hot_tensor(x).to(self.get_current_device())
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_current_device(self):

def forward(
self,
x: Union[str, list[str], numpy.ndarray, torch.Tensor],
x: Union[str, List[str], numpy.ndarray, torch.Tensor],
y: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -209,5 +209,5 @@ def __init__(
device=device,
)

def __call__(self, text: Union[str, list[str]]) -> list[dict]:
def __call__(self, text: Union[str, List[str]]) -> List[dict]:
return self.pipe(text)

0 comments on commit 23026d9

Please sign in to comment.