Skip to content

Commit

Permalink
Merge pull request #4 from guardrails-ai/jc/new_remote_inference
Browse files Browse the repository at this point in the history
Change validator to respect local/remote inference.
  • Loading branch information
AlejandroEsquivel authored Nov 27, 2024
2 parents f63ccee + b34b137 commit 789a2da
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 83 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/publish_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ name: Publish to Guardrails Hub
on:
workflow_dispatch:
push:
branches:
- main
# Publish when new releases are tagged.
tags:
- '*'

jobs:
setup:
Expand Down
39 changes: 7 additions & 32 deletions app_inference_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# Forked from spec:
# github.com/guardrails-ai/models-host/tree/main/ray#adding-new-inference-endpoints
import os
from typing import Optional
from logging import getLogger
from typing import List

from fastapi import HTTPException
from pydantic import BaseModel
from models_host.base_inference_spec import BaseInferenceSpec

Expand All @@ -22,14 +21,11 @@


class InputRequest(BaseModel):
message: str
threshold: Optional[float] = None
prompts: List[str]


class OutputResponse(BaseModel):
classification: str
score: float
is_jailbreak: bool
scores: List[float]


# Using same nomenclature as in Sagemaker classes
Expand Down Expand Up @@ -64,35 +60,14 @@ def load(self):
self.model = DetectJailbreak(**kwargs)

def process_request(self, input_request: InputRequest):
message = input_request.message
prompts = input_request.prompts
# If needed, sanity check.
# raise HTTPException(status_code=400, detail="Invalid input format")
args = (message,)
args = (prompts,)
kwargs = {}
if input_request.threshold is not None:
kwargs["threshold"] = input_request.threshold
if not 0.0 <= input_request.threshold <= 1.0:
raise HTTPException(
status_code=400,
detail=f"Threshold must be between 0.0 and 1.0. "
f"Got {input_request.threshold}"
)
return args, kwargs

def infer(self, message: str, threshold: Optional[float] = None) -> OutputResponse:
if threshold is None:
threshold = 0.81

score = self.model.predict_jailbreak([message,])[0]
if score > threshold:
classification = "jailbreak"
is_jailbreak = True
else:
classification = "safe"
is_jailbreak = False

def infer(self, prompts: List[str]) -> OutputResponse:
return OutputResponse(
classification=classification,
score=score,
is_jailbreak=is_jailbreak,
scores=self.model.predict_jailbreak(prompts),
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "detect-jailbreak"
version = "0.1.3"
version = "0.1.4"
description = "A prompt-injection and jailbreak detector for LLMs."
authors = [
{name = "Guardrails AI", email = "[email protected]"},
Expand Down
127 changes: 79 additions & 48 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import math
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Union, Any

import torch
from torch.nn import functional as F
Expand Down Expand Up @@ -65,57 +66,64 @@ def __init__(
device: str = "cpu",
on_fail: Optional[Callable] = None,
model_path_override: str = "",
**kwargs,
):
super().__init__(on_fail=on_fail)
super().__init__(on_fail=on_fail, **kwargs)
self.device = device
self.threshold = threshold
self.saturation_attack_detector = None
self.text_classifier = None
self.embedding_tokenizer = None
self.embedding_model = None
self.known_malicious_embeddings = []

if not model_path_override:
self.saturation_attack_detector = PromptSaturationDetectorV3(
device=torch.device(device),
)
self.text_classifier = pipeline(
"text-classification",
DetectJailbreak.TEXT_CLASSIFIER_NAME,
max_length=512, # HACK: Fix classifier size.
truncation=True,
device=device,
)
# There are a large number of fairly low-effort prompts people will use.
# The embedding detectors do checks to roughly match those.
self.embedding_tokenizer = AutoTokenizer.from_pretrained(
DetectJailbreak.EMBEDDING_MODEL_NAME
)
self.embedding_model = AutoModel.from_pretrained(
DetectJailbreak.EMBEDDING_MODEL_NAME
).to(device)
else:
# Saturation:
self.saturation_attack_detector = PromptSaturationDetectorV3(
device=torch.device(device),
model_path_override=model_path_override
)
# Known attacks:
embedding_tokenizer, embedding_model = get_tokenizer_and_model_by_path(
model_path_override,
"embedding",
AutoTokenizer,
AutoModel
)
self.embedding_tokenizer = embedding_tokenizer
self.embedding_model = embedding_model.to(device)
# Other text attacks:
self.text_classifier = get_pipeline_by_path(
model_path_override,
"text-classifier",
"text-classification",
max_length=512,
truncation=True,
device=device
)
if self.use_local:
if not model_path_override:
self.saturation_attack_detector = PromptSaturationDetectorV3(
device=torch.device(device),
)
self.text_classifier = pipeline(
"text-classification",
DetectJailbreak.TEXT_CLASSIFIER_NAME,
max_length=512, # HACK: Fix classifier size.
truncation=True,
device=device,
)
# There are a large number of fairly low-effort prompts people will use.
# The embedding detectors do checks to roughly match those.
self.embedding_tokenizer = AutoTokenizer.from_pretrained(
DetectJailbreak.EMBEDDING_MODEL_NAME
)
self.embedding_model = AutoModel.from_pretrained(
DetectJailbreak.EMBEDDING_MODEL_NAME
).to(device)
else:
# Saturation:
self.saturation_attack_detector = PromptSaturationDetectorV3(
device=torch.device(device),
model_path_override=model_path_override
)
# Known attacks:
embedding_tokenizer, embedding_model = get_tokenizer_and_model_by_path(
model_path_override,
"embedding",
AutoTokenizer,
AutoModel
)
self.embedding_tokenizer = embedding_tokenizer
self.embedding_model = embedding_model.to(device)
# Other text attacks:
self.text_classifier = get_pipeline_by_path(
model_path_override,
"text-classifier",
"text-classification",
max_length=512,
truncation=True,
device=device
)

# Quick compute on startup:
self.known_malicious_embeddings = self._embed(KNOWN_ATTACKS)
# Quick compute on startup:
self.known_malicious_embeddings = self._embed(KNOWN_ATTACKS)

# These _are_ modifyable, but not explicitly advertised.
self.known_attack_scales = DetectJailbreak.DEFAULT_KNOWN_ATTACK_SCALE_FACTORS
Expand Down Expand Up @@ -233,6 +241,9 @@ def predict_jailbreak(
prompts: List[str],
reduction_function: Optional[Callable] = max,
) -> Union[List[float], List[dict]]:
"""predict_jailbreak will return an array of floats by default, one per prompt.
If reduction_function is set to 'none' it will return a dict with the different
sub-validators and their scores. Useful for debugging and tuning."""
if isinstance(prompts, str):
print("WARN: predict_jailbreak should be called with a list of strings.")
prompts = [prompts, ]
Expand Down Expand Up @@ -271,7 +282,9 @@ def validate(
if isinstance(value, str):
value = [value, ]

scores = self.predict_jailbreak(value)
# _inference is to support local/remote. It is equivalent to this:
# scores = self.predict_jailbreak(value)
scores = self._inference(value)

failed_prompts = list()
failed_scores = list() # To help people calibrate their thresholds.
Expand All @@ -289,3 +302,21 @@ def validate(
error_message=failure_message
)
return PassResult()

# The rest of these methods are made for validator compatibility and may have some
# strange properties,

def _inference_local(self, model_input: List[str]) -> Any:
return self.predict_jailbreak(model_input)

def _inference_remote(self, model_input: List[str]) -> Any:
# This needs to be kept in-sync with app_inference_spec.
request_body = {"prompts": model_input}
response = self._hub_inference_request(
json.dumps(request_body),
self.validation_endpoint
)
if not response or "scores" not in response:
raise ValueError("Invalid response from remote inference", response)

return response["scores"]

0 comments on commit 789a2da

Please sign in to comment.