From d44057442cfcdf51bdc434b24ac875af578f5fae Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Wed, 6 Sep 2023 12:20:14 -0500 Subject: [PATCH 01/52] start validation outcome changes --- guardrails/classes/__init__.py | 7 ++++++ guardrails/classes/validation_outcome.py | 32 ++++++++++++++++++++++++ guardrails/guard.py | 27 +++++++++++--------- 3 files changed, 54 insertions(+), 12 deletions(-) create mode 100644 guardrails/classes/__init__.py create mode 100644 guardrails/classes/validation_outcome.py diff --git a/guardrails/classes/__init__.py b/guardrails/classes/__init__.py new file mode 100644 index 000000000..cc8b43fd6 --- /dev/null +++ b/guardrails/classes/__init__.py @@ -0,0 +1,7 @@ +from guardrails.classes.validation_outcome import ValidationOutcome, TextOutcome, StructuredOutcome + +__all__ = [ + "ValidationOutcome", + "TextOutcome", + "StructuredOutcome", +] \ No newline at end of file diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py new file mode 100644 index 000000000..eb9f6647a --- /dev/null +++ b/guardrails/classes/validation_outcome.py @@ -0,0 +1,32 @@ +from typing import Union, Dict +from pydantic import Field +from guardrails.utils.logs_utils import GuardHistory, ArbitraryModel + +class ValidationOutcome(ArbitraryModel): + raw_llm_output: str = Field(description="The raw, unchanged output from the LLM call.") + validated_output: Union[str, Dict, None] = Field(description="The validated, and potentially fixed, output from the LLM call after passing through validation.") + validation_passed: bool = Field(description="A boolean to indicate whether or not the LLM output passed validation. If this is False, the validated_output may be invalid.") + + @classmethod + def from_guard_history (cls, guard_history: GuardHistory): + raw_output = guard_history.output + validated_output = guard_history.validated_output + any_validations_failed = len(guard_history.failed_validations) > 0 + if isinstance(validated_output, str): + return TextOutcome( + raw_llm_output=raw_output, + validated_output=validated_output, + validation_passed=any_validations_failed + ) + else: + return StructuredOutcome( + raw_llm_output=raw_output, + validated_output=validated_output, + validation_passed=any_validations_failed + ) + +class TextOutcome(ValidationOutcome): + validated_output: Union[str, None] = Field(description="The validated, and potentially fixed, output from the LLM call after passing through validation.") + +class StructuredOutcome(ValidationOutcome): + validated_output: Union[Dict, None] = Field(description="The validated, and potentially fixed, output from the LLM call after passing through validation.") \ No newline at end of file diff --git a/guardrails/guard.py b/guardrails/guard.py index c3260438a..ab6f8491a 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -26,6 +26,7 @@ from guardrails.utils.parsing_utils import get_template_variables from guardrails.utils.reask_utils import sub_reasks_with_fixed_values from guardrails.validators import Validator +from guardrails.classes import ValidationOutcome logger = logging.getLogger(__name__) actions_logger = logging.getLogger(f"{__name__}.actions") @@ -265,7 +266,7 @@ def __call__( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Union[Tuple[Optional[str], Any], Awaitable[Tuple[Optional[str], Any]]]: + ) -> Union[ValidationOutcome, Awaitable[ValidationOutcome]]: """Call the LLM and validate the output. Pass an async LLM API to return a coroutine. @@ -343,7 +344,7 @@ def _call_sync( full_schema_reask: bool, *args, **kwargs, - ) -> Tuple[Optional[str], Any]: + ) -> ValidationOutcome: instructions_obj = instructions or self.instructions prompt_obj = prompt or self.prompt msg_history_obj = msg_history or [] @@ -371,7 +372,7 @@ def _call_sync( full_schema_reask=full_schema_reask, ) guard_history = runner(prompt_params=prompt_params) - return guard_history.output, guard_history.validated_output + return ValidationOutcome.from_guard_history(guard_history) async def _call_async( self, @@ -385,7 +386,7 @@ async def _call_async( full_schema_reask: bool, *args, **kwargs, - ) -> Tuple[Optional[str], Any]: + ) -> ValidationOutcome: """Call the LLM asynchronously and validate the output. Args: @@ -430,7 +431,7 @@ async def _call_async( full_schema_reask=full_schema_reask, ) guard_history = await runner.async_run(prompt_params=prompt_params) - return guard_history.output, guard_history.validated_output + return ValidationOutcome.from_guard_history(guard_history) def __repr__(self): return f"Guard(RAIL={self.rail})" @@ -449,7 +450,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Any: + ) -> ValidationOutcome: ... @overload @@ -463,7 +464,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Awaitable[Any]: + ) -> Awaitable[ValidationOutcome]: ... @overload @@ -477,7 +478,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Any: + ) -> ValidationOutcome: ... def parse( @@ -490,7 +491,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Union[Any, Awaitable[Any]]: + ) -> Union[ValidationOutcome, Awaitable[ValidationOutcome]]: """Alternate flow to using Guard where the llm_output is known. Args: @@ -558,7 +559,7 @@ def _sync_parse( full_schema_reask: bool, *args, **kwargs, - ) -> Any: + ) -> ValidationOutcome: """Alternate flow to using Guard where the llm_output is known. Args: @@ -587,7 +588,8 @@ def _sync_parse( full_schema_reask=full_schema_reask, ) guard_history = runner(prompt_params=prompt_params) - return sub_reasks_with_fixed_values(guard_history.validated_output) + guard_history.history[-1].validated_output = sub_reasks_with_fixed_values(guard_history.validated_output) + return ValidationOutcome.from_guard_history(guard_history) async def _async_parse( self, @@ -628,4 +630,5 @@ async def _async_parse( full_schema_reask=full_schema_reask, ) guard_history = await runner.async_run(prompt_params=prompt_params) - return sub_reasks_with_fixed_values(guard_history.validated_output) + guard_history.history[-1].validated_output = sub_reasks_with_fixed_values(guard_history.validated_output) + return ValidationOutcome.from_guard_history(guard_history) From 6223bed24f3735dd19f8d03b90ce4302a19e0658 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Thu, 7 Sep 2023 14:12:06 -0500 Subject: [PATCH 02/52] fix gather_reasks for non-structured output --- guardrails/utils/reask_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/guardrails/utils/reask_utils.py b/guardrails/utils/reask_utils.py index 368b21686..d297c5bfb 100644 --- a/guardrails/utils/reask_utils.py +++ b/guardrails/utils/reask_utils.py @@ -24,7 +24,7 @@ class NonParseableReAsk(ReAsk): pass -def gather_reasks(validated_output: Optional[Union[Dict, ReAsk]]) -> List[ReAsk]: +def gather_reasks(validated_output: Any) -> List[FieldReAsk]: """Traverse output and gather all ReAsk objects. Args: @@ -73,7 +73,10 @@ def _gather_reasks_in_list( _gather_reasks_in_list(item, path + [idx]) return - _gather_reasks_in_dict(validated_output) + if isinstance(validated_output, Dict): + _gather_reasks_in_dict(validated_output) + elif isinstance(validated_output, ReAsk): + reasks = [validated_output] return reasks From e5cd3a03e239210993f67c70ff508d4ea6949377 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Thu, 7 Sep 2023 14:16:08 -0500 Subject: [PATCH 03/52] lint fixes --- guardrails/classes/__init__.py | 8 ++++-- guardrails/classes/validation_outcome.py | 35 +++++++++++++++++------- guardrails/guard.py | 10 +++++-- 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/guardrails/classes/__init__.py b/guardrails/classes/__init__.py index cc8b43fd6..8a86a4dd9 100644 --- a/guardrails/classes/__init__.py +++ b/guardrails/classes/__init__.py @@ -1,7 +1,11 @@ -from guardrails.classes.validation_outcome import ValidationOutcome, TextOutcome, StructuredOutcome +from guardrails.classes.validation_outcome import ( + StructuredOutcome, + TextOutcome, + ValidationOutcome, +) __all__ = [ "ValidationOutcome", "TextOutcome", "StructuredOutcome", -] \ No newline at end of file +] diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index eb9f6647a..dc0642b4e 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -1,14 +1,23 @@ -from typing import Union, Dict +from typing import Dict, Union + from pydantic import Field -from guardrails.utils.logs_utils import GuardHistory, ArbitraryModel + +from guardrails.utils.logs_utils import ArbitraryModel, GuardHistory + class ValidationOutcome(ArbitraryModel): - raw_llm_output: str = Field(description="The raw, unchanged output from the LLM call.") - validated_output: Union[str, Dict, None] = Field(description="The validated, and potentially fixed, output from the LLM call after passing through validation.") - validation_passed: bool = Field(description="A boolean to indicate whether or not the LLM output passed validation. If this is False, the validated_output may be invalid.") + raw_llm_output: str = Field( + description="The raw, unchanged output from the LLM call." + ) + validated_output: Union[str, Dict, None] = Field( + description="The validated, and potentially fixed, output from the LLM call after passing through validation." + ) + validation_passed: bool = Field( + description="A boolean to indicate whether or not the LLM output passed validation. If this is False, the validated_output may be invalid." + ) @classmethod - def from_guard_history (cls, guard_history: GuardHistory): + def from_guard_history(cls, guard_history: GuardHistory): raw_output = guard_history.output validated_output = guard_history.validated_output any_validations_failed = len(guard_history.failed_validations) > 0 @@ -16,17 +25,23 @@ def from_guard_history (cls, guard_history: GuardHistory): return TextOutcome( raw_llm_output=raw_output, validated_output=validated_output, - validation_passed=any_validations_failed + validation_passed=any_validations_failed, ) else: return StructuredOutcome( raw_llm_output=raw_output, validated_output=validated_output, - validation_passed=any_validations_failed + validation_passed=any_validations_failed, ) + class TextOutcome(ValidationOutcome): - validated_output: Union[str, None] = Field(description="The validated, and potentially fixed, output from the LLM call after passing through validation.") + validated_output: Union[str, None] = Field( + description="The validated, and potentially fixed, output from the LLM call after passing through validation." + ) + class StructuredOutcome(ValidationOutcome): - validated_output: Union[Dict, None] = Field(description="The validated, and potentially fixed, output from the LLM call after passing through validation.") \ No newline at end of file + validated_output: Union[Dict, None] = Field( + description="The validated, and potentially fixed, output from the LLM call after passing through validation." + ) diff --git a/guardrails/guard.py b/guardrails/guard.py index ab6f8491a..4a028097a 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -17,6 +17,7 @@ from eliot import add_destinations, start_action from pydantic import BaseModel +from guardrails.classes import ValidationOutcome from guardrails.llm_providers import get_async_llm_ask, get_llm_ask from guardrails.prompt import Instructions, Prompt from guardrails.rail import Rail @@ -26,7 +27,6 @@ from guardrails.utils.parsing_utils import get_template_variables from guardrails.utils.reask_utils import sub_reasks_with_fixed_values from guardrails.validators import Validator -from guardrails.classes import ValidationOutcome logger = logging.getLogger(__name__) actions_logger = logging.getLogger(f"{__name__}.actions") @@ -588,7 +588,9 @@ def _sync_parse( full_schema_reask=full_schema_reask, ) guard_history = runner(prompt_params=prompt_params) - guard_history.history[-1].validated_output = sub_reasks_with_fixed_values(guard_history.validated_output) + guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( + guard_history.validated_output + ) return ValidationOutcome.from_guard_history(guard_history) async def _async_parse( @@ -630,5 +632,7 @@ async def _async_parse( full_schema_reask=full_schema_reask, ) guard_history = await runner.async_run(prompt_params=prompt_params) - guard_history.history[-1].validated_output = sub_reasks_with_fixed_values(guard_history.validated_output) + guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( + guard_history.validated_output + ) return ValidationOutcome.from_guard_history(guard_history) From 89e79e6c9fc9839c64fa15c26508b6b87d8a92ba Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Thu, 7 Sep 2023 14:17:46 -0500 Subject: [PATCH 04/52] more lint fixes --- guardrails/classes/validation_outcome.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index dc0642b4e..67065f04b 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -10,10 +10,13 @@ class ValidationOutcome(ArbitraryModel): description="The raw, unchanged output from the LLM call." ) validated_output: Union[str, Dict, None] = Field( - description="The validated, and potentially fixed, output from the LLM call after passing through validation." + description="The validated, and potentially fixed," + " output from the LLM call after passing through validation." ) validation_passed: bool = Field( - description="A boolean to indicate whether or not the LLM output passed validation. If this is False, the validated_output may be invalid." + description="A boolean to indicate whether or not" + " the LLM output passed validation." + " If this is False, the validated_output may be invalid." ) @classmethod @@ -37,11 +40,13 @@ def from_guard_history(cls, guard_history: GuardHistory): class TextOutcome(ValidationOutcome): validated_output: Union[str, None] = Field( - description="The validated, and potentially fixed, output from the LLM call after passing through validation." + description="The validated, and potentially fixed," + " output from the LLM call after passing through validation." ) class StructuredOutcome(ValidationOutcome): validated_output: Union[Dict, None] = Field( - description="The validated, and potentially fixed, output from the LLM call after passing through validation." + description="The validated, and potentially fixed," + " output from the LLM call after passing through validation." ) From 0a60c3e0dfddb6fb6a718d25c3d79dc3bd04e221 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Thu, 7 Sep 2023 16:09:35 -0500 Subject: [PATCH 05/52] start test fixes, debug types --- guardrails/classes/validation_outcome.py | 11 ++-- guardrails/cli.py | 2 +- guardrails/guard.py | 5 +- guardrails/utils/logs_utils.py | 4 +- tests/integration_tests/test_async.py | 8 +-- .../integration_tests/test_data_validation.py | 46 +++++++++++------ tests/integration_tests/test_guard.py | 51 ++++++++++--------- 7 files changed, 75 insertions(+), 52 deletions(-) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index 67065f04b..c7bab623a 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -3,13 +3,14 @@ from pydantic import Field from guardrails.utils.logs_utils import ArbitraryModel, GuardHistory +from guardrails.utils.reask_utils import ReAsk class ValidationOutcome(ArbitraryModel): raw_llm_output: str = Field( description="The raw, unchanged output from the LLM call." ) - validated_output: Union[str, Dict, None] = Field( + validated_output: Union[str, Dict, ReAsk, None] = Field( description="The validated, and potentially fixed," " output from the LLM call after passing through validation." ) @@ -22,7 +23,10 @@ class ValidationOutcome(ArbitraryModel): @classmethod def from_guard_history(cls, guard_history: GuardHistory): raw_output = guard_history.output + print("from_guard_history - type(guard_history.validated_output): ", type(guard_history.validated_output)) + # validated_output: Union[str, Dict, ReAsk, None] = guard_history.validated_output validated_output = guard_history.validated_output + print("from_guard_history - type(validated_output): ", type(validated_output)) any_validations_failed = len(guard_history.failed_validations) > 0 if isinstance(validated_output, str): return TextOutcome( @@ -31,6 +35,7 @@ def from_guard_history(cls, guard_history: GuardHistory): validation_passed=any_validations_failed, ) else: + # TODO: Why does instantiation collapse validated_output to a dict? return StructuredOutcome( raw_llm_output=raw_output, validated_output=validated_output, @@ -39,14 +44,14 @@ def from_guard_history(cls, guard_history: GuardHistory): class TextOutcome(ValidationOutcome): - validated_output: Union[str, None] = Field( + validated_output: Union[str, ReAsk, None] = Field( description="The validated, and potentially fixed," " output from the LLM call after passing through validation." ) class StructuredOutcome(ValidationOutcome): - validated_output: Union[Dict, None] = Field( + validated_output: Union[Dict, ReAsk, None] = Field( description="The validated, and potentially fixed," " output from the LLM call after passing through validation." ) diff --git a/guardrails/cli.py b/guardrails/cli.py index de24c442f..85df23421 100644 --- a/guardrails/cli.py +++ b/guardrails/cli.py @@ -16,7 +16,7 @@ def validate_llm_output(rail: str, llm_output: str) -> dict: """Validate guardrails.yml file.""" guard = Guard.from_rail(rail) result = guard.parse(llm_output) - return result + return result.validated_output @cli.command() diff --git a/guardrails/guard.py b/guardrails/guard.py index 4a028097a..ad46c906d 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -591,7 +591,10 @@ def _sync_parse( guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( guard_history.validated_output ) - return ValidationOutcome.from_guard_history(guard_history) + print("before ValidationOutcome.from_guard_history - type(validated_output): ", type(guard_history.validated_output)) + validation_outcome = ValidationOutcome.from_guard_history(guard_history) + print("after ValidationOutcome.from_guard_history - type(validated_output): ", type(validation_outcome.validated_output)) + return validation_outcome async def _async_parse( self, diff --git a/guardrails/utils/logs_utils.py b/guardrails/utils/logs_utils.py index 34d42c10d..3ea66fe3f 100644 --- a/guardrails/utils/logs_utils.py +++ b/guardrails/utils/logs_utils.py @@ -52,7 +52,7 @@ class GuardLogs(ArbitraryModel): llm_response: Optional[LLMResponse] = None msg_history: Optional[List[Dict[str, Prompt]]] = None parsed_output: Optional[Dict] = None - validated_output: Optional[Union[Dict, ReAsk]] = None + validated_output: Optional[Union[str, Dict, ReAsk, None]] = None reasks: Optional[Sequence[ReAsk]] = None field_validation_logs: Optional[FieldValidationLogs] = None @@ -154,7 +154,7 @@ def tree(self) -> Tree: return tree @property - def validated_output(self) -> Optional[Union[Dict, ReAsk]]: + def validated_output(self) -> Union[str, Dict, ReAsk, None]: """Returns the latest validated output.""" return self.history[-1].validated_output diff --git a/tests/integration_tests/test_async.py b/tests/integration_tests/test_async.py index 98d6be42f..0eef70edc 100644 --- a/tests/integration_tests/test_async.py +++ b/tests/integration_tests/test_async.py @@ -29,7 +29,7 @@ async def test_entity_extraction_with_reask(mocker, multiprocessing_validators: with patch.object( JsonSchema, "preprocess_prompt", wraps=guard.output_schema.preprocess_prompt ) as mock_preprocess_prompt: - _, final_output = await guard( + final_output = await guard( llm_api=openai.Completion.acreate, prompt_params={"document": content[:6000]}, num_reasks=1, @@ -39,7 +39,7 @@ async def test_entity_extraction_with_reask(mocker, multiprocessing_validators: mock_preprocess_prompt.assert_called() # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_REASK_2 + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_REASK_2 guard_history = guard.guard_state.most_recent_call.history @@ -222,7 +222,7 @@ async def test_rail_spec_output_parse(rail_spec, llm_output, validated_output): llm_output, llm_api=openai.Completion.acreate, ) - assert output == validated_output + assert output.validated_output == validated_output @pytest.fixture @@ -262,4 +262,4 @@ async def test_string_rail_spec_output_parse( llm_api=openai.Completion.acreate, num_reasks=0, ) - assert output == validated_string_output + assert output.validated_output == validated_string_output diff --git a/tests/integration_tests/test_data_validation.py b/tests/integration_tests/test_data_validation.py index dd5abd9b6..21eacd9a9 100644 --- a/tests/integration_tests/test_data_validation.py +++ b/tests/integration_tests/test_data_validation.py @@ -9,19 +9,19 @@ from guardrails.validators import ValidChoices test_cases = [ - ('{"choice": {"action": "fight", "fight_move": "kick"}}', False), - ( - '{"choice": {"action": "flight", "flight_direction": "north", "flight_speed": 1}}', - False, - ), + # ('{"choice": {"action": "fight", "fight_move": "kick"}}', False), + # ( + # '{"choice": {"action": "flight", "flight_direction": "north", "flight_speed": 1}}', + # False, + # ), ('{"choice": {"action": "flight", "fight_move": "punch"}}', True), - ( - '{"choice": {"action": "fight", "flight_direction": "north", "flight_speed": 1}}', - True, - ), - ('{"choice": {"action": "random_action"}}', True), - ('{"choice": {"action": "fight", "fight": "random_move"}}', True), - ('{"choice": {"action": "flight", "random_key": "random_value"}', True), + # ( + # '{"choice": {"action": "fight", "flight_direction": "north", "flight_speed": 1}}', + # True, + # ), + # ('{"choice": {"action": "random_action"}}', True), + # ('{"choice": {"action": "fight", "fight": "random_move"}}', True), + # ('{"choice": {"action": "flight", "random_key": "random_value"}', True), ] @@ -55,11 +55,13 @@ def test_choice_validation(llm_output, raises): if raises: with pytest.raises(ValueError): result = guard.parse(llm_output, num_reasks=0) - if result is None or isinstance(result, ReAsk): + validated_output = result.validated_output + if validated_output is None or isinstance(validated_output, ReAsk): raise ValueError("Expected a result, but got None or ReAsk.") else: result = guard.parse(llm_output, num_reasks=0) - assert not isinstance(result, ReAsk) + validated_output = result.validated_output + assert not isinstance(validated_output, ReAsk) @pytest.mark.parametrize("llm_output, raises", test_cases) @@ -92,8 +94,20 @@ class Choice(BaseModel): if raises: with pytest.raises(ValueError): result = guard.parse(llm_output, num_reasks=0) - if result is None or isinstance(result, ReAsk): + validated_output = result.validated_output + print("llm_output: ", llm_output) + print("raises: ", raises) + print("validated_output: ", validated_output) + print("type(validated_output): ", type(validated_output)) + print("validated_output is None: ", validated_output is None) + print("isinstance(validated_output, ReAsk): ", isinstance(validated_output, ReAsk)) + if validated_output is None or isinstance(validated_output, ReAsk): raise ValueError("Expected a result, but got None or ReAsk.") else: result = guard.parse(llm_output, num_reasks=0) - assert not isinstance(result, ReAsk) + validated_output = result.validated_output + print("llm_output: ", llm_output) + print("raises: ", raises) + print("validated_output: ", validated_output) + print("type(validated_output): ", type(validated_output)) + assert not isinstance(validated_output, ReAsk) diff --git a/tests/integration_tests/test_guard.py b/tests/integration_tests/test_guard.py index a48474d30..acce77c4a 100644 --- a/tests/integration_tests/test_guard.py +++ b/tests/integration_tests/test_guard.py @@ -136,7 +136,7 @@ def test_entity_extraction_with_reask( content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = guard_initializer(rail, prompt) - _, final_output = guard( + final_output = guard( llm_api=openai.Completion.create, prompt_params={"document": content[:6000]}, num_reasks=1, @@ -145,7 +145,7 @@ def test_entity_extraction_with_reask( ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_REASK_2 + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_REASK_2 guard_history = guard.guard_state.most_recent_call.history @@ -215,14 +215,14 @@ def test_entity_extraction_with_noop(mocker, rail, prompt): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = guard_initializer(rail, prompt) - _, final_output = guard( + final_output = guard( llm_api=openai.Completion.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_NOOP + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_NOOP guard_history = guard.guard_state.most_recent_call.history @@ -251,14 +251,14 @@ def test_entity_extraction_with_filter(mocker, rail, prompt): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = guard_initializer(rail, prompt) - _, final_output = guard( + final_output = guard( llm_api=openai.Completion.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_FILTER + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_FILTER guard_history = guard.guard_state.most_recent_call.history @@ -286,14 +286,14 @@ def test_entity_extraction_with_fix(mocker, rail, prompt): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = guard_initializer(rail, prompt) - _, final_output = guard( + final_output = guard( llm_api=openai.Completion.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_FIX + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_FIX guard_history = guard.guard_state.most_recent_call.history @@ -322,14 +322,14 @@ def test_entity_extraction_with_refrain(mocker, rail, prompt): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = guard_initializer(rail, prompt) - _, final_output = guard( + final_output = guard( llm_api=openai.Completion.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_REFRAIN + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_REFRAIN guard_history = guard.guard_state.most_recent_call.history @@ -365,14 +365,14 @@ def test_entity_extraction_with_fix_chat_models(mocker, rail, prompt, instructio content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = guard_initializer(rail, prompt, instructions) - _, final_output = guard( + final_output = guard( llm_api=openai.ChatCompletion.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_FIX + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_FIX guard_history = guard.guard_state.most_recent_call.history @@ -395,12 +395,13 @@ def test_string_output(mocker): mocker.patch("guardrails.llm_providers.OpenAICallable", new=MockOpenAICallable) guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_STRING) - _, final_output = guard( + final_output = guard( llm_api=openai.Completion.create, prompt_params={"ingredients": "tomato, cheese, sour cream"}, num_reasks=1, ) - assert final_output == string.LLM_OUTPUT + + assert final_output.validated_output == string.LLM_OUTPUT guard_history = guard.guard_state.most_recent_call.history @@ -417,14 +418,14 @@ def test_string_reask(mocker): mocker.patch("guardrails.llm_providers.OpenAICallable", new=MockOpenAICallable) guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_STRING_REASK) - _, final_output = guard( + final_output = guard( llm_api=openai.Completion.create, prompt_params={"ingredients": "tomato, cheese, sour cream"}, num_reasks=1, max_tokens=100, ) - assert final_output == string.LLM_OUTPUT_REASK + assert final_output.validated_output == string.LLM_OUTPUT_REASK guard_history = guard.guard_state.most_recent_call.history @@ -450,7 +451,7 @@ def test_skeleton_reask(mocker): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = gd.Guard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_SKELETON_REASK) - _, final_output = guard( + final_output = guard( llm_api=openai.Completion.create, prompt_params={"document": content[:6000]}, max_tokens=1000, @@ -458,7 +459,7 @@ def test_skeleton_reask(mocker): ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_SKELETON_REASK_2 + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_SKELETON_REASK_2 guard_history = guard.guard_state.most_recent_call.history @@ -574,7 +575,7 @@ def test_entity_extraction_with_reask_with_optional_prompts( content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = Guard.from_rail_string(rail) - _, final_output = guard( + final_output = guard( llm_api=llm_api, prompt=prompt, instructions=instructions, @@ -584,7 +585,7 @@ def test_entity_extraction_with_reask_with_optional_prompts( ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_REASK_2 + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_REASK_2 guard_history = guard.guard_state.most_recent_call.history @@ -649,14 +650,14 @@ def test_string_with_message_history_reask(mocker): ) guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_MSG_HISTORY) - _, final_output = guard( + final_output = guard( llm_api=openai.ChatCompletion.create, msg_history=string.MOVIE_MSG_HISTORY, temperature=0.0, model="gpt-3.5-turbo", ) - assert final_output == string.MSG_LLM_OUTPUT_CORRECT + assert final_output.validated_output == string.MSG_LLM_OUTPUT_CORRECT guard_history = guard.guard_state.most_recent_call.history @@ -685,15 +686,15 @@ def test_pydantic_with_message_history_reask(mocker): ) guard = gd.Guard.from_pydantic(output_class=pydantic.WITH_MSG_HISTORY) - raw_output, guarded_output = guard( + final_output = guard( llm_api=openai.ChatCompletion.create, msg_history=string.MOVIE_MSG_HISTORY, temperature=0.0, model="gpt-3.5-turbo", ) - assert raw_output == pydantic.MSG_HISTORY_LLM_OUTPUT_CORRECT - assert guarded_output == json.loads(pydantic.MSG_HISTORY_LLM_OUTPUT_CORRECT) + assert final_output.raw_llm_output == pydantic.MSG_HISTORY_LLM_OUTPUT_CORRECT + assert final_output.validated_output == json.loads(pydantic.MSG_HISTORY_LLM_OUTPUT_CORRECT) guard_history = guard.guard_state.most_recent_call.history From dc6002656c3acdb19e0e75365f9807e3a392e51c Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 18 Sep 2023 09:05:50 -0500 Subject: [PATCH 06/52] fix tests --- tests/integration_tests/test_async.py | 20 ++++++++++---------- tests/integration_tests/test_multi_reask.py | 2 +- tests/integration_tests/test_parsing.py | 8 ++++---- tests/integration_tests/test_pydantic.py | 4 ++-- tests/integration_tests/test_python_rail.py | 12 ++++++------ 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/integration_tests/test_async.py b/tests/integration_tests/test_async.py index 0eef70edc..fc496bf61 100644 --- a/tests/integration_tests/test_async.py +++ b/tests/integration_tests/test_async.py @@ -71,14 +71,14 @@ async def test_entity_extraction_with_noop(mocker): ) content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = gd.Guard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_NOOP) - _, final_output = await guard( + final_output = await guard( llm_api=openai.Completion.acreate, prompt_params={"document": content[:6000]}, num_reasks=1, ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_NOOP + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_NOOP guard_history = guard.guard_state.most_recent_call.history @@ -101,14 +101,14 @@ async def test_entity_extraction_with_noop_pydantic(mocker): guard = gd.Guard.from_pydantic( entity_extraction.PYDANTIC_RAIL_WITH_NOOP, entity_extraction.PYDANTIC_PROMPT ) - _, final_output = await guard( + final_output = await guard( llm_api=openai.Completion.acreate, prompt_params={"document": content[:6000]}, num_reasks=1, ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_NOOP + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_NOOP guard_history = guard.guard_state.most_recent_call.history @@ -131,14 +131,14 @@ async def test_entity_extraction_with_filter(mocker): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = gd.Guard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_FILTER) - _, final_output = await guard( + final_output = await guard( llm_api=openai.Completion.acreate, prompt_params={"document": content[:6000]}, num_reasks=1, ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_FILTER + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_FILTER guard_history = guard.guard_state.most_recent_call.history @@ -163,14 +163,14 @@ async def test_entity_extraction_with_fix(mocker): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = gd.Guard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_FIX) - _, final_output = await guard( + final_output = await guard( llm_api=openai.Completion.acreate, prompt_params={"document": content[:6000]}, num_reasks=1, ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_FIX + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_FIX guard_history = guard.guard_state.most_recent_call.history @@ -193,13 +193,13 @@ async def test_entity_extraction_with_refrain(mocker): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = gd.Guard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_REFRAIN) - _, final_output = await guard( + final_output = await guard( llm_api=openai.Completion.acreate, prompt_params={"document": content[:6000]}, num_reasks=1, ) # Assertions are made on the guard state object. - assert final_output == entity_extraction.VALIDATED_OUTPUT_REFRAIN + assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_REFRAIN guard_history = guard.guard_state.most_recent_call.history diff --git a/tests/integration_tests/test_multi_reask.py b/tests/integration_tests/test_multi_reask.py index 3bf938840..e44997be7 100644 --- a/tests/integration_tests/test_multi_reask.py +++ b/tests/integration_tests/test_multi_reask.py @@ -12,7 +12,7 @@ def test_multi_reask(mocker): guard = gd.Guard.from_rail_string(python_rail.RAIL_SPEC_WITH_VALIDATOR_PARALLELISM) - _, final_output = guard( + final_output = guard( llm_api=openai.Completion.create, engine="text-davinci-003", num_reasks=5, diff --git a/tests/integration_tests/test_parsing.py b/tests/integration_tests/test_parsing.py index 981293414..c63e2cbf5 100644 --- a/tests/integration_tests/test_parsing.py +++ b/tests/integration_tests/test_parsing.py @@ -28,13 +28,13 @@ def test_parsing_reask(mocker): def mock_callable(prompt: str): return - _, final_output = guard( + final_output = guard( llm_api=mock_callable, prompt_params={"document": pydantic.PARSING_DOCUMENT}, num_reasks=1, ) - assert final_output == pydantic.PARSING_EXPECTED_OUTPUT + assert final_output.validated_output == pydantic.PARSING_EXPECTED_OUTPUT guard_history = guard.guard_state.most_recent_call.history @@ -67,13 +67,13 @@ async def test_async_parsing_reask(mocker): async def mock_async_callable(prompt: str): return - _, final_output = await guard( + final_output = await guard( llm_api=mock_async_callable, prompt_params={"document": pydantic.PARSING_DOCUMENT}, num_reasks=1, ) - assert final_output == pydantic.PARSING_EXPECTED_OUTPUT + assert final_output.validated_output == pydantic.PARSING_EXPECTED_OUTPUT guard_history = guard.guard_state.most_recent_call.history diff --git a/tests/integration_tests/test_pydantic.py b/tests/integration_tests/test_pydantic.py index 2109a4f51..e304908ec 100644 --- a/tests/integration_tests/test_pydantic.py +++ b/tests/integration_tests/test_pydantic.py @@ -11,7 +11,7 @@ def test_pydantic_with_reask(mocker): mocker.patch("guardrails.llm_providers.OpenAICallable", new=MockOpenAICallable) guard = gd.Guard.from_pydantic(ListOfPeople, prompt=VALIDATED_RESPONSE_REASK_PROMPT) - _, final_output = guard( + final_output = guard( openai.Completion.create, engine="text-davinci-003", max_tokens=512, @@ -21,7 +21,7 @@ def test_pydantic_with_reask(mocker): ) # Assertions are made on the guard state object. - assert final_output == pydantic.VALIDATED_OUTPUT_REASK_3 + assert final_output.validated_output == pydantic.VALIDATED_OUTPUT_REASK_3 guard_history = guard.guard_state.most_recent_call.history diff --git a/tests/integration_tests/test_python_rail.py b/tests/integration_tests/test_python_rail.py index 307478f6a..07efb481a 100644 --- a/tests/integration_tests/test_python_rail.py +++ b/tests/integration_tests/test_python_rail.py @@ -107,7 +107,7 @@ class Director(BaseModel): ) # Guardrails runs validation and fixes the first failing output through reasking - _, final_output = guard( + final_output = guard( openai.ChatCompletion.create, prompt_params={"director": "Christopher Nolan"}, num_reasks=2, @@ -118,7 +118,7 @@ class Director(BaseModel): expected_gd_output = json.loads( python_rail.LLM_OUTPUT_2_SUCCEED_GUARDRAILS_BUT_FAIL_PYDANTIC_VALIDATION ) - assert final_output == expected_gd_output + assert final_output.validated_output == expected_gd_output guard_history = guard.guard_state.most_recent_call.history @@ -228,7 +228,7 @@ class Director(BaseModel): ) # Guardrails runs validation and fixes the first failing output through reasking - _, final_output = guard( + final_output = guard( openai.ChatCompletion.create, prompt_params={"director": "Christopher Nolan"}, num_reasks=2, @@ -239,7 +239,7 @@ class Director(BaseModel): expected_gd_output = json.loads( python_rail.LLM_OUTPUT_2_SUCCEED_GUARDRAILS_BUT_FAIL_PYDANTIC_VALIDATION ) - assert final_output == expected_gd_output + assert final_output.validated_output == expected_gd_output guard_history = guard.guard_state.most_recent_call.history @@ -295,14 +295,14 @@ def test_python_string(mocker): guard = gd.Guard.from_string( validators, description, prompt=prompt, instructions=instructions ) - _, final_output = guard( + final_output = guard( llm_api=openai.Completion.create, prompt_params={"ingredients": "tomato, cheese, sour cream"}, num_reasks=1, max_tokens=100, ) - assert final_output == string.LLM_OUTPUT_REASK + assert final_output.validated_output == string.LLM_OUTPUT_REASK guard_history = guard.guard_state.most_recent_call.history From 3b667dd9721c0af7a055527dfdb74124e2f310b9 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 18 Sep 2023 09:28:52 -0500 Subject: [PATCH 07/52] fix types with overloads --- guardrails/classes/validation_outcome.py | 33 +++++++++++++++--- guardrails/guard.py | 6 ++-- .../integration_tests/test_data_validation.py | 34 +++++++------------ 3 files changed, 43 insertions(+), 30 deletions(-) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index c7bab623a..d9878059c 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Dict, Union, overload from pydantic import Field @@ -20,13 +20,32 @@ class ValidationOutcome(ArbitraryModel): " If this is False, the validated_output may be invalid." ) + @overload + def __init__ (self, raw_llm_output: str, validated_output: str, validation_passed: bool): + ... + + @overload + def __init__ (self, raw_llm_output: str, validated_output: Dict, validation_passed: bool): + ... + + @overload + def __init__ (self, raw_llm_output: str, validated_output: ReAsk, validation_passed: bool): + ... + + @overload + def __init__ (self, raw_llm_output: str, validated_output: None, validation_passed: bool): + ... + + def __init__ (self, raw_llm_output: str, validated_output: Union[str, Dict, ReAsk, None], validation_passed: bool): + super().__init__(raw_llm_output=raw_llm_output, validated_output=validated_output, validation_passed=validation_passed) + self.raw_llm_output = raw_llm_output + self.validated_output = validated_output + self.validation_passed = validation_passed + @classmethod def from_guard_history(cls, guard_history: GuardHistory): raw_output = guard_history.output - print("from_guard_history - type(guard_history.validated_output): ", type(guard_history.validated_output)) - # validated_output: Union[str, Dict, ReAsk, None] = guard_history.validated_output validated_output = guard_history.validated_output - print("from_guard_history - type(validated_output): ", type(validated_output)) any_validations_failed = len(guard_history.failed_validations) > 0 if isinstance(validated_output, str): return TextOutcome( @@ -49,9 +68,15 @@ class TextOutcome(ValidationOutcome): " output from the LLM call after passing through validation." ) + def __init__ (self, raw_llm_output: str, validated_output: Union[str, ReAsk, None], validation_passed: bool): + super().__init__(raw_llm_output, validated_output, validation_passed) + class StructuredOutcome(ValidationOutcome): validated_output: Union[Dict, ReAsk, None] = Field( description="The validated, and potentially fixed," " output from the LLM call after passing through validation." ) + + def __init__ (self, raw_llm_output: str, validated_output: Union[Dict, ReAsk, None], validation_passed: bool): + super().__init__(raw_llm_output, validated_output, validation_passed) diff --git a/guardrails/guard.py b/guardrails/guard.py index ad46c906d..8d833db60 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -235,7 +235,7 @@ def __call__( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Awaitable[Tuple[str, Any]]: + ) -> Awaitable[ValidationOutcome]: ... @overload @@ -251,7 +251,7 @@ def __call__( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Tuple[str, Any]: + ) -> ValidationOutcome: ... def __call__( @@ -591,9 +591,7 @@ def _sync_parse( guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( guard_history.validated_output ) - print("before ValidationOutcome.from_guard_history - type(validated_output): ", type(guard_history.validated_output)) validation_outcome = ValidationOutcome.from_guard_history(guard_history) - print("after ValidationOutcome.from_guard_history - type(validated_output): ", type(validation_outcome.validated_output)) return validation_outcome async def _async_parse( diff --git a/tests/integration_tests/test_data_validation.py b/tests/integration_tests/test_data_validation.py index 21eacd9a9..2f76d8e27 100644 --- a/tests/integration_tests/test_data_validation.py +++ b/tests/integration_tests/test_data_validation.py @@ -9,19 +9,19 @@ from guardrails.validators import ValidChoices test_cases = [ - # ('{"choice": {"action": "fight", "fight_move": "kick"}}', False), - # ( - # '{"choice": {"action": "flight", "flight_direction": "north", "flight_speed": 1}}', - # False, - # ), + ('{"choice": {"action": "fight", "fight_move": "kick"}}', False), + ( + '{"choice": {"action": "flight", "flight_direction": "north", "flight_speed": 1}}', + False, + ), ('{"choice": {"action": "flight", "fight_move": "punch"}}', True), - # ( - # '{"choice": {"action": "fight", "flight_direction": "north", "flight_speed": 1}}', - # True, - # ), - # ('{"choice": {"action": "random_action"}}', True), - # ('{"choice": {"action": "fight", "fight": "random_move"}}', True), - # ('{"choice": {"action": "flight", "random_key": "random_value"}', True), + ( + '{"choice": {"action": "fight", "flight_direction": "north", "flight_speed": 1}}', + True, + ), + ('{"choice": {"action": "random_action"}}', True), + ('{"choice": {"action": "fight", "fight": "random_move"}}', True), + ('{"choice": {"action": "flight", "random_key": "random_value"}', True), ] @@ -95,19 +95,9 @@ class Choice(BaseModel): with pytest.raises(ValueError): result = guard.parse(llm_output, num_reasks=0) validated_output = result.validated_output - print("llm_output: ", llm_output) - print("raises: ", raises) - print("validated_output: ", validated_output) - print("type(validated_output): ", type(validated_output)) - print("validated_output is None: ", validated_output is None) - print("isinstance(validated_output, ReAsk): ", isinstance(validated_output, ReAsk)) if validated_output is None or isinstance(validated_output, ReAsk): raise ValueError("Expected a result, but got None or ReAsk.") else: result = guard.parse(llm_output, num_reasks=0) validated_output = result.validated_output - print("llm_output: ", llm_output) - print("raises: ", raises) - print("validated_output: ", validated_output) - print("type(validated_output): ", type(validated_output)) assert not isinstance(validated_output, ReAsk) From adfda531b83f44398f312a3c0a24ce2918977f51 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 18 Sep 2023 09:59:14 -0500 Subject: [PATCH 08/52] fix tests --- tests/unit_tests/test_validators.py | 35 ++++++++++++++++++----------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/tests/unit_tests/test_validators.py b/tests/unit_tests/test_validators.py index 7b98a9c8b..f0a1f4304 100644 --- a/tests/unit_tests/test_validators.py +++ b/tests/unit_tests/test_validators.py @@ -208,7 +208,7 @@ class MyModel(BaseModel): num_reasks=0, ) - assert output == {"a_field": "hullo"} + assert output.validated_output == {"a_field": "hullo"} # (string, on_fail) tuple fix @@ -223,7 +223,7 @@ class MyModel(BaseModel): num_reasks=0, ) - assert output == {"a_field": "hullo"} + assert output.validated_output == {"a_field": "hullo"} # (Validator, on_fail) tuple fix @@ -236,7 +236,7 @@ class MyModel(BaseModel): num_reasks=0, ) - assert output == {"a_field": "hello there"} + assert output.validated_output == {"a_field": "hello there"} # (Validator, on_fail) tuple reask @@ -261,8 +261,11 @@ class MyModel(BaseModel): num_reasks=0, ) - assert output == {"a_field": "hullo"} - assert guard.guard_state.all_histories[0].history[0].reasks[0] == hullo_reask + assert output.validated_output == {"a_field": "hullo"} + assert ( + guard.guard_state.all_histories[0].history[0].parsed_output["a_field"] + == hullo_reask + ) hello_reask = FieldReAsk( incorrect_value="hello there yo", @@ -287,8 +290,11 @@ class MyModel(BaseModel): num_reasks=0, ) - assert output == {"a_field": "hello there"} - assert guard.guard_state.all_histories[0].history[0].reasks[0] == hello_reask + assert output.validated_output == {"a_field": "hello there"} + assert ( + guard.guard_state.all_histories[0].history[0].parsed_output["a_field"] + == hello_reask + ) # (Validator, on_fail) tuple reask @@ -302,8 +308,11 @@ class MyModel(BaseModel): num_reasks=0, ) - assert output == {"a_field": "hello there"} - assert guard.guard_state.all_histories[0].history[0].reasks[0] == hello_reask + assert output.validated_output == {"a_field": "hello there"} + assert ( + guard.guard_state.all_histories[0].history[0].parsed_output["a_field"] + == hello_reask + ) # Fail on string @@ -331,7 +340,7 @@ def test_custom_func_validator(): '{"greeting": "hello"}', num_reasks=0, ) - assert output == {"greeting": "hullo"} + assert output.validated_output == {"greeting": "hullo"} guard_history = guard.guard_state.all_histories[0].history assert len(guard_history) == 1 @@ -392,7 +401,7 @@ def test_provenance_v1(mocker): llm_output=LLM_RESPONSE, metadata={"query_function": mock_chromadb_query_function}, ) - assert output == LLM_RESPONSE + assert output.validated_output == LLM_RESPONSE # 2. Setting the environment variable os.environ["OPENAI_API_KEY"] = API_KEY @@ -400,7 +409,7 @@ def test_provenance_v1(mocker): llm_output=LLM_RESPONSE, metadata={"query_function": mock_chromadb_query_function}, ) - assert output == LLM_RESPONSE + assert output.validated_output == LLM_RESPONSE # 3. Passing the API key as an argument output = string_guard.parse( @@ -409,7 +418,7 @@ def test_provenance_v1(mocker): api_key=API_KEY, api_base="https://api.openai.com", ) - assert output == LLM_RESPONSE + assert output.validated_output == LLM_RESPONSE @pytest.mark.parametrize( From 5fa638a3ce9e7703986536b274dc1cd5f74a983a Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 18 Sep 2023 10:09:51 -0500 Subject: [PATCH 09/52] lint fixes --- guardrails/classes/validation_outcome.py | 49 ++++++++++++++++++------ tests/integration_tests/test_guard.py | 9 ++++- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index d9878059c..712ed3fde 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -21,23 +21,40 @@ class ValidationOutcome(ArbitraryModel): ) @overload - def __init__ (self, raw_llm_output: str, validated_output: str, validation_passed: bool): + def __init__( + self, raw_llm_output: str, validated_output: str, validation_passed: bool + ): ... - + @overload - def __init__ (self, raw_llm_output: str, validated_output: Dict, validation_passed: bool): + def __init__( + self, raw_llm_output: str, validated_output: Dict, validation_passed: bool + ): ... - + @overload - def __init__ (self, raw_llm_output: str, validated_output: ReAsk, validation_passed: bool): + def __init__( + self, raw_llm_output: str, validated_output: ReAsk, validation_passed: bool + ): ... @overload - def __init__ (self, raw_llm_output: str, validated_output: None, validation_passed: bool): + def __init__( + self, raw_llm_output: str, validated_output: None, validation_passed: bool + ): ... - - def __init__ (self, raw_llm_output: str, validated_output: Union[str, Dict, ReAsk, None], validation_passed: bool): - super().__init__(raw_llm_output=raw_llm_output, validated_output=validated_output, validation_passed=validation_passed) + + def __init__( + self, + raw_llm_output: str, + validated_output: Union[str, Dict, ReAsk, None], + validation_passed: bool, + ): + super().__init__( + raw_llm_output=raw_llm_output, + validated_output=validated_output, + validation_passed=validation_passed, + ) self.raw_llm_output = raw_llm_output self.validated_output = validated_output self.validation_passed = validation_passed @@ -68,7 +85,12 @@ class TextOutcome(ValidationOutcome): " output from the LLM call after passing through validation." ) - def __init__ (self, raw_llm_output: str, validated_output: Union[str, ReAsk, None], validation_passed: bool): + def __init__( + self, + raw_llm_output: str, + validated_output: Union[str, ReAsk, None], + validation_passed: bool, + ): super().__init__(raw_llm_output, validated_output, validation_passed) @@ -78,5 +100,10 @@ class StructuredOutcome(ValidationOutcome): " output from the LLM call after passing through validation." ) - def __init__ (self, raw_llm_output: str, validated_output: Union[Dict, ReAsk, None], validation_passed: bool): + def __init__( + self, + raw_llm_output: str, + validated_output: Union[Dict, ReAsk, None], + validation_passed: bool, + ): super().__init__(raw_llm_output, validated_output, validation_passed) diff --git a/tests/integration_tests/test_guard.py b/tests/integration_tests/test_guard.py index acce77c4a..d459b961e 100644 --- a/tests/integration_tests/test_guard.py +++ b/tests/integration_tests/test_guard.py @@ -459,7 +459,10 @@ def test_skeleton_reask(mocker): ) # Assertions are made on the guard state object. - assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_SKELETON_REASK_2 + assert ( + final_output.validated_output + == entity_extraction.VALIDATED_OUTPUT_SKELETON_REASK_2 + ) guard_history = guard.guard_state.most_recent_call.history @@ -694,7 +697,9 @@ def test_pydantic_with_message_history_reask(mocker): ) assert final_output.raw_llm_output == pydantic.MSG_HISTORY_LLM_OUTPUT_CORRECT - assert final_output.validated_output == json.loads(pydantic.MSG_HISTORY_LLM_OUTPUT_CORRECT) + assert final_output.validated_output == json.loads( + pydantic.MSG_HISTORY_LLM_OUTPUT_CORRECT + ) guard_history = guard.guard_state.most_recent_call.history From c9f7b157811932f714f8e8c4e435acc194091604 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 18 Sep 2023 10:11:06 -0500 Subject: [PATCH 10/52] lint fixes --- guardrails/guard.py | 14 ++------------ tests/integration_tests/test_multi_reask.py | 2 +- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index 8d833db60..af4955f8d 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1,18 +1,8 @@ import asyncio import contextvars import logging -from typing import ( - Any, - Awaitable, - Callable, - Dict, - List, - Optional, - Tuple, - Type, - Union, - overload, -) +from string import Formatter +from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union, overload from eliot import add_destinations, start_action from pydantic import BaseModel diff --git a/tests/integration_tests/test_multi_reask.py b/tests/integration_tests/test_multi_reask.py index e44997be7..2537d2b11 100644 --- a/tests/integration_tests/test_multi_reask.py +++ b/tests/integration_tests/test_multi_reask.py @@ -12,7 +12,7 @@ def test_multi_reask(mocker): guard = gd.Guard.from_rail_string(python_rail.RAIL_SPEC_WITH_VALIDATOR_PARALLELISM) - final_output = guard( + guard( llm_api=openai.Completion.create, engine="text-davinci-003", num_reasks=5, From 184891b91c08f711ad9c5fa36a617d8ac6b01b50 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 18 Sep 2023 11:27:27 -0500 Subject: [PATCH 11/52] fix tests --- tests/unit_tests/test_prompt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/test_prompt.py b/tests/unit_tests/test_prompt.py index 40f2d1c08..60f89b894 100644 --- a/tests/unit_tests/test_prompt.py +++ b/tests/unit_tests/test_prompt.py @@ -2,6 +2,8 @@ from string import Template from unittest import mock +from guardrails.prompt.instructions import Instructions +from guardrails.prompt.prompt import Prompt import pytest from pydantic import BaseModel, Field @@ -211,9 +213,7 @@ def test_reask_prompt(): def test_reask_instructions(): guard = gd.Guard.from_rail_string(RAIL_WITH_REASK_INSTRUCTIONS) - assert guard.output_schema._reask_instructions_template == Instructions( - INSTRUCTIONS - ) + assert guard.output_schema._reask_instructions_template == Instructions(INSTRUCTIONS) @pytest.mark.parametrize( From 9c985bbd8d27b750fdb44362d2a72e64dbf7b64c Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 18 Sep 2023 11:31:11 -0500 Subject: [PATCH 12/52] lint fixes --- tests/unit_tests/test_prompt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/test_prompt.py b/tests/unit_tests/test_prompt.py index 60f89b894..40f2d1c08 100644 --- a/tests/unit_tests/test_prompt.py +++ b/tests/unit_tests/test_prompt.py @@ -2,8 +2,6 @@ from string import Template from unittest import mock -from guardrails.prompt.instructions import Instructions -from guardrails.prompt.prompt import Prompt import pytest from pydantic import BaseModel, Field @@ -213,7 +211,9 @@ def test_reask_prompt(): def test_reask_instructions(): guard = gd.Guard.from_rail_string(RAIL_WITH_REASK_INSTRUCTIONS) - assert guard.output_schema._reask_instructions_template == Instructions(INSTRUCTIONS) + assert guard.output_schema._reask_instructions_template == Instructions( + INSTRUCTIONS + ) @pytest.mark.parametrize( From 57357fb00fa336d20d0a0480bf482155391b657e Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 18 Sep 2023 14:33:16 -0500 Subject: [PATCH 13/52] switch to generics for ValidationOutcome --- guardrails/classes/__init__.py | 8 +-- guardrails/classes/validation_outcome.py | 91 ++++-------------------- guardrails/guard.py | 68 +++++++++++------- guardrails/rail.py | 9 ++- 4 files changed, 66 insertions(+), 110 deletions(-) diff --git a/guardrails/classes/__init__.py b/guardrails/classes/__init__.py index 8a86a4dd9..656f27eda 100644 --- a/guardrails/classes/__init__.py +++ b/guardrails/classes/__init__.py @@ -1,11 +1,5 @@ -from guardrails.classes.validation_outcome import ( - StructuredOutcome, - TextOutcome, - ValidationOutcome, -) +from guardrails.classes.validation_outcome import ValidationOutcome __all__ = [ "ValidationOutcome", - "TextOutcome", - "StructuredOutcome", ] diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index 712ed3fde..a56238c67 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -1,109 +1,46 @@ -from typing import Dict, Union, overload +from typing import Dict, Generic, Optional, TypeVar from pydantic import Field from guardrails.utils.logs_utils import ArbitraryModel, GuardHistory from guardrails.utils.reask_utils import ReAsk +T = TypeVar("T", str, Dict, None) -class ValidationOutcome(ArbitraryModel): + +class ValidationOutcome(Generic[T], ArbitraryModel): raw_llm_output: str = Field( description="The raw, unchanged output from the LLM call." ) - validated_output: Union[str, Dict, ReAsk, None] = Field( + validated_output: Optional[T] = Field( description="The validated, and potentially fixed," " output from the LLM call after passing through validation." ) + reask: Optional[ReAsk] = Field( + description="If validation continuously fails and all allocated" + " reasks are used, this field will contain the final reask that" + " would have been sent to the LLM if additional reasks were available." + ) validation_passed: bool = Field( description="A boolean to indicate whether or not" " the LLM output passed validation." " If this is False, the validated_output may be invalid." ) - @overload - def __init__( - self, raw_llm_output: str, validated_output: str, validation_passed: bool - ): - ... - - @overload - def __init__( - self, raw_llm_output: str, validated_output: Dict, validation_passed: bool - ): - ... - - @overload - def __init__( - self, raw_llm_output: str, validated_output: ReAsk, validation_passed: bool - ): - ... - - @overload - def __init__( - self, raw_llm_output: str, validated_output: None, validation_passed: bool - ): - ... - - def __init__( - self, - raw_llm_output: str, - validated_output: Union[str, Dict, ReAsk, None], - validation_passed: bool, - ): - super().__init__( - raw_llm_output=raw_llm_output, - validated_output=validated_output, - validation_passed=validation_passed, - ) - self.raw_llm_output = raw_llm_output - self.validated_output = validated_output - self.validation_passed = validation_passed - @classmethod def from_guard_history(cls, guard_history: GuardHistory): raw_output = guard_history.output validated_output = guard_history.validated_output any_validations_failed = len(guard_history.failed_validations) > 0 - if isinstance(validated_output, str): - return TextOutcome( + if isinstance(validated_output, ReAsk): + return cls[T]( raw_llm_output=raw_output, - validated_output=validated_output, + reask=validated_output, validation_passed=any_validations_failed, ) else: - # TODO: Why does instantiation collapse validated_output to a dict? - return StructuredOutcome( + return cls[T]( raw_llm_output=raw_output, validated_output=validated_output, validation_passed=any_validations_failed, ) - - -class TextOutcome(ValidationOutcome): - validated_output: Union[str, ReAsk, None] = Field( - description="The validated, and potentially fixed," - " output from the LLM call after passing through validation." - ) - - def __init__( - self, - raw_llm_output: str, - validated_output: Union[str, ReAsk, None], - validation_passed: bool, - ): - super().__init__(raw_llm_output, validated_output, validation_passed) - - -class StructuredOutcome(ValidationOutcome): - validated_output: Union[Dict, ReAsk, None] = Field( - description="The validated, and potentially fixed," - " output from the LLM call after passing through validation." - ) - - def __init__( - self, - raw_llm_output: str, - validated_output: Union[Dict, ReAsk, None], - validation_passed: bool, - ): - super().__init__(raw_llm_output, validated_output, validation_passed) diff --git a/guardrails/guard.py b/guardrails/guard.py index af4955f8d..ffd8ce888 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -2,7 +2,19 @@ import contextvars import logging from string import Formatter -from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union, overload +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generic, + List, + Optional, + Type, + TypeVar, + Union, + overload, +) from eliot import add_destinations, start_action from pydantic import BaseModel @@ -22,8 +34,10 @@ actions_logger = logging.getLogger(f"{__name__}.actions") add_destinations(actions_logger.debug) +T = TypeVar("T", str, Dict, None) -class Guard: + +class Guard(Generic[T]): """The Guard class. This class is the main entry point for using Guardrails. It is @@ -138,7 +152,9 @@ def configure( ) @classmethod - def from_rail(cls, rail_file: str, num_reasks: Optional[int] = None) -> "Guard": + def from_rail( + cls, rail_file: str, num_reasks: Optional[int] = None + ) -> "Guard[Union[str, Dict]]": """Create a Schema from a `.rail` file. Args: @@ -148,12 +164,13 @@ def from_rail(cls, rail_file: str, num_reasks: Optional[int] = None) -> "Guard": Returns: An instance of the `Guard` class. """ - return cls(Rail.from_file(rail_file), num_reasks=num_reasks) + rail = Rail.from_file(rail_file) + return cls[rail.output_type](rail=rail, num_reasks=num_reasks) @classmethod def from_rail_string( cls, rail_string: str, num_reasks: Optional[int] = None - ) -> "Guard": + ) -> "Guard[Union[str, Dict]]": """Create a Schema from a `.rail` string. Args: @@ -163,7 +180,8 @@ def from_rail_string( Returns: An instance of the `Guard` class. """ - return cls(Rail.from_string(rail_string), num_reasks=num_reasks) + rail = Rail.from_string(rail_string) + return cls[rail.output_type](rail=rail, num_reasks=num_reasks) @classmethod def from_pydantic( @@ -172,12 +190,12 @@ def from_pydantic( prompt: Optional[str] = None, instructions: Optional[str] = None, num_reasks: Optional[int] = None, - ) -> "Guard": + ) -> "Guard[Dict]": """Create a Guard instance from a Pydantic model and prompt.""" rail = Rail.from_pydantic( output_class=output_class, prompt=prompt, instructions=instructions ) - return cls(rail, num_reasks=num_reasks, base_model=output_class) + return cls[Dict](rail, num_reasks=num_reasks, base_model=output_class) @classmethod def from_string( @@ -189,7 +207,7 @@ def from_string( reask_prompt: Optional[str] = None, reask_instructions: Optional[str] = None, num_reasks: Optional[int] = None, - ) -> "Guard": + ) -> "Guard[str]": """Create a Guard instance for a string response with prompt, instructions, and validations. @@ -210,7 +228,7 @@ def from_string( reask_prompt=reask_prompt, reask_instructions=reask_instructions, ) - return cls(rail, num_reasks=num_reasks) + return cls[str](rail, num_reasks=num_reasks) @overload def __call__( @@ -225,7 +243,7 @@ def __call__( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Awaitable[ValidationOutcome]: + ) -> Awaitable[ValidationOutcome[T]]: ... @overload @@ -241,7 +259,7 @@ def __call__( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> ValidationOutcome: + ) -> ValidationOutcome[T]: ... def __call__( @@ -256,7 +274,7 @@ def __call__( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Union[ValidationOutcome, Awaitable[ValidationOutcome]]: + ) -> Union[ValidationOutcome[T], Awaitable[ValidationOutcome[T]]]: """Call the LLM and validate the output. Pass an async LLM API to return a coroutine. @@ -334,7 +352,7 @@ def _call_sync( full_schema_reask: bool, *args, **kwargs, - ) -> ValidationOutcome: + ) -> ValidationOutcome[T]: instructions_obj = instructions or self.instructions prompt_obj = prompt or self.prompt msg_history_obj = msg_history or [] @@ -362,7 +380,7 @@ def _call_sync( full_schema_reask=full_schema_reask, ) guard_history = runner(prompt_params=prompt_params) - return ValidationOutcome.from_guard_history(guard_history) + return ValidationOutcome[T].from_guard_history(guard_history) async def _call_async( self, @@ -376,7 +394,7 @@ async def _call_async( full_schema_reask: bool, *args, **kwargs, - ) -> ValidationOutcome: + ) -> ValidationOutcome[T]: """Call the LLM asynchronously and validate the output. Args: @@ -421,7 +439,7 @@ async def _call_async( full_schema_reask=full_schema_reask, ) guard_history = await runner.async_run(prompt_params=prompt_params) - return ValidationOutcome.from_guard_history(guard_history) + return ValidationOutcome[T].from_guard_history(guard_history) def __repr__(self): return f"Guard(RAIL={self.rail})" @@ -440,7 +458,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> ValidationOutcome: + ) -> ValidationOutcome[T]: ... @overload @@ -454,7 +472,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Awaitable[ValidationOutcome]: + ) -> Awaitable[ValidationOutcome[T]]: ... @overload @@ -468,7 +486,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> ValidationOutcome: + ) -> ValidationOutcome[T]: ... def parse( @@ -481,7 +499,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Union[ValidationOutcome, Awaitable[ValidationOutcome]]: + ) -> Union[ValidationOutcome[T], Awaitable[ValidationOutcome[T]]]: """Alternate flow to using Guard where the llm_output is known. Args: @@ -549,7 +567,7 @@ def _sync_parse( full_schema_reask: bool, *args, **kwargs, - ) -> ValidationOutcome: + ) -> ValidationOutcome[T]: """Alternate flow to using Guard where the llm_output is known. Args: @@ -581,7 +599,7 @@ def _sync_parse( guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( guard_history.validated_output ) - validation_outcome = ValidationOutcome.from_guard_history(guard_history) + validation_outcome = ValidationOutcome[T].from_guard_history(guard_history) return validation_outcome async def _async_parse( @@ -594,7 +612,7 @@ async def _async_parse( full_schema_reask: bool, *args, **kwargs, - ) -> Any: + ) -> ValidationOutcome[T]: """Alternate flow to using Guard where the llm_output is known. Args: @@ -626,4 +644,4 @@ async def _async_parse( guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( guard_history.validated_output ) - return ValidationOutcome.from_guard_history(guard_history) + return ValidationOutcome[T].from_guard_history(guard_history) diff --git a/guardrails/rail.py b/guardrails/rail.py index fcded5170..d7b1913b8 100644 --- a/guardrails/rail.py +++ b/guardrails/rail.py @@ -1,7 +1,7 @@ """Rail class.""" import warnings from dataclasses import dataclass -from typing import List, Optional, Type +from typing import Dict, List, Optional, Type from lxml import etree as ET from lxml.etree import Element, SubElement @@ -40,6 +40,13 @@ class Rail: prompt: Optional[Prompt] version: str = "0.1" + @property + def output_type(self): + if isinstance(self.output_schema, StringSchema): + return str + else: + return Dict + @classmethod def from_pydantic( cls, From 521fe6ad410d76673609435697a6cc5034518f74 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 18 Sep 2023 15:44:12 -0500 Subject: [PATCH 14/52] allow destructuring --- guardrails/classes/validation_outcome.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index a56238c67..e1b91fa58 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -44,3 +44,15 @@ def from_guard_history(cls, guard_history: GuardHistory): validated_output=validated_output, validation_passed=any_validations_failed, ) + + def __iter__(self): + as_tuple = ( + self.raw_llm_output, + self.validated_output, + self.reask, + self.validation_passed, + ) + return iter(as_tuple) + + def __getitem__(self, keys): + return iter(getattr(self, k) for k in keys) From 8b6c71dbc34382ccd640cfae1727a1c7dd8fa0fd Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Wed, 20 Sep 2023 09:16:01 -0500 Subject: [PATCH 15/52] remove None from generic type --- guardrails/classes/validation_outcome.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index e1b91fa58..a13f22862 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -5,7 +5,7 @@ from guardrails.utils.logs_utils import ArbitraryModel, GuardHistory from guardrails.utils.reask_utils import ReAsk -T = TypeVar("T", str, Dict, None) +T = TypeVar("T", str, Dict) class ValidationOutcome(Generic[T], ArbitraryModel): From 86fd20ed4bac09824428678c6dceebd5ae2bc761 Mon Sep 17 00:00:00 2001 From: Nefertiti Rogers Date: Wed, 20 Sep 2023 17:00:36 -0700 Subject: [PATCH 16/52] init commit, changes to handle error in guard --- guardrails/classes/validation_outcome.py | 10 ++++++++++ guardrails/guard.py | 20 +++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index a13f22862..9ff1ac3ed 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -26,6 +26,7 @@ class ValidationOutcome(Generic[T], ArbitraryModel): " the LLM output passed validation." " If this is False, the validated_output may be invalid." ) + exception: Optional[str] = Field() @classmethod def from_guard_history(cls, guard_history: GuardHistory): @@ -44,6 +45,14 @@ def from_guard_history(cls, guard_history: GuardHistory): validated_output=validated_output, validation_passed=any_validations_failed, ) + + @classmethod + def from_exception(cls, error_message: str): + return cls[T]( + raw_llm_output='', + validation_passed=False, + exception=error_message + ) def __iter__(self): as_tuple = ( @@ -51,6 +60,7 @@ def __iter__(self): self.validated_output, self.reask, self.validation_passed, + self.exception, ) return iter(as_tuple) diff --git a/guardrails/guard.py b/guardrails/guard.py index ffd8ce888..52aea7770 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -230,7 +230,18 @@ def from_string( ) return cls[str](rail, num_reasks=num_reasks) + def handle_error(func): + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except Exception as e: + error_message = str(e) + return ValidationOutcome[T].from_exception(error_message) + + return wrapper + @overload + @handle_error def __call__( self, llm_api: Callable[[Any], Awaitable[Any]], @@ -247,6 +258,7 @@ def __call__( ... @overload + @handle_error def __call__( self, llm_api: Callable, @@ -262,6 +274,7 @@ def __call__( ) -> ValidationOutcome[T]: ... + @handle_error def __call__( self, llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]], @@ -446,8 +459,9 @@ def __repr__(self): def __rich_repr__(self): yield "RAIL", self.rail - + @overload + @handle_error def parse( self, llm_output: str, @@ -462,6 +476,7 @@ def parse( ... @overload + @handle_error def parse( self, llm_output: str, @@ -476,6 +491,7 @@ def parse( ... @overload + @handle_error def parse( self, llm_output: str, @@ -489,6 +505,7 @@ def parse( ) -> ValidationOutcome[T]: ... + @handle_error def parse( self, llm_output: str, @@ -645,3 +662,4 @@ async def _async_parse( guard_history.validated_output ) return ValidationOutcome[T].from_guard_history(guard_history) + From ca55b4fb610eded5bde2f267f27d227006a12917 Mon Sep 17 00:00:00 2001 From: Nefertiti Rogers Date: Thu, 21 Sep 2023 12:45:21 -0700 Subject: [PATCH 17/52] handle error a layer deeper --- guardrails/applications/text2sql.py | 8 ++---- guardrails/classes/validation_outcome.py | 15 ++++++++--- guardrails/guard.py | 34 ++++++------------------ guardrails/run.py | 30 +++++++++++++++++++-- 4 files changed, 50 insertions(+), 37 deletions(-) diff --git a/guardrails/applications/text2sql.py b/guardrails/applications/text2sql.py index bafca2cf8..b7b9053a8 100644 --- a/guardrails/applications/text2sql.py +++ b/guardrails/applications/text2sql.py @@ -187,7 +187,7 @@ def __call__(self, text: str) -> Optional[str]: ) try: - output = self.guard( + return self.guard( self.llm_api, prompt_params={ "nl_instruction": text, @@ -195,11 +195,7 @@ def __call__(self, text: str) -> Optional[str]: "db_info": str(self.sql_schema), }, **self.llm_api_kwargs, - )[ # type: ignore - 1 - ][ - "generated_sql" - ] + ).validated_output["generated_sql"] except TypeError: output = None diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index 9ff1ac3ed..4b00d24ca 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -29,22 +29,31 @@ class ValidationOutcome(Generic[T], ArbitraryModel): exception: Optional[str] = Field() @classmethod - def from_guard_history(cls, guard_history: GuardHistory): + def from_guard_history(cls, guard_history: GuardHistory, error_message: Optional[str]): raw_output = guard_history.output validated_output = guard_history.validated_output any_validations_failed = len(guard_history.failed_validations) > 0 - if isinstance(validated_output, ReAsk): + if(error_message): + return cls[T]( + raw_llm_output=raw_output or "", + validation_passed=False, + exception=error_message, + ) + elif isinstance(validated_output, ReAsk): return cls[T]( raw_llm_output=raw_output, reask=validated_output, validation_passed=any_validations_failed, ) else: - return cls[T]( + print("else") + result = cls[T]( raw_llm_output=raw_output, validated_output=validated_output, validation_passed=any_validations_failed, ) + print(result) + return result @classmethod def from_exception(cls, error_message: str): diff --git a/guardrails/guard.py b/guardrails/guard.py index 52aea7770..037d3592e 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -229,19 +229,8 @@ def from_string( reask_instructions=reask_instructions, ) return cls[str](rail, num_reasks=num_reasks) - - def handle_error(func): - def wrapper(self, *args, **kwargs): - try: - return func(self, *args, **kwargs) - except Exception as e: - error_message = str(e) - return ValidationOutcome[T].from_exception(error_message) - - return wrapper @overload - @handle_error def __call__( self, llm_api: Callable[[Any], Awaitable[Any]], @@ -258,7 +247,6 @@ def __call__( ... @overload - @handle_error def __call__( self, llm_api: Callable, @@ -274,7 +262,6 @@ def __call__( ) -> ValidationOutcome[T]: ... - @handle_error def __call__( self, llm_api: Union[Callable, Callable[[Any], Awaitable[Any]]], @@ -392,8 +379,8 @@ def _call_sync( guard_state=self.guard_state, full_schema_reask=full_schema_reask, ) - guard_history = runner(prompt_params=prompt_params) - return ValidationOutcome[T].from_guard_history(guard_history) + guard_history, error_message = runner(prompt_params=prompt_params) + return ValidationOutcome[T].from_guard_history(guard_history, error_message) async def _call_async( self, @@ -451,8 +438,8 @@ async def _call_async( guard_state=self.guard_state, full_schema_reask=full_schema_reask, ) - guard_history = await runner.async_run(prompt_params=prompt_params) - return ValidationOutcome[T].from_guard_history(guard_history) + guard_history, error_message = await runner.async_run(prompt_params=prompt_params) + return ValidationOutcome[T].from_guard_history(guard_history, error_message) def __repr__(self): return f"Guard(RAIL={self.rail})" @@ -461,7 +448,6 @@ def __rich_repr__(self): yield "RAIL", self.rail @overload - @handle_error def parse( self, llm_output: str, @@ -476,7 +462,6 @@ def parse( ... @overload - @handle_error def parse( self, llm_output: str, @@ -491,7 +476,6 @@ def parse( ... @overload - @handle_error def parse( self, llm_output: str, @@ -504,8 +488,6 @@ def parse( **kwargs, ) -> ValidationOutcome[T]: ... - - @handle_error def parse( self, llm_output: str, @@ -612,11 +594,11 @@ def _sync_parse( guard_state=self.guard_state, full_schema_reask=full_schema_reask, ) - guard_history = runner(prompt_params=prompt_params) + guard_history, error_message = runner(prompt_params=prompt_params) guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( guard_history.validated_output ) - validation_outcome = ValidationOutcome[T].from_guard_history(guard_history) + validation_outcome = ValidationOutcome[T].from_guard_history(guard_history, error_message) return validation_outcome async def _async_parse( @@ -657,9 +639,9 @@ async def _async_parse( guard_state=self.guard_state, full_schema_reask=full_schema_reask, ) - guard_history = await runner.async_run(prompt_params=prompt_params) + guard_history, error_message = await runner.async_run(prompt_params=prompt_params) guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( guard_history.validated_output ) - return ValidationOutcome[T].from_guard_history(guard_history) + return ValidationOutcome[T].from_guard_history(guard_history, error_message) diff --git a/guardrails/run.py b/guardrails/run.py index d052327fe..fb5500726 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -106,7 +106,20 @@ def _reset_guard_history(self): self.guard_history = GuardHistory(history=[]) self.guard_state.push(self.guard_history) - def __call__(self, prompt_params: Optional[Dict] = None) -> GuardHistory: + def handle_error(func): + def wrapper(self, *args, **kwargs): + error_message = None + try: + guard_history = func(self, *args, **kwargs) + except Exception as e: + error_message = str(e) + guard_history = self.guard_history + return guard_history, error_message + + return wrapper + + @handle_error + def __call__(self, prompt_params: Optional[Dict] = None) -> Tuple[GuardHistory, str]: """Execute the runner by repeatedly calling step until the reask budget is exhausted. @@ -500,7 +513,20 @@ def __init__( ) self.api: Optional[AsyncPromptCallableBase] = api - async def async_run(self, prompt_params: Optional[Dict] = None) -> GuardHistory: + def handle_error(func): + async def wrapper(self, *args, **kwargs): + error_message = None + try: + guard_history = await func(self, *args, **kwargs) + except Exception as e: + error_message = str(e) + guard_history = self.guard_history + return guard_history, error_message + + return wrapper + + @handle_error + async def async_run(self, prompt_params: Optional[Dict] = None) -> Tuple[GuardHistory, str]: """Execute the runner by repeatedly calling step until the reask budget is exhausted. From db984099f4423b1f5296268a960e73daca049b9e Mon Sep 17 00:00:00 2001 From: Nefertiti Rogers Date: Thu, 21 Sep 2023 13:51:40 -0700 Subject: [PATCH 18/52] update return in text2sql --- guardrails/applications/text2sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guardrails/applications/text2sql.py b/guardrails/applications/text2sql.py index b7b9053a8..ce93e0683 100644 --- a/guardrails/applications/text2sql.py +++ b/guardrails/applications/text2sql.py @@ -187,7 +187,7 @@ def __call__(self, text: str) -> Optional[str]: ) try: - return self.guard( + output = self.guard( self.llm_api, prompt_params={ "nl_instruction": text, From 2d9c694d2ad1922383813fb4821db604cbdcac2b Mon Sep 17 00:00:00 2001 From: Nefertiti Rogers Date: Thu, 21 Sep 2023 13:52:56 -0700 Subject: [PATCH 19/52] remove extra fx in validation outcome --- guardrails/classes/validation_outcome.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index 4b00d24ca..b8697c675 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -54,14 +54,6 @@ def from_guard_history(cls, guard_history: GuardHistory, error_message: Optional ) print(result) return result - - @classmethod - def from_exception(cls, error_message: str): - return cls[T]( - raw_llm_output='', - validation_passed=False, - exception=error_message - ) def __iter__(self): as_tuple = ( From 0127b9f868c5d1e90b3419f894dc062c07b9e2d5 Mon Sep 17 00:00:00 2001 From: Nefertiti Rogers Date: Thu, 21 Sep 2023 13:58:51 -0700 Subject: [PATCH 20/52] use error instead of exception --- guardrails/classes/validation_outcome.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index b8697c675..2d18a9a1e 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -26,7 +26,7 @@ class ValidationOutcome(Generic[T], ArbitraryModel): " the LLM output passed validation." " If this is False, the validated_output may be invalid." ) - exception: Optional[str] = Field() + error: Optional[str] = Field() @classmethod def from_guard_history(cls, guard_history: GuardHistory, error_message: Optional[str]): @@ -37,7 +37,7 @@ def from_guard_history(cls, guard_history: GuardHistory, error_message: Optional return cls[T]( raw_llm_output=raw_output or "", validation_passed=False, - exception=error_message, + error=error_message, ) elif isinstance(validated_output, ReAsk): return cls[T]( @@ -61,7 +61,7 @@ def __iter__(self): self.validated_output, self.reask, self.validation_passed, - self.exception, + self.error, ) return iter(as_tuple) From ec1231260e35674ee7335772671d4eac60d0d9cc Mon Sep 17 00:00:00 2001 From: Nefertiti Rogers Date: Thu, 21 Sep 2023 14:00:21 -0700 Subject: [PATCH 21/52] remove print statements plus lint --- guardrails/classes/validation_outcome.py | 11 +++++------ guardrails/guard.py | 18 ++++++++++++------ guardrails/run.py | 14 +++++++++----- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index 2d18a9a1e..92b434547 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -29,11 +29,13 @@ class ValidationOutcome(Generic[T], ArbitraryModel): error: Optional[str] = Field() @classmethod - def from_guard_history(cls, guard_history: GuardHistory, error_message: Optional[str]): + def from_guard_history( + cls, guard_history: GuardHistory, error_message: Optional[str] + ): raw_output = guard_history.output validated_output = guard_history.validated_output any_validations_failed = len(guard_history.failed_validations) > 0 - if(error_message): + if error_message: return cls[T]( raw_llm_output=raw_output or "", validation_passed=False, @@ -46,14 +48,11 @@ def from_guard_history(cls, guard_history: GuardHistory, error_message: Optional validation_passed=any_validations_failed, ) else: - print("else") - result = cls[T]( + return cls[T]( raw_llm_output=raw_output, validated_output=validated_output, validation_passed=any_validations_failed, ) - print(result) - return result def __iter__(self): as_tuple = ( diff --git a/guardrails/guard.py b/guardrails/guard.py index 037d3592e..965c2aefa 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -229,7 +229,7 @@ def from_string( reask_instructions=reask_instructions, ) return cls[str](rail, num_reasks=num_reasks) - + @overload def __call__( self, @@ -438,7 +438,9 @@ async def _call_async( guard_state=self.guard_state, full_schema_reask=full_schema_reask, ) - guard_history, error_message = await runner.async_run(prompt_params=prompt_params) + guard_history, error_message = await runner.async_run( + prompt_params=prompt_params + ) return ValidationOutcome[T].from_guard_history(guard_history, error_message) def __repr__(self): @@ -446,7 +448,7 @@ def __repr__(self): def __rich_repr__(self): yield "RAIL", self.rail - + @overload def parse( self, @@ -488,6 +490,7 @@ def parse( **kwargs, ) -> ValidationOutcome[T]: ... + def parse( self, llm_output: str, @@ -598,7 +601,9 @@ def _sync_parse( guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( guard_history.validated_output ) - validation_outcome = ValidationOutcome[T].from_guard_history(guard_history, error_message) + validation_outcome = ValidationOutcome[T].from_guard_history( + guard_history, error_message + ) return validation_outcome async def _async_parse( @@ -639,9 +644,10 @@ async def _async_parse( guard_state=self.guard_state, full_schema_reask=full_schema_reask, ) - guard_history, error_message = await runner.async_run(prompt_params=prompt_params) + guard_history, error_message = await runner.async_run( + prompt_params=prompt_params + ) guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( guard_history.validated_output ) return ValidationOutcome[T].from_guard_history(guard_history, error_message) - diff --git a/guardrails/run.py b/guardrails/run.py index fb5500726..b07f83318 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -106,7 +106,7 @@ def _reset_guard_history(self): self.guard_history = GuardHistory(history=[]) self.guard_state.push(self.guard_history) - def handle_error(func): + def handle_error(func): def wrapper(self, *args, **kwargs): error_message = None try: @@ -117,9 +117,11 @@ def wrapper(self, *args, **kwargs): return guard_history, error_message return wrapper - + @handle_error - def __call__(self, prompt_params: Optional[Dict] = None) -> Tuple[GuardHistory, str]: + def __call__( + self, prompt_params: Optional[Dict] = None + ) -> Tuple[GuardHistory, str]: """Execute the runner by repeatedly calling step until the reask budget is exhausted. @@ -513,7 +515,7 @@ def __init__( ) self.api: Optional[AsyncPromptCallableBase] = api - def handle_error(func): + def handle_error(func): async def wrapper(self, *args, **kwargs): error_message = None try: @@ -526,7 +528,9 @@ async def wrapper(self, *args, **kwargs): return wrapper @handle_error - async def async_run(self, prompt_params: Optional[Dict] = None) -> Tuple[GuardHistory, str]: + async def async_run( + self, prompt_params: Optional[Dict] = None + ) -> Tuple[GuardHistory, str]: """Execute the runner by repeatedly calling step until the reask budget is exhausted. From 3bd4b0ec3313572ff7387e221c1d447ad798b706 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Tue, 31 Oct 2023 16:31:37 -0500 Subject: [PATCH 22/52] fix type --- guardrails/cli.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/guardrails/cli.py b/guardrails/cli.py index 85df23421..6c246c4c8 100644 --- a/guardrails/cli.py +++ b/guardrails/cli.py @@ -1,4 +1,5 @@ import json +from typing import Dict, Union import typer @@ -12,7 +13,7 @@ def compile_rail(rail: str, out: str) -> None: raise NotImplementedError("Currently compiling rail is not supported.") -def validate_llm_output(rail: str, llm_output: str) -> dict: +def validate_llm_output(rail: str, llm_output: str) -> Union[str, Dict, None]: """Validate guardrails.yml file.""" guard = Guard.from_rail(rail) result = guard.parse(llm_output) From 82bb11ce1f3e7809a1f563a9e9b607072d6e5161 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Thu, 2 Nov 2023 12:43:36 -0500 Subject: [PATCH 23/52] fix typing while maintaining type hinting --- guardrails/classes/__init__.py | 2 + guardrails/classes/output_type.py | 3 ++ guardrails/classes/validation_outcome.py | 38 ++++++++------ guardrails/guard.py | 66 +++++++++++++----------- guardrails/rail.py | 4 +- 5 files changed, 64 insertions(+), 49 deletions(-) create mode 100644 guardrails/classes/output_type.py diff --git a/guardrails/classes/__init__.py b/guardrails/classes/__init__.py index 656f27eda..7c040a4f4 100644 --- a/guardrails/classes/__init__.py +++ b/guardrails/classes/__init__.py @@ -1,5 +1,7 @@ from guardrails.classes.validation_outcome import ValidationOutcome +from guardrails.classes.output_type import OT __all__ = [ "ValidationOutcome", + "OT" ] diff --git a/guardrails/classes/output_type.py b/guardrails/classes/output_type.py new file mode 100644 index 000000000..1a8924ad2 --- /dev/null +++ b/guardrails/classes/output_type.py @@ -0,0 +1,3 @@ +from typing import Dict, TypeVar, Union + +OT = TypeVar("OT", str, Dict) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index 92b434547..7f39ed6db 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -1,32 +1,34 @@ -from typing import Dict, Generic, Optional, TypeVar +from typing import Dict, Generic, Iterator, Optional, Tuple, Union, cast from pydantic import Field from guardrails.utils.logs_utils import ArbitraryModel, GuardHistory from guardrails.utils.reask_utils import ReAsk +from guardrails.classes.output_type import OT -T = TypeVar("T", str, Dict) - -class ValidationOutcome(Generic[T], ArbitraryModel): - raw_llm_output: str = Field( - description="The raw, unchanged output from the LLM call." +class ValidationOutcome(Generic[OT], ArbitraryModel): + raw_llm_output: Optional[str] = Field( + description="The raw, unchanged output from the LLM call.", + default=None ) - validated_output: Optional[T] = Field( + validated_output: Optional[OT] = Field( description="The validated, and potentially fixed," - " output from the LLM call after passing through validation." + " output from the LLM call after passing through validation.", + default=None ) reask: Optional[ReAsk] = Field( description="If validation continuously fails and all allocated" " reasks are used, this field will contain the final reask that" - " would have been sent to the LLM if additional reasks were available." + " would have been sent to the LLM if additional reasks were available.", + default=None ) validation_passed: bool = Field( description="A boolean to indicate whether or not" " the LLM output passed validation." " If this is False, the validated_output may be invalid." ) - error: Optional[str] = Field() + error: Optional[str] = Field(default=None) @classmethod def from_guard_history( @@ -36,26 +38,28 @@ def from_guard_history( validated_output = guard_history.validated_output any_validations_failed = len(guard_history.failed_validations) > 0 if error_message: - return cls[T]( + return cls( raw_llm_output=raw_output or "", validation_passed=False, error=error_message, ) elif isinstance(validated_output, ReAsk): - return cls[T]( + reask: ReAsk = validated_output + return cls( raw_llm_output=raw_output, - reask=validated_output, + reask=reask, validation_passed=any_validations_failed, ) else: - return cls[T]( + output = cast(OT, validated_output) + return cls( raw_llm_output=raw_output, - validated_output=validated_output, + validated_output=output, validation_passed=any_validations_failed, ) - def __iter__(self): - as_tuple = ( + def __iter__(self) -> Iterator[Union[Optional[str], Optional[OT], Optional[ReAsk], bool, Optional[str]]]: + as_tuple: Tuple[Optional[str], Optional[OT], Optional[ReAsk], bool, Optional[str]] = ( self.raw_llm_output, self.validated_output, self.reask, diff --git a/guardrails/guard.py b/guardrails/guard.py index 965c2aefa..98bcd1a86 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -11,15 +11,15 @@ List, Optional, Type, - TypeVar, Union, + cast, overload, ) from eliot import add_destinations, start_action from pydantic import BaseModel -from guardrails.classes import ValidationOutcome +from guardrails.classes import ValidationOutcome, OT from guardrails.llm_providers import get_async_llm_ask, get_llm_ask from guardrails.prompt import Instructions, Prompt from guardrails.rail import Rail @@ -34,10 +34,8 @@ actions_logger = logging.getLogger(f"{__name__}.actions") add_destinations(actions_logger.debug) -T = TypeVar("T", str, Dict, None) - -class Guard(Generic[T]): +class Guard(Generic[OT]): """The Guard class. This class is the main entry point for using Guardrails. It is @@ -153,8 +151,10 @@ def configure( @classmethod def from_rail( - cls, rail_file: str, num_reasks: Optional[int] = None - ) -> "Guard[Union[str, Dict]]": + cls, + rail_file: str, + num_reasks: Optional[int] = None + ): """Create a Schema from a `.rail` file. Args: @@ -165,12 +165,16 @@ def from_rail( An instance of the `Guard` class. """ rail = Rail.from_file(rail_file) - return cls[rail.output_type](rail=rail, num_reasks=num_reasks) + if rail.output_type == 'str': + return cast(Guard[str], cls(rail=rail, num_reasks=num_reasks)) + return cast(Guard[Dict], cls(rail=rail, num_reasks=num_reasks)) @classmethod def from_rail_string( - cls, rail_string: str, num_reasks: Optional[int] = None - ) -> "Guard[Union[str, Dict]]": + cls, + rail_string: str, + num_reasks: Optional[int] = None + ): """Create a Schema from a `.rail` string. Args: @@ -181,7 +185,9 @@ def from_rail_string( An instance of the `Guard` class. """ rail = Rail.from_string(rail_string) - return cls[rail.output_type](rail=rail, num_reasks=num_reasks) + if rail.output_type == 'str': + return cast(Guard[str], cls(rail=rail, num_reasks=num_reasks)) + return cast(Guard[Dict], cls(rail=rail, num_reasks=num_reasks)) @classmethod def from_pydantic( @@ -190,12 +196,12 @@ def from_pydantic( prompt: Optional[str] = None, instructions: Optional[str] = None, num_reasks: Optional[int] = None, - ) -> "Guard[Dict]": + ): """Create a Guard instance from a Pydantic model and prompt.""" rail = Rail.from_pydantic( output_class=output_class, prompt=prompt, instructions=instructions ) - return cls[Dict](rail, num_reasks=num_reasks, base_model=output_class) + return cast(Guard[Dict], cls(rail, num_reasks=num_reasks, base_model=output_class)) @classmethod def from_string( @@ -207,7 +213,7 @@ def from_string( reask_prompt: Optional[str] = None, reask_instructions: Optional[str] = None, num_reasks: Optional[int] = None, - ) -> "Guard[str]": + ): """Create a Guard instance for a string response with prompt, instructions, and validations. @@ -228,7 +234,7 @@ def from_string( reask_prompt=reask_prompt, reask_instructions=reask_instructions, ) - return cls[str](rail, num_reasks=num_reasks) + return cast(Guard[str], cls(rail, num_reasks=num_reasks)) @overload def __call__( @@ -243,7 +249,7 @@ def __call__( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Awaitable[ValidationOutcome[T]]: + ) -> Awaitable[ValidationOutcome[OT]]: ... @overload @@ -259,7 +265,7 @@ def __call__( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> ValidationOutcome[T]: + ) -> ValidationOutcome[OT]: ... def __call__( @@ -274,7 +280,7 @@ def __call__( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Union[ValidationOutcome[T], Awaitable[ValidationOutcome[T]]]: + ) -> Union[ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]]]: """Call the LLM and validate the output. Pass an async LLM API to return a coroutine. @@ -352,7 +358,7 @@ def _call_sync( full_schema_reask: bool, *args, **kwargs, - ) -> ValidationOutcome[T]: + ) -> ValidationOutcome[OT]: instructions_obj = instructions or self.instructions prompt_obj = prompt or self.prompt msg_history_obj = msg_history or [] @@ -380,7 +386,7 @@ def _call_sync( full_schema_reask=full_schema_reask, ) guard_history, error_message = runner(prompt_params=prompt_params) - return ValidationOutcome[T].from_guard_history(guard_history, error_message) + return ValidationOutcome[OT].from_guard_history(guard_history, error_message) async def _call_async( self, @@ -394,7 +400,7 @@ async def _call_async( full_schema_reask: bool, *args, **kwargs, - ) -> ValidationOutcome[T]: + ) -> ValidationOutcome[OT]: """Call the LLM asynchronously and validate the output. Args: @@ -441,7 +447,7 @@ async def _call_async( guard_history, error_message = await runner.async_run( prompt_params=prompt_params ) - return ValidationOutcome[T].from_guard_history(guard_history, error_message) + return ValidationOutcome[OT].from_guard_history(guard_history, error_message) def __repr__(self): return f"Guard(RAIL={self.rail})" @@ -460,7 +466,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> ValidationOutcome[T]: + ) -> ValidationOutcome[OT]: ... @overload @@ -474,7 +480,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Awaitable[ValidationOutcome[T]]: + ) -> Awaitable[ValidationOutcome[OT]]: ... @overload @@ -488,7 +494,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> ValidationOutcome[T]: + ) -> ValidationOutcome[OT]: ... def parse( @@ -501,7 +507,7 @@ def parse( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Union[ValidationOutcome[T], Awaitable[ValidationOutcome[T]]]: + ) -> Union[ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]]]: """Alternate flow to using Guard where the llm_output is known. Args: @@ -569,7 +575,7 @@ def _sync_parse( full_schema_reask: bool, *args, **kwargs, - ) -> ValidationOutcome[T]: + ) -> ValidationOutcome[OT]: """Alternate flow to using Guard where the llm_output is known. Args: @@ -601,7 +607,7 @@ def _sync_parse( guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( guard_history.validated_output ) - validation_outcome = ValidationOutcome[T].from_guard_history( + validation_outcome = ValidationOutcome[OT].from_guard_history( guard_history, error_message ) return validation_outcome @@ -616,7 +622,7 @@ async def _async_parse( full_schema_reask: bool, *args, **kwargs, - ) -> ValidationOutcome[T]: + ) -> ValidationOutcome[OT]: """Alternate flow to using Guard where the llm_output is known. Args: @@ -650,4 +656,4 @@ async def _async_parse( guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( guard_history.validated_output ) - return ValidationOutcome[T].from_guard_history(guard_history, error_message) + return ValidationOutcome[OT].from_guard_history(guard_history, error_message) diff --git a/guardrails/rail.py b/guardrails/rail.py index d7b1913b8..6cfa9ec5c 100644 --- a/guardrails/rail.py +++ b/guardrails/rail.py @@ -43,9 +43,9 @@ class Rail: @property def output_type(self): if isinstance(self.output_schema, StringSchema): - return str + return 'str' else: - return Dict + return 'dict' @classmethod def from_pydantic( From 7ac019e85a567ccc64acff9585dc5a186050525b Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Thu, 2 Nov 2023 13:42:53 -0500 Subject: [PATCH 24/52] fix other type issues --- guardrails/applications/text2sql.py | 36 ++-- guardrails/guard.py | 8 +- guardrails/run.py | 250 +++++++++++++--------------- guardrails/validators.py | 16 +- 4 files changed, 148 insertions(+), 162 deletions(-) diff --git a/guardrails/applications/text2sql.py b/guardrails/applications/text2sql.py index ce93e0683..71e53df75 100644 --- a/guardrails/applications/text2sql.py +++ b/guardrails/applications/text2sql.py @@ -2,7 +2,7 @@ import json import os from string import Template -from typing import Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Type, cast import openai @@ -90,7 +90,7 @@ def __init__( """ self.example_formatter = example_formatter - self.llm_api = llm_api + self.llm_api: Callable = llm_api self.llm_api_kwargs = llm_api_kwargs or {"max_tokens": 512} # Initialize the SQL driver. @@ -185,18 +185,20 @@ def __call__(self, text: str) -> Optional[str]: "Async API is not supported in Text2SQL application. " "Please use a synchronous API." ) - - try: - output = self.guard( - self.llm_api, - prompt_params={ - "nl_instruction": text, - "examples": similar_examples_prompt, - "db_info": str(self.sql_schema), - }, - **self.llm_api_kwargs, - ).validated_output["generated_sql"] - except TypeError: - output = None - - return output + else: + try: + response = self.guard( + self.llm_api, + prompt_params={ + "nl_instruction": text, + "examples": similar_examples_prompt, + "db_info": str(self.sql_schema), + }, + **self.llm_api_kwargs, + ) + validated_output: Dict = cast(Dict, response.validated_output) + output = validated_output["generated_sql"] + except TypeError: + output = None + + return output diff --git a/guardrails/guard.py b/guardrails/guard.py index 98bcd1a86..363740177 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -239,7 +239,7 @@ def from_string( @overload def __call__( self, - llm_api: Callable[[Any], Awaitable[Any]], + llm_api: Callable, prompt_params: Optional[Dict] = None, num_reasks: Optional[int] = None, prompt: Optional[str] = None, @@ -249,13 +249,13 @@ def __call__( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> Awaitable[ValidationOutcome[OT]]: + ) -> ValidationOutcome[OT]: ... @overload def __call__( self, - llm_api: Callable, + llm_api: Callable[[Any], Awaitable[Any]], prompt_params: Optional[Dict] = None, num_reasks: Optional[int] = None, prompt: Optional[str] = None, @@ -265,7 +265,7 @@ def __call__( full_schema_reask: Optional[bool] = None, *args, **kwargs, - ) -> ValidationOutcome[OT]: + ) -> Awaitable[ValidationOutcome[OT]]: ... def __call__( diff --git a/guardrails/run.py b/guardrails/run.py index b07f83318..c2be036b9 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -22,7 +22,6 @@ actions_logger = logging.getLogger(f"{__name__}.actions") add_destinations(actions_logger.debug) - class Runner: """Runner class that calls an LLM API with a prompt, and performs input and output validation. @@ -106,22 +105,9 @@ def _reset_guard_history(self): self.guard_history = GuardHistory(history=[]) self.guard_state.push(self.guard_history) - def handle_error(func): - def wrapper(self, *args, **kwargs): - error_message = None - try: - guard_history = func(self, *args, **kwargs) - except Exception as e: - error_message = str(e) - guard_history = self.guard_history - return guard_history, error_message - - return wrapper - - @handle_error def __call__( self, prompt_params: Optional[Dict] = None - ) -> Tuple[GuardHistory, str]: + ) -> Tuple[GuardHistory, Optional[str]]: """Execute the runner by repeatedly calling step until the reask budget is exhausted. @@ -132,69 +118,72 @@ def __call__( Returns: The guard history. """ - if prompt_params is None: - prompt_params = {} + error_message = None + try: + if prompt_params is None: + prompt_params = {} - # check if validator requirements are fulfilled - missing_keys = verify_metadata_requirements( - self.metadata, self.output_schema.root_datatype - ) - if missing_keys: - raise ValueError( - f"Missing required metadata keys: {', '.join(missing_keys)}" + # check if validator requirements are fulfilled + missing_keys = verify_metadata_requirements( + self.metadata, self.output_schema.root_datatype ) + if missing_keys: + raise ValueError( + f"Missing required metadata keys: {', '.join(missing_keys)}" + ) - self._reset_guard_history() + self._reset_guard_history() - # Figure out if we need to include instructions in the prompt. - include_instructions = not ( - self.instructions is None and self.msg_history is None - ) - - with start_action( - action_type="run", - instructions=self.instructions, - prompt=self.prompt, - api=self.api, - input_schema=self.input_schema, - output_schema=self.output_schema, - num_reasks=self.num_reasks, - metadata=self.metadata, - ): - instructions, prompt, msg_history, input_schema, output_schema = ( - self.instructions, - self.prompt, - self.msg_history, - self.input_schema, - self.output_schema, + # Figure out if we need to include instructions in the prompt. + include_instructions = not ( + self.instructions is None and self.msg_history is None ) - for index in range(self.num_reasks + 1): - # Run a single step. - validated_output, reasks = self.step( - index=index, - api=self.api, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, - prompt_params=prompt_params, - input_schema=input_schema, - output_schema=output_schema, - output=self.output if index == 0 else None, - ) - # Loop again? - if not self.do_loop(index, reasks): - break - # Get new prompt and output schema. - prompt, instructions, output_schema, msg_history = self.prepare_to_loop( - reasks, - validated_output, - output_schema, - prompt_params=prompt_params, - include_instructions=include_instructions, + with start_action( + action_type="run", + instructions=self.instructions, + prompt=self.prompt, + api=self.api, + input_schema=self.input_schema, + output_schema=self.output_schema, + num_reasks=self.num_reasks, + metadata=self.metadata, + ): + instructions, prompt, msg_history, input_schema, output_schema = ( + self.instructions, + self.prompt, + self.msg_history, + self.input_schema, + self.output_schema, ) + for index in range(self.num_reasks + 1): + # Run a single step. + validated_output, reasks = self.step( + index=index, + api=self.api, + instructions=instructions, + prompt=prompt, + msg_history=msg_history, + prompt_params=prompt_params, + input_schema=input_schema, + output_schema=output_schema, + output=self.output if index == 0 else None, + ) - return self.guard_history + # Loop again? + if not self.do_loop(index, reasks): + break + # Get new prompt and output schema. + prompt, instructions, output_schema, msg_history = self.prepare_to_loop( + reasks, + validated_output, + output_schema, + prompt_params=prompt_params, + include_instructions=include_instructions, + ) + except Exception as e: + error_message = str(e) + return self.guard_history, error_message def step( self, @@ -476,7 +465,6 @@ def prepare_to_loop( msg_history = None # clear msg history for reasking return prompt, instructions, output_schema, msg_history - class AsyncRunner(Runner): def __init__( self, @@ -515,22 +503,9 @@ def __init__( ) self.api: Optional[AsyncPromptCallableBase] = api - def handle_error(func): - async def wrapper(self, *args, **kwargs): - error_message = None - try: - guard_history = await func(self, *args, **kwargs) - except Exception as e: - error_message = str(e) - guard_history = self.guard_history - return guard_history, error_message - - return wrapper - - @handle_error async def async_run( self, prompt_params: Optional[Dict] = None - ) -> Tuple[GuardHistory, str]: + ) -> Tuple[GuardHistory, Optional[str]]: """Execute the runner by repeatedly calling step until the reask budget is exhausted. @@ -541,62 +516,67 @@ async def async_run( Returns: The guard history. """ - if prompt_params is None: - prompt_params = {} - self._reset_guard_history() - - # check if validator requirements are fulfilled - missing_keys = verify_metadata_requirements( - self.metadata, self.output_schema.root_datatype - ) - if missing_keys: - raise ValueError( - f"Missing required metadata keys: {', '.join(missing_keys)}" - ) + error_message = None + try: + if prompt_params is None: + prompt_params = {} + self._reset_guard_history() - with start_action( - action_type="run", - instructions=self.instructions, - prompt=self.prompt, - api=self.api, - input_schema=self.input_schema, - output_schema=self.output_schema, - num_reasks=self.num_reasks, - metadata=self.metadata, - ): - instructions, prompt, msg_history, input_schema, output_schema = ( - self.instructions, - self.prompt, - self.msg_history, - self.input_schema, - self.output_schema, + # check if validator requirements are fulfilled + missing_keys = verify_metadata_requirements( + self.metadata, self.output_schema.root_datatype ) - for index in range(self.num_reasks + 1): - # Run a single step. - validated_output, reasks = await self.async_step( - index=index, - api=self.api, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, - prompt_params=prompt_params, - input_schema=input_schema, - output_schema=output_schema, - output=self.output if index == 0 else None, + if missing_keys: + raise ValueError( + f"Missing required metadata keys: {', '.join(missing_keys)}" ) - # Loop again? - if not self.do_loop(index, reasks): - break - # Get new prompt and output schema. - prompt, instructions, output_schema, msg_history = self.prepare_to_loop( - reasks, - validated_output, - output_schema, - prompt_params=prompt_params, + with start_action( + action_type="run", + instructions=self.instructions, + prompt=self.prompt, + api=self.api, + input_schema=self.input_schema, + output_schema=self.output_schema, + num_reasks=self.num_reasks, + metadata=self.metadata, + ): + instructions, prompt, msg_history, input_schema, output_schema = ( + self.instructions, + self.prompt, + self.msg_history, + self.input_schema, + self.output_schema, ) + for index in range(self.num_reasks + 1): + # Run a single step. + validated_output, reasks = await self.async_step( + index=index, + api=self.api, + instructions=instructions, + prompt=prompt, + msg_history=msg_history, + prompt_params=prompt_params, + input_schema=input_schema, + output_schema=output_schema, + output=self.output if index == 0 else None, + ) + + # Loop again? + if not self.do_loop(index, reasks): + break + # Get new prompt and output schema. + prompt, instructions, output_schema, msg_history = self.prepare_to_loop( + reasks, + validated_output, + output_schema, + prompt_params=prompt_params, + ) - return self.guard_history + except Exception as e: + error_message = str(e) + + return self.guard_history, error_message async def async_step( self, diff --git a/guardrails/validators.py b/guardrails/validators.py index 506516af0..87e4742f9 100644 --- a/guardrails/validators.py +++ b/guardrails/validators.py @@ -13,7 +13,7 @@ import string import warnings from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast import openai import rstr @@ -1338,6 +1338,7 @@ def _get_topics(self, text: str, topics: Optional[List[str]] = None) -> List[str guard = Guard.from_rail_string(spec) _, validated_output = guard(llm_api=self.llm_callable) + validated_output = cast(Dict, validated_output) return validated_output["topics"] def validate(self, value: Any, metadata: Dict) -> ValidationResult: @@ -1395,7 +1396,7 @@ def __init__( llm_callable if llm_callable else openai.ChatCompletion.create ) - def _selfeval(self, question: str, answer: str): + def _selfeval(self, question: str, answer: str) -> Dict: from guardrails import Guard spec = """ @@ -1416,13 +1417,15 @@ def _selfeval(self, question: str, answer: str): question=question, answer=answer, ) - guard = Guard.from_rail_string(spec) + guard = Guard[Dict].from_rail_string(spec) - return guard( + response = guard( self.llm_callable, max_tokens=10, temperature=0.1, - )[1] + ) + validated_output = cast(Dict, response.validated_output) + return validated_output def validate(self, value: Any, metadata: Dict) -> ValidationResult: if "question" not in metadata: @@ -1432,7 +1435,8 @@ def validate(self, value: Any, metadata: Dict) -> ValidationResult: question = metadata["question"] - relevant = self._selfeval(question, value)["relevant"] + self_evaluation: Dict = self._selfeval(question, value) + relevant = self_evaluation["relevant"] if relevant: return PassResult() From 0b6c73e258f3c4e6028f0984f6b2a8f9cd9fd93b Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Thu, 2 Nov 2023 13:46:09 -0500 Subject: [PATCH 25/52] autoformat --- guardrails/classes/__init__.py | 7 ++--- guardrails/classes/validation_outcome.py | 19 ++++++++----- guardrails/guard.py | 34 ++++++++++++------------ guardrails/rail.py | 4 +-- guardrails/run.py | 18 ++++++++++--- 5 files changed, 48 insertions(+), 34 deletions(-) diff --git a/guardrails/classes/__init__.py b/guardrails/classes/__init__.py index 7c040a4f4..4cf6f9680 100644 --- a/guardrails/classes/__init__.py +++ b/guardrails/classes/__init__.py @@ -1,7 +1,4 @@ -from guardrails.classes.validation_outcome import ValidationOutcome from guardrails.classes.output_type import OT +from guardrails.classes.validation_outcome import ValidationOutcome -__all__ = [ - "ValidationOutcome", - "OT" -] +__all__ = ["ValidationOutcome", "OT"] diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index 7f39ed6db..9c492f5dc 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -2,26 +2,25 @@ from pydantic import Field +from guardrails.classes.output_type import OT from guardrails.utils.logs_utils import ArbitraryModel, GuardHistory from guardrails.utils.reask_utils import ReAsk -from guardrails.classes.output_type import OT class ValidationOutcome(Generic[OT], ArbitraryModel): raw_llm_output: Optional[str] = Field( - description="The raw, unchanged output from the LLM call.", - default=None + description="The raw, unchanged output from the LLM call.", default=None ) validated_output: Optional[OT] = Field( description="The validated, and potentially fixed," " output from the LLM call after passing through validation.", - default=None + default=None, ) reask: Optional[ReAsk] = Field( description="If validation continuously fails and all allocated" " reasks are used, this field will contain the final reask that" " would have been sent to the LLM if additional reasks were available.", - default=None + default=None, ) validation_passed: bool = Field( description="A boolean to indicate whether or not" @@ -58,8 +57,14 @@ def from_guard_history( validation_passed=any_validations_failed, ) - def __iter__(self) -> Iterator[Union[Optional[str], Optional[OT], Optional[ReAsk], bool, Optional[str]]]: - as_tuple: Tuple[Optional[str], Optional[OT], Optional[ReAsk], bool, Optional[str]] = ( + def __iter__( + self, + ) -> Iterator[ + Union[Optional[str], Optional[OT], Optional[ReAsk], bool, Optional[str]] + ]: + as_tuple: Tuple[ + Optional[str], Optional[OT], Optional[ReAsk], bool, Optional[str] + ] = ( self.raw_llm_output, self.validated_output, self.reask, diff --git a/guardrails/guard.py b/guardrails/guard.py index 363740177..c17a6376e 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -19,7 +19,7 @@ from eliot import add_destinations, start_action from pydantic import BaseModel -from guardrails.classes import ValidationOutcome, OT +from guardrails.classes import OT, ValidationOutcome from guardrails.llm_providers import get_async_llm_ask, get_llm_ask from guardrails.prompt import Instructions, Prompt from guardrails.rail import Rail @@ -150,11 +150,7 @@ def configure( ) @classmethod - def from_rail( - cls, - rail_file: str, - num_reasks: Optional[int] = None - ): + def from_rail(cls, rail_file: str, num_reasks: Optional[int] = None): """Create a Schema from a `.rail` file. Args: @@ -165,16 +161,12 @@ def from_rail( An instance of the `Guard` class. """ rail = Rail.from_file(rail_file) - if rail.output_type == 'str': + if rail.output_type == "str": return cast(Guard[str], cls(rail=rail, num_reasks=num_reasks)) return cast(Guard[Dict], cls(rail=rail, num_reasks=num_reasks)) @classmethod - def from_rail_string( - cls, - rail_string: str, - num_reasks: Optional[int] = None - ): + def from_rail_string(cls, rail_string: str, num_reasks: Optional[int] = None): """Create a Schema from a `.rail` string. Args: @@ -185,7 +177,7 @@ def from_rail_string( An instance of the `Guard` class. """ rail = Rail.from_string(rail_string) - if rail.output_type == 'str': + if rail.output_type == "str": return cast(Guard[str], cls(rail=rail, num_reasks=num_reasks)) return cast(Guard[Dict], cls(rail=rail, num_reasks=num_reasks)) @@ -201,7 +193,9 @@ def from_pydantic( rail = Rail.from_pydantic( output_class=output_class, prompt=prompt, instructions=instructions ) - return cast(Guard[Dict], cls(rail, num_reasks=num_reasks, base_model=output_class)) + return cast( + Guard[Dict], cls(rail, num_reasks=num_reasks, base_model=output_class) + ) @classmethod def from_string( @@ -386,7 +380,9 @@ def _call_sync( full_schema_reask=full_schema_reask, ) guard_history, error_message = runner(prompt_params=prompt_params) - return ValidationOutcome[OT].from_guard_history(guard_history, error_message) + return ValidationOutcome[OT].from_guard_history( + guard_history, error_message + ) async def _call_async( self, @@ -447,7 +443,9 @@ async def _call_async( guard_history, error_message = await runner.async_run( prompt_params=prompt_params ) - return ValidationOutcome[OT].from_guard_history(guard_history, error_message) + return ValidationOutcome[OT].from_guard_history( + guard_history, error_message + ) def __repr__(self): return f"Guard(RAIL={self.rail})" @@ -656,4 +654,6 @@ async def _async_parse( guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( guard_history.validated_output ) - return ValidationOutcome[OT].from_guard_history(guard_history, error_message) + return ValidationOutcome[OT].from_guard_history( + guard_history, error_message + ) diff --git a/guardrails/rail.py b/guardrails/rail.py index 6cfa9ec5c..9edfb7c7d 100644 --- a/guardrails/rail.py +++ b/guardrails/rail.py @@ -43,9 +43,9 @@ class Rail: @property def output_type(self): if isinstance(self.output_schema, StringSchema): - return 'str' + return "str" else: - return 'dict' + return "dict" @classmethod def from_pydantic( diff --git a/guardrails/run.py b/guardrails/run.py index c2be036b9..80d4485b4 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -22,6 +22,7 @@ actions_logger = logging.getLogger(f"{__name__}.actions") add_destinations(actions_logger.debug) + class Runner: """Runner class that calls an LLM API with a prompt, and performs input and output validation. @@ -174,7 +175,12 @@ def __call__( if not self.do_loop(index, reasks): break # Get new prompt and output schema. - prompt, instructions, output_schema, msg_history = self.prepare_to_loop( + ( + prompt, + instructions, + output_schema, + msg_history, + ) = self.prepare_to_loop( reasks, validated_output, output_schema, @@ -465,6 +471,7 @@ def prepare_to_loop( msg_history = None # clear msg history for reasking return prompt, instructions, output_schema, msg_history + class AsyncRunner(Runner): def __init__( self, @@ -566,7 +573,12 @@ async def async_run( if not self.do_loop(index, reasks): break # Get new prompt and output schema. - prompt, instructions, output_schema, msg_history = self.prepare_to_loop( + ( + prompt, + instructions, + output_schema, + msg_history, + ) = self.prepare_to_loop( reasks, validated_output, output_schema, @@ -575,7 +587,7 @@ async def async_run( except Exception as e: error_message = str(e) - + return self.guard_history, error_message async def async_step( From 14038b449359a8ea8e66ae14eb07bef803c53ea0 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Thu, 2 Nov 2023 13:48:07 -0500 Subject: [PATCH 26/52] lint fixes --- guardrails/applications/text2sql.py | 2 +- guardrails/classes/output_type.py | 2 +- guardrails/classes/validation_outcome.py | 2 +- guardrails/rail.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/guardrails/applications/text2sql.py b/guardrails/applications/text2sql.py index 71e53df75..a1b9469cb 100644 --- a/guardrails/applications/text2sql.py +++ b/guardrails/applications/text2sql.py @@ -2,7 +2,7 @@ import json import os from string import Template -from typing import Any, Callable, Dict, Optional, Type, cast +from typing import Callable, Dict, Optional, Type, cast import openai diff --git a/guardrails/classes/output_type.py b/guardrails/classes/output_type.py index 1a8924ad2..8c1c26135 100644 --- a/guardrails/classes/output_type.py +++ b/guardrails/classes/output_type.py @@ -1,3 +1,3 @@ -from typing import Dict, TypeVar, Union +from typing import Dict, TypeVar OT = TypeVar("OT", str, Dict) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index 9c492f5dc..de51ed7fb 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -1,4 +1,4 @@ -from typing import Dict, Generic, Iterator, Optional, Tuple, Union, cast +from typing import Generic, Iterator, Optional, Tuple, Union, cast from pydantic import Field diff --git a/guardrails/rail.py b/guardrails/rail.py index 9edfb7c7d..fe9a42239 100644 --- a/guardrails/rail.py +++ b/guardrails/rail.py @@ -1,7 +1,7 @@ """Rail class.""" import warnings from dataclasses import dataclass -from typing import Dict, List, Optional, Type +from typing import List, Optional, Type from lxml import etree as ET from lxml.etree import Element, SubElement From eccabb08ebba51e407bca0539807a88de44dd60b Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Thu, 2 Nov 2023 15:56:18 -0500 Subject: [PATCH 27/52] test fixes --- guardrails/classes/__init__.py | 3 +- guardrails/classes/list_plus_plus.py | 11 +++ guardrails/guard.py | 17 ++-- guardrails/run.py | 5 +- guardrails/utils/logs_utils.py | 43 +++++++-- tests/__init__.py | 0 .../integration_tests/test_assets/__init__.py | 0 .../test_assets/fixtures/__init__.py | 69 ++++++++++++++ tests/integration_tests/test_async.py | 6 +- tests/integration_tests/test_datatypes.py | 10 +- tests/integration_tests/test_guard.py | 6 +- tests/integration_tests/test_pydantic.py | 2 +- tests/integration_tests/test_run.py | 1 + .../test_validator_type_deprecation.py | 8 +- tests/integration_tests/test_validators.py | 74 +++++++-------- tests/unit_tests/test_guard.py | 91 ++++++++++--------- tests/unit_tests/test_validator_suite.py | 6 +- tests/unit_tests/test_validators.py | 6 +- tests/unit_tests/validators/__init__.py | 0 19 files changed, 238 insertions(+), 120 deletions(-) create mode 100644 guardrails/classes/list_plus_plus.py create mode 100644 tests/__init__.py create mode 100644 tests/integration_tests/test_assets/__init__.py create mode 100644 tests/integration_tests/test_assets/fixtures/__init__.py create mode 100644 tests/unit_tests/validators/__init__.py diff --git a/guardrails/classes/__init__.py b/guardrails/classes/__init__.py index 4cf6f9680..db686adbe 100644 --- a/guardrails/classes/__init__.py +++ b/guardrails/classes/__init__.py @@ -1,4 +1,5 @@ +from guardrails.classes.list_plus_plus import ListPlusPlus from guardrails.classes.output_type import OT from guardrails.classes.validation_outcome import ValidationOutcome -__all__ = ["ValidationOutcome", "OT"] +__all__ = ["ListPlusPlus", "ValidationOutcome", "OT"] diff --git a/guardrails/classes/list_plus_plus.py b/guardrails/classes/list_plus_plus.py new file mode 100644 index 000000000..c244b162b --- /dev/null +++ b/guardrails/classes/list_plus_plus.py @@ -0,0 +1,11 @@ +class ListPlusPlus(list): + def __init__(self, *args): + list.__init__(self, args) + + def at(self, index: int): + value = None + try: + value = self[index] + except IndexError: + pass + return value \ No newline at end of file diff --git a/guardrails/guard.py b/guardrails/guard.py index c17a6376e..1aa6aad39 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -20,6 +20,7 @@ from pydantic import BaseModel from guardrails.classes import OT, ValidationOutcome +from guardrails.classes.list_plus_plus import ListPlusPlus from guardrails.llm_providers import get_async_llm_ask, get_llm_ask from guardrails.prompt import Instructions, Prompt from guardrails.rail import Rail @@ -61,7 +62,7 @@ def __init__( """Initialize the Guard.""" self.rail = rail self.num_reasks = num_reasks - self.guard_state = GuardState(all_histories=[]) + self.guard_state = GuardState(all_histories=ListPlusPlus()) self._reask_prompt = None self._reask_instructions = None self.base_model = base_model @@ -602,9 +603,10 @@ def _sync_parse( full_schema_reask=full_schema_reask, ) guard_history, error_message = runner(prompt_params=prompt_params) - guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( - guard_history.validated_output - ) + if (len(guard_history.history) > 0): + guard_history.history.at(-1).validated_output = sub_reasks_with_fixed_values( + guard_history.validated_output + ) validation_outcome = ValidationOutcome[OT].from_guard_history( guard_history, error_message ) @@ -651,9 +653,10 @@ async def _async_parse( guard_history, error_message = await runner.async_run( prompt_params=prompt_params ) - guard_history.history[-1].validated_output = sub_reasks_with_fixed_values( - guard_history.validated_output - ) + if (len(guard_history.history) > 0): + guard_history.history.at(-1).validated_output = sub_reasks_with_fixed_values( + guard_history.validated_output + ) return ValidationOutcome[OT].from_guard_history( guard_history, error_message ) diff --git a/guardrails/run.py b/guardrails/run.py index 80d4485b4..204b7830c 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -4,6 +4,7 @@ from eliot import add_destinations, start_action from pydantic import BaseModel +from guardrails.classes import ListPlusPlus from guardrails.datatypes import verify_metadata_requirements from guardrails.llm_providers import AsyncPromptCallableBase, PromptCallableBase @@ -97,13 +98,13 @@ def __init__( self.output = output self.reask_prompt = reask_prompt self.reask_instructions = reask_instructions - self.guard_history = guard_history or GuardHistory(history=[]) + self.guard_history = guard_history or GuardHistory(history=ListPlusPlus()) self.base_model = base_model self.full_schema_reask = full_schema_reask def _reset_guard_history(self): """Reset the guard history.""" - self.guard_history = GuardHistory(history=[]) + self.guard_history = GuardHistory(history=ListPlusPlus()) self.guard_state.push(self.guard_history) def __call__( diff --git a/guardrails/utils/logs_utils.py b/guardrails/utils/logs_utils.py index 3ea66fe3f..8b38b68d4 100644 --- a/guardrails/utils/logs_utils.py +++ b/guardrails/utils/logs_utils.py @@ -7,6 +7,7 @@ from rich.pretty import pretty_repr from rich.table import Table from rich.tree import Tree +from guardrails.classes.list_plus_plus import ListPlusPlus from guardrails.prompt import Instructions, Prompt from guardrails.utils.reask_utils import ( @@ -136,14 +137,18 @@ def create_msg_history_table( class GuardHistory(ArbitraryModel): - history: List[GuardLogs] + history: ListPlusPlus[GuardLogs] = Field(default_factory=ListPlusPlus) + + def __init__(self, history: ListPlusPlus[GuardLogs]): + super().__init__() + self.history = ListPlusPlus(*history) def push(self, guard_log: GuardLogs) -> None: if len(self.history) > 0: - last_log = self.history[-1] + last_log = self.history.at(-1) guard_log._previous_logs = last_log - self.history += [guard_log] + self.history.append(guard_log) @property def tree(self) -> Tree: @@ -153,20 +158,36 @@ def tree(self) -> Tree: tree.add(Panel(log.rich_group, title=f"Step {i}")) return tree + @property + def last_entry(self) -> Union[GuardLogs, None]: + return self.history.at(-1) + @property def validated_output(self) -> Union[str, Dict, ReAsk, None]: """Returns the latest validated output.""" - return self.history[-1].validated_output + return ( + self.last_entry.validated_output + if self.last_entry is not None + else None + ) @property def output(self) -> Optional[str]: """Returns the latest output.""" - return self.history[-1].output + return ( + self.last_entry.output + if self.last_entry is not None + else None + ) @property def output_as_dict(self) -> Optional[Dict]: """Returns the latest output as a dict.""" - return self.history[-1].parsed_output + return ( + self.last_entry.parsed_output + if self.last_entry is not None + else None + ) @property def failed_validations(self) -> List[List[ReAsk]]: @@ -175,17 +196,21 @@ def failed_validations(self) -> List[List[ReAsk]]: class GuardState(ArbitraryModel): - all_histories: List[GuardHistory] = Field(default_factory=list) + all_histories: ListPlusPlus[GuardHistory] = Field(default_factory=ListPlusPlus) + + def __init__(self, all_histories: ListPlusPlus[GuardHistory]): + super().__init__() + self.all_histories = ListPlusPlus(*all_histories) def push(self, guard_history: GuardHistory) -> None: - self.all_histories += [guard_history] + self.all_histories.append(guard_history) @property def most_recent_call(self) -> Optional[GuardHistory]: """Returns the most recent call.""" if not len(self.all_histories): return None - return self.all_histories[-1] + return self.all_histories.at(-1) def update_response_by_path(output: dict, path: List[Any], value: Any) -> None: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration_tests/test_assets/__init__.py b/tests/integration_tests/test_assets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration_tests/test_assets/fixtures/__init__.py b/tests/integration_tests/test_assets/fixtures/__init__.py new file mode 100644 index 000000000..6b08e8913 --- /dev/null +++ b/tests/integration_tests/test_assets/fixtures/__init__.py @@ -0,0 +1,69 @@ +import pytest + + +@pytest.fixture(name="rail_spec") +def fixture_rail_spec(): + return """ + + + + + + + + + + +