From 5407752daf793bf7dfac0db2a75ad389f348015d Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Thu, 27 Jun 2024 13:55:11 -0400 Subject: [PATCH] Generate JSON schema (#16) This schema can be used to validate JSON data. --- .pre-commit-config.yaml | 13 ++++++- ci/generate_json.py | 2 +- ci/generate_json_schema.py | 20 ++++++++++ schemas/rapids-metadata-v1.json | 67 +++++++++++++++++++++++++++++++++ src/rapids_metadata/json.py | 43 +++++++++++++-------- src/rapids_metadata/metadata.py | 52 +++++++++++++++++++++---- tests/test_json.py | 30 ++++++++++++--- 7 files changed, 196 insertions(+), 31 deletions(-) create mode 100755 ci/generate_json_schema.py create mode 100644 schemas/rapids-metadata-v1.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3c6de3e..808a2de 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,11 +7,13 @@ repos: - id: trailing-whitespace exclude: | (?x) - rapids-metadata[.]json$ + rapids-metadata[.]json$| + schemas/rapids-metadata-v[0-9]+[.]json$ - id: end-of-file-fixer exclude: | (?x) - rapids-metadata[.]json$ + rapids-metadata[.]json$| + schemas/rapids-metadata-v[0-9]+[.]json$ - repo: https://github.com/rapidsai/dependency-file-generator rev: v1.13.11 hooks: @@ -53,6 +55,13 @@ repos: pass_filenames: false additional_dependencies: - pydantic + - id: generate-json-schema + name: generate-json-schema + entry: ./ci/generate_json_schema.py + language: python + pass_filenames: false + additional_dependencies: + - pydantic default_language_version: python: python3 diff --git a/ci/generate_json.py b/ci/generate_json.py index 80bdaea..177d9c7 100755 --- a/ci/generate_json.py +++ b/ci/generate_json.py @@ -7,7 +7,7 @@ repo_root = os.path.join(os.path.dirname(__file__), "..") sys.path.append(os.path.join(repo_root, "src")) -from rapids_metadata import json as rapids_json # noqa: E402 +import rapids_metadata.json as rapids_json # noqa: E402 if __name__ == "__main__": rapids_json.main( diff --git a/ci/generate_json_schema.py b/ci/generate_json_schema.py new file mode 100755 index 0000000..db8738e --- /dev/null +++ b/ci/generate_json_schema.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, NVIDIA CORPORATION. + +import os.path +import sys + +repo_root = os.path.join(os.path.dirname(__file__), "..") +sys.path.append(os.path.join(repo_root, "src")) + +import rapids_metadata.json as rapids_json # noqa: E402 + +if __name__ == "__main__": + rapids_json.main( + [ + "--output", + os.path.join(repo_root, "schemas/rapids-metadata-v1.json"), + "--pretty", + "--schema", + ] + ) diff --git a/schemas/rapids-metadata-v1.json b/schemas/rapids-metadata-v1.json new file mode 100644 index 0000000..b573096 --- /dev/null +++ b/schemas/rapids-metadata-v1.json @@ -0,0 +1,67 @@ +{ + "$defs": { + "RAPIDSPackage": { + "description": "Package published by a RAPIDS repository. Includes both Python packages and Conda packages.", + "properties": { + "has_cuda_suffix": { + "default": true, + "description": "Whether or not the package has a CUDA suffix.", + "title": "Has Cuda Suffix", + "type": "boolean" + }, + "publishes_prereleases": { + "default": true, + "description": "Whether or not the package publishes prereleases.", + "title": "Publishes Prereleases", + "type": "boolean" + } + }, + "title": "RAPIDSPackage", + "type": "object" + }, + "RAPIDSRepository": { + "description": "RAPIDS Git repository. Can publish more than one package.", + "properties": { + "packages": { + "additionalProperties": { + "$ref": "#/$defs/RAPIDSPackage" + }, + "description": "Dictionary of packages in this repository by name.", + "title": "Packages", + "type": "object" + } + }, + "title": "RAPIDSRepository", + "type": "object" + }, + "RAPIDSVersion": { + "description": "Version of RAPIDS, which contains many Git repositories.", + "properties": { + "repositories": { + "additionalProperties": { + "$ref": "#/$defs/RAPIDSRepository" + }, + "description": "Dictionary of repositories in this version by name.", + "title": "Repositories", + "type": "object" + } + }, + "title": "RAPIDSVersion", + "type": "object" + } + }, + "$id": "https://raw.githubusercontent.com/rapidsai/rapids-metadata/main/schemas/rapids-metadata-v1.json", + "description": "All RAPIDS metadata.", + "properties": { + "versions": { + "additionalProperties": { + "$ref": "#/$defs/RAPIDSVersion" + }, + "description": "Dictionary of RAPIDS versions by . version string.", + "title": "Versions", + "type": "object" + } + }, + "title": "RAPIDSMetadata", + "type": "object" +} diff --git a/src/rapids_metadata/json.py b/src/rapids_metadata/json.py index e83dc59..8d0ec9a 100644 --- a/src/rapids_metadata/json.py +++ b/src/rapids_metadata/json.py @@ -16,7 +16,7 @@ import json import os import sys -from typing import Union +from typing import Any, TextIO, Union from pydantic import TypeAdapter @@ -41,6 +41,11 @@ def main(argv: Union[list[str], None] = None): action="store_true", help="Output all versions, ignoring local VERSION file", ) + parser.add_argument( + "--schema", + action="store_true", + help="Output a JSON schema for the data instead of the data itself", + ) parser.add_argument( "--pretty", action="store_true", help="Pretty-print JSON output" ) @@ -52,21 +57,10 @@ def main(argv: Union[list[str], None] = None): ) parsed = parser.parse_args(argv) - metadata = ( - all_metadata - if parsed.all_versions - else RAPIDSMetadata( - versions={ - get_rapids_version(os.getcwd()): all_metadata.get_current_version( - os.getcwd() - ) - } - ) - ) - def write_file(f): + def write_file(data: dict[str, Any], f: TextIO): json.dump( - TypeAdapter(RAPIDSMetadata).dump_python(metadata), + data, f, sort_keys=True, separators=(",", ": ") if parsed.pretty else (",", ":"), @@ -75,11 +69,28 @@ def write_file(f): if parsed.pretty: f.write("\n") + type_adapter = TypeAdapter(RAPIDSMetadata) + if parsed.schema: + data = type_adapter.json_schema() + else: + metadata = ( + all_metadata + if parsed.all_versions + else RAPIDSMetadata( + versions={ + get_rapids_version(os.getcwd()): all_metadata.get_current_version( + os.getcwd() + ) + } + ) + ) + data = type_adapter.dump_python(metadata) + if parsed.output: with open(parsed.output, "w") as f: - write_file(f) + write_file(data, f) else: - write_file(sys.stdout) + write_file(data, sys.stdout) if __name__ == "__main__": diff --git a/src/rapids_metadata/metadata.py b/src/rapids_metadata/metadata.py index 1baf66f..7096937 100644 --- a/src/rapids_metadata/metadata.py +++ b/src/rapids_metadata/metadata.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field from os import PathLike from typing import Union +from pydantic import ConfigDict, Field +from pydantic.dataclasses import dataclass + from .rapids_version import get_rapids_version __all__ = [ @@ -28,18 +30,40 @@ @dataclass class RAPIDSPackage: - publishes_prereleases: bool = field(default=True) - has_cuda_suffix: bool = field(default=True) + ( + """Package published by a RAPIDS repository. Includes both Python packages """ + """and Conda packages.""" + ) + + publishes_prereleases: bool = Field( + default=True, + description="""Whether or not the package publishes prereleases.""", + ) + + has_cuda_suffix: bool = Field( + default=True, + description="""Whether or not the package has a CUDA suffix.""", + ) @dataclass class RAPIDSRepository: - packages: dict[str, RAPIDSPackage] = field(default_factory=dict) + """RAPIDS Git repository. Can publish more than one package.""" + + packages: dict[str, RAPIDSPackage] = Field( + default_factory=dict, + description="""Dictionary of packages in this repository by name.""", + ) @dataclass class RAPIDSVersion: - repositories: dict[str, RAPIDSRepository] = field(default_factory=dict) + """Version of RAPIDS, which contains many Git repositories.""" + + repositories: dict[str, RAPIDSRepository] = Field( + default_factory=dict, + description="""Dictionary of repositories in this version by name.""", + ) @property def all_packages(self) -> set[str]: @@ -68,9 +92,23 @@ def cuda_suffixed_packages(self) -> set[str]: } -@dataclass +@dataclass( + config=ConfigDict( + json_schema_extra={ + "$id": "https://raw.githubusercontent.com/rapidsai/rapids-metadata/main/schemas/rapids-metadata-v1.json", + }, + ) +) class RAPIDSMetadata: - versions: dict[str, RAPIDSVersion] = field(default_factory=dict) + """All RAPIDS metadata.""" + + versions: dict[str, RAPIDSVersion] = Field( + default_factory=dict, + description=( + """Dictionary of RAPIDS versions by . """ + """version string.""" + ), + ) def get_current_version( self, directory: Union[str, PathLike[str]] diff --git a/tests/test_json.py b/tests/test_json.py index 6b7d719..0287594 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -14,8 +14,9 @@ import contextlib import os.path +import re from textwrap import dedent -from typing import Generator +from typing import Generator, Union from unittest.mock import patch import pytest @@ -256,9 +257,22 @@ def test_metadata_encoder(unencoded, encoded): """ ), ), + ( + None, + ["--schema"], + re.compile( + r'"\$id":"https://raw.githubusercontent.com/rapidsai/rapids-metadata/main/schemas/rapids-metadata-v1.json"' + ), + ), ], ) -def test_main(capsys, tmp_path, version, args, expected_json): +def test_main( + capsys: pytest.CaptureFixture[str], + tmp_path: str, + version: Union[str, None], + args: list[str], + expected_json: Union[str, re.Pattern], +): mock_metadata = RAPIDSMetadata( versions={ "24.08": RAPIDSVersion( @@ -285,17 +299,23 @@ def test_main(capsys, tmp_path, version, args, expected_json): with open(os.path.join(tmp_path, "VERSION"), "w") as f: f.write(f"{version}\n") + def check_output(output: str): + if isinstance(expected_json, re.Pattern): + assert expected_json.search(output) + else: + assert output == expected_json + with set_cwd(tmp_path), patch("sys.argv", ["rapids-metadata-json", *args]), patch( "rapids_metadata.json.all_metadata", mock_metadata ): rapids_json.main() captured = capsys.readouterr() - assert captured.out == expected_json + check_output(captured.out) with set_cwd(tmp_path), patch("rapids_metadata.json.all_metadata", mock_metadata): rapids_json.main(args) captured = capsys.readouterr() - assert captured.out == expected_json + check_output(captured.out) with set_cwd(tmp_path), patch( "sys.argv", ["rapids-metadata-json", *args, "-o", "rapids-metadata.json"] @@ -305,4 +325,4 @@ def test_main(capsys, tmp_path, version, args, expected_json): written_json = f.read() captured = capsys.readouterr() assert captured.out == "" - assert written_json == expected_json + check_output(written_json)