Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support boolean JSON schemas #1015

Merged
merged 7 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ._pydantic import pydantic_to_json_schema
from ._subgrammar import lexeme, subgrammar

JSONSchema = Union[bool, Mapping[str, Any]]

def _to_compact_json(target: Any) -> str:
# See 'Compact Encoding':
Expand Down Expand Up @@ -150,8 +151,8 @@ def _gen_json_string(
def _gen_json_object(
lm,
*,
properties: Mapping[str, Any],
additional_properties: Union[bool, Mapping[str, Any]],
properties: Mapping[str, JSONSchema],
additional_properties: JSONSchema,
required: Sequence[str],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
Expand Down Expand Up @@ -206,16 +207,12 @@ def _gen_list(lm, *, elements: tuple[GrammarFunction, ...], required: tuple[bool
def _gen_json_array(
lm,
*,
prefix_items_schema: Sequence[Mapping[str, Any]],
item_schema: Union[bool, Mapping[str, Any]],
prefix_items_schema: Sequence[JSONSchema],
item_schema: JSONSchema,
min_items: int,
max_items: Optional[int],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
if item_schema is True:
# True means that anything goes
item_schema = {}

if len(prefix_items_schema) < min_items and item_schema is False:
raise ValueError(
f"PrefixItems has too few elements ({len(prefix_items_schema)}) to"
Expand Down Expand Up @@ -282,7 +279,7 @@ def _gen_json_array(
def _process_anyOf(
lm,
*,
anyof_list: Sequence[Mapping[str, Any]],
anyof_list: Sequence[JSONSchema],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
options = [_gen_json(json_schema=item, definitions=definitions) for item in anyof_list]
Expand Down Expand Up @@ -329,9 +326,14 @@ def _gen_json_any(lm):
@guidance(stateless=True)
def _gen_json(
lm,
json_schema: Mapping[str, Any],
json_schema: JSONSchema,
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
if json_schema is True:
json_schema = {}
elif json_schema is False:
raise ValueError("No valid JSON can be generated from a schema of `False`")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thoughts on raising error vs. returning empty string?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intuitively I feel like warning + empty string is more "correct", but I'm easily convinced otherwise.

Copy link
Collaborator Author

@hudson-ai hudson-ai Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The empty string isn't valid against the False schema. I think we need to raise an exception if someone wants to generate against a schema with no valid instances. Another "bad" schema by this standard is

{
  "type": "object",
  "properties": {
    "name": false
  },
  "required": ["name"]
  }

Note that raising an exception upon encountering a false subschema isn't actually correct though. See the following example:

{
 "type": "object",
 "properties": {
   "name": false
 },
 "required": []
 }

This should validate against any object that does not have the key "name". Will not support for now -- I think I'll open an issue against this just to track it...

Edit: issue #1018


validate_json_node_keys(json_schema)

if Keyword.ANYOF in json_schema:
Expand Down Expand Up @@ -403,7 +405,7 @@ def json(
*,
schema: Union[
None,
Mapping[str, Any],
JSONSchema,
Type["pydantic.BaseModel"],
"pydantic.TypeAdapter",
] = None,
Expand Down Expand Up @@ -457,20 +459,25 @@ def json(
If True, the generated JSON will be forced to be compact (no whitespace).
If False, output will be whitespace-flexible (i.e. decided by the model).
"""
if isinstance(schema, Mapping):
if schema is None:
# Default schema is empty, "anything goes" schema
# TODO: consider default being `{"type": "object"}`
schema = {}
elif isinstance(schema, (Mapping, bool)):
# Raises jsonschema.exceptions.SchemaError or ValueError
# if schema is not valid
jsonschema.validators.Draft202012Validator.check_schema(schema)
elif schema is None:
schema = {}
else:
elif isinstance(schema, pydantic.TypeAdapter) or (isinstance(schema, type) and issubclass(schema, pydantic.BaseModel)):
schema = pydantic_to_json_schema(schema)
else:
raise TypeError(f"Unsupported schema type: {type(schema)}")

definitions: Mapping[str, Callable[[], GrammarFunction]] = {}
for dk in DEFS_KEYS:
if dk in schema:
assert len(definitions) == 0, "Found duplicate definitions"
definitions = _build_definitions(schema[dk])
if isinstance(schema, Mapping):
for dk in DEFS_KEYS:
if dk in schema:
assert len(definitions) == 0, "Found duplicate definitions"
definitions = _build_definitions(schema[dk])

return lm + with_temperature(
subgrammar(
Expand All @@ -488,11 +495,11 @@ def json(


def _build_definitions(
raw_definitions: Mapping[str, Any]
raw_definitions: Mapping[str, JSONSchema]
) -> Mapping[str, Callable[[], GrammarFunction]]:
definitions: Dict[str, Callable[[], GrammarFunction]] = {}

def build_definition(json_schema: Mapping[str, Any]) -> Callable[[], GrammarFunction]:
def build_definition(json_schema: JSONSchema) -> Callable[[], GrammarFunction]:
@guidance(stateless=True, dedent=False, cache=True)
def closure(lm):
return lm + _gen_json(json_schema=json_schema, definitions=definitions)
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/library/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2218,3 +2218,36 @@ def test_all_required_properties_doesnt_blow_up(self, num_properties):
HITS_MAGIC_NUMBER = 1
expected_hits = 0
assert cache_info.hits <= expected_hits + HITS_MAGIC_NUMBER

class TestBooleanSchema:
@pytest.mark.parametrize(
"target_obj",
[
123,
"hello",
[1, 2, 3],
{"a": 1},
None,
[{"a": 1}],
{"a": [1, 2, 3]},
{"a": {"b": 1}},
False,
True
],
)
def test_true_schema(self, target_obj):
# should be the same as an empty schema
schema_obj = True
generate_and_check(target_obj, schema_obj)

@pytest.mark.parametrize(
"schema_obj",
[
False,
{"type": "object", "properties": {"a": False}},
]
)
def test_false_schema(self, schema_obj):
with pytest.raises(ValueError) as ve:
gen_json(schema=schema_obj)
assert ve.value.args[0] == "No valid JSON can be generated from a schema of `False`"
Loading