Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validate hook #1310

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions jsonschema/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,12 @@ def __call__(
[referencing.jsonschema.Schema],
Iterable[tuple[str, Any]],
]

class ValidateHook(Protocol):
def __call__(
self,
is_valid: bool,
instance: Any,
schema: referencing.jsonschema.Schema,
) -> None:
...
4 changes: 3 additions & 1 deletion jsonschema/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# therefore, only import at type-checking time (to avoid circular references),
# but use `jsonschema` for any types which will otherwise not be resolvable
if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
from collections.abc import Iterable, Mapping, Sequence

import referencing.jsonschema

Expand Down Expand Up @@ -102,6 +102,8 @@ class Validator(Protocol):
#: A function which given a schema returns its ID.
ID_OF: _typing.id_of

VALIDATE_HOOKS: ClassVar[Sequence]

#: The schema that will be used to validate instances
schema: Mapping | bool

Expand Down
46 changes: 44 additions & 2 deletions jsonschema/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def create(
applicable_validators: _typing.ApplicableValidators = methodcaller(
"items",
),
validate_hooks: Sequence[_typing.ValidateHook] = (),
):
"""
Create a new validator class.
Expand Down Expand Up @@ -207,6 +208,16 @@ def create(
implement similar behavior, you can typically ignore this argument
and leave it at its default.

validate_hooks:

A list of callables, will be called after validate.

Each callable should take 4 arguments:

1. is valid or not
2. the instance
3. the schema

Returns:

a new `jsonschema.protocols.Validator` class
Expand All @@ -220,6 +231,10 @@ def create(
default=referencing.Specification.OPAQUE,
)

def _call_validate_hooks(is_valid, instance, schema):
for hook in validate_hooks:
hook(is_valid, instance, schema)

@define
class Validator:

Expand All @@ -228,6 +243,7 @@ class Validator:
TYPE_CHECKER = type_checker
FORMAT_CHECKER = format_checker_arg
ID_OF = staticmethod(id_of)
VALIDATE_HOOKS = list(validate_hooks) # noqa: RUF012

_APPLICABLE_VALIDATORS = applicable_validators
_validators = field(init=False, repr=False, eq=False)
Expand Down Expand Up @@ -368,6 +384,7 @@ def iter_errors(self, instance, _schema=None):
_schema, validators = self.schema, self._validators

if _schema is True:
_call_validate_hooks(True, instance, _schema)
return
elif _schema is False:
yield exceptions.ValidationError(
Expand All @@ -377,8 +394,10 @@ def iter_errors(self, instance, _schema=None):
instance=instance,
schema=_schema,
)
_call_validate_hooks(False, instance, _schema)
return

is_valid = True
for validator, k, v in validators:
errors = validator(self, v, instance, _schema) or ()
for error in errors:
Expand All @@ -392,7 +411,9 @@ def iter_errors(self, instance, _schema=None):
)
if k not in {"if", "$ref"}:
error.schema_path.appendleft(k)
is_valid = False
yield error
_call_validate_hooks(is_valid, instance, _schema)

def descend(
self,
Expand All @@ -403,6 +424,7 @@ def descend(
resolver=None,
):
if schema is True:
_call_validate_hooks(True, instance, schema)
return
elif schema is False:
yield exceptions.ValidationError(
Expand All @@ -412,6 +434,7 @@ def descend(
instance=instance,
schema=schema,
)
_call_validate_hooks(False, instance, schema)
return

if self._ref_resolver is not None:
Expand All @@ -423,6 +446,7 @@ def descend(
)
evolved = self.evolve(schema=schema, _resolver=resolver)

is_valid = True
for k, v in applicable_validators(schema):
validator = evolved.VALIDATORS.get(k)
if validator is None:
Expand All @@ -444,10 +468,15 @@ def descend(
error.path.appendleft(path)
if schema_path is not None:
error.schema_path.appendleft(schema_path)
is_valid = False
yield error
_call_validate_hooks(is_valid, instance, schema)

def validate(self, *args, **kwargs):
for error in self.iter_errors(*args, **kwargs):
def validate(self, instance, _schema=None):
for error in self.iter_errors(instance, _schema):
if _schema is None:
_schema = self.schema
_call_validate_hooks(False, instance, _schema)
raise error

def is_type(self, instance, type):
Expand Down Expand Up @@ -498,6 +527,8 @@ def is_valid(self, instance, _schema=None):
self = self.evolve(schema=_schema)

error = next(self.iter_errors(instance), None)
if error is not None:
_call_validate_hooks(False, instance, self.schema)
return error is None

evolve_fields = [
Expand All @@ -520,6 +551,7 @@ def extend(
version=None,
type_checker=None,
format_checker=None,
validate_hooks=(),
):
"""
Create a new validator class by extending an existing one.
Expand Down Expand Up @@ -565,6 +597,12 @@ def extend(
If unprovided, the format checker of the extended
`jsonschema.protocols.Validator` will be carried along.

validate_hooks (collections.abc.Sequence):

a list of new validate hooks to extend with, whose
structure is as in `create`.


Returns:

a new `jsonschema.protocols.Validator` class extending the one
Expand All @@ -584,6 +622,9 @@ def extend(
all_validators = dict(validator.VALIDATORS)
all_validators.update(validators)

all_validate_hooks = list(validator.VALIDATE_HOOKS)
all_validate_hooks.extend(validate_hooks)

if type_checker is None:
type_checker = validator.TYPE_CHECKER
if format_checker is None:
Expand All @@ -596,6 +637,7 @@ def extend(
format_checker=format_checker,
id_of=validator.ID_OF,
applicable_validators=validator._APPLICABLE_VALIDATORS,
validate_hooks=all_validate_hooks,
)


Expand Down
Loading