From c91ca9e26f225fdf3d07414a18a75d954aa01652 Mon Sep 17 00:00:00 2001 From: kaczmarj Date: Tue, 20 Jun 2023 15:14:52 -0400 Subject: [PATCH 01/10] remove zoo schema from top-level (now in schemas dir) --- wsinfer-zoo-registry.schema.json | 34 -------------------------------- 1 file changed, 34 deletions(-) delete mode 100644 wsinfer-zoo-registry.schema.json diff --git a/wsinfer-zoo-registry.schema.json b/wsinfer-zoo-registry.schema.json deleted file mode 100644 index e84897f..0000000 --- a/wsinfer-zoo-registry.schema.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "$schema": "http://json-schema.org/draft-04/schema#", - "title": "Schema for WSInfer Model Zoo registry file (wsinfer-zoo-registry.json)", - "type": "object", - "properties": { - "models": { - "type": "object", - "additionalProperties": { - "type": "object", - "properties": { - "description": { - "type": "string" - }, - "hf_repo_id": { - "type": "string" - }, - "hf_revision": { - "type": "string" - } - }, - "required": [ - "description", - "hf_repo_id", - "hf_revision" - ], - "additionalProperties": false - } - } - }, - "required": [ - "models" - ], - "additionalProperties": false -} From e1ce5dd5a50b5f9ca84fb3805d83c9faaba33662 Mon Sep 17 00:00:00 2001 From: kaczmarj Date: Tue, 20 Jun 2023 15:15:13 -0400 Subject: [PATCH 02/10] add jsonschema and tabulate + add schemas as package data --- setup.cfg | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.cfg b/setup.cfg index 2a14ed0..b6a85e8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,9 @@ python_requires = >= 3.7 install_requires = click>=8.0,<9 huggingface_hub + jsonschema requests + tabulate [options.extras_require] dev = @@ -47,6 +49,9 @@ dev = console_scripts = wsinfer_zoo = wsinfer_zoo.cli:cli +[options.package_data] +wsinfer = + schemas/*.json [flake8] max-line-length = 88 From 18aa5b667c308f98974fe10fa6fc21497b951107 Mon Sep 17 00:00:00 2001 From: kaczmarj Date: Tue, 20 Jun 2023 15:15:33 -0400 Subject: [PATCH 03/10] use new schemas + add validate cli + prettify ls cli output --- wsinfer_zoo/cli.py | 96 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 16 deletions(-) diff --git a/wsinfer_zoo/cli.py b/wsinfer_zoo/cli.py index b190488..d1618a9 100644 --- a/wsinfer_zoo/cli.py +++ b/wsinfer_zoo/cli.py @@ -2,11 +2,23 @@ import dataclasses import json +import sys from pathlib import Path import click +import jsonschema +import tabulate -from wsinfer_zoo.client import WSINFER_ZOO_REGISTRY_DEFAULT_PATH, ModelRegistry +from wsinfer_zoo.client import ( + MODEL_CONFIG_SCHEMA, + WSINFER_ZOO_REGISTRY_DEFAULT_PATH, + WSINFER_ZOO_SCHEMA, + InvalidModelConfiguration, + InvalidRegistryConfiguration, + ModelRegistry, +) + +_here = Path(__file__).parent.resolve() @click.group() @@ -24,46 +36,98 @@ def cli(ctx: click.Context, *, registry_file: Path): raise click.ClickException(f"registry file not found: {registry_file}") with open(registry_file) as f: d = json.load(f) + # Raise an error if validation fails. + try: + jsonschema.validate(instance=d, schema=WSINFER_ZOO_SCHEMA) + except jsonschema.ValidationError as e: + raise InvalidRegistryConfiguration( + "Registry schema is invalid. Please contact the developer by" + " creating a new issue on our GitHub page:" + " https://github.com/SBU-BMI/wsinfer-zoo/issues/new." + ) from e registry = ModelRegistry.from_dict(d) ctx.ensure_object(dict) ctx.obj["registry"] = registry @cli.command() -@click.option("--as-json", is_flag=True, help="Print as JSON") +@click.option("--json", "as_json", is_flag=True, help="Print as JSON lines") @click.pass_context def ls(ctx: click.Context, *, as_json: bool): - """List registered models.""" + """List registered models. + + If not a TTY, only model names are printed. If a TTY, a pretty table + of models is printed. + """ registry: ModelRegistry = ctx.obj["registry"] - if not as_json: - click.echo("\n".join(str(m) for m in registry.models)) - else: - for m in registry.models: + if as_json: + for m in registry.models.values(): click.echo(json.dumps(dataclasses.asdict(m))) + else: + if sys.stdout.isatty(): + info = [ + [m.name, m.description, m.hf_repo_id, m.hf_revision] + for m in registry.models.values() + ] + click.echo( + tabulate.tabulate( + info, + headers=["Name", "Description", "HF Repo ID", "Rev"], + tablefmt="grid", + maxcolwidths=[None, 24, 30, None], + ) + ) + else: + # You're being piped or redirected + click.echo("\n".join(str(m) for m in registry.models)) @cli.command() @click.option( - "--model-id", + "--model-name", required=True, - help="Number of the model to get. See `ls` to list model numbers", - type=int, + help="Number of the model to get. See `ls` to list model names", ) @click.pass_context -def get(ctx: click.Context, *, model_id: int): - """Retrieve the model and configuration. +def get(ctx: click.Context, *, model_name: str): + """Retrieve a model and its configuration. Outputs JSON with model configuration, path to the model, and origin of the model. - This downloads the pretrained model if necessary. + The pretrained model is downloaded to a cache and reused if it is already present. """ registry: ModelRegistry = ctx.obj["registry"] - if model_id not in registry.model_ids: + if model_name not in registry.models.keys(): raise click.ClickException( - f"'{model_id}' not found, available models are {registry.model_ids}. Use `wsinfer_zoo ls` to list all models." + f"'{model_name}' not found, available models are" + " {list(registry.models.keys())}. Use `wsinfer_zoo ls` to list all" + " models." ) - registered_model = registry.get_model_by_id(model_id) + registered_model = registry.get_model_by_name(model_name) model = registered_model.load_model() model_dict = dataclasses.asdict(model) model_json = json.dumps(model_dict) click.echo(model_json) + + +@cli.command() +@click.option( + "--input", + help="Config file to validate (default is standard input)", + type=click.File("r"), + default=sys.stdin, +) +def validate_config(*, input): + """Validate a model configuration file against the JSON schema.""" + try: + c = json.load(input) + except Exception as e: + raise click.ClickException(f"Unable to read JSON file. Original error: {e}") + + # Raise an error if the schema is not valid. + try: + jsonschema.validate(instance=c, schema=MODEL_CONFIG_SCHEMA) + except jsonschema.ValidationError as e: + raise InvalidModelConfiguration( + "The configuration is invalid. Please see the traceback above for details." + ) from e From a721969ecb0ebf6007d34735a7e3a2fbfb07340d Mon Sep 17 00:00:00 2001 From: kaczmarj Date: Tue, 20 Jun 2023 15:16:59 -0400 Subject: [PATCH 04/10] add schemas and modify code to adhere to those schemas --- wsinfer_zoo/client.py | 139 ++++++++++++++++++++++++++---------------- 1 file changed, 85 insertions(+), 54 deletions(-) diff --git a/wsinfer_zoo/client.py b/wsinfer_zoo/client.py index b7ef7fe..bfe3bf0 100644 --- a/wsinfer_zoo/client.py +++ b/wsinfer_zoo/client.py @@ -4,8 +4,9 @@ import json from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence +import jsonschema import requests from huggingface_hub import hf_hub_download @@ -23,21 +24,50 @@ HF_TORCHSCRIPT_NAME = "torchscript_model.pt" # URL to the latest model registry. -WSINFER_ZOO_REGISTRY_URL = "https://raw.githubusercontent.com/SBU-BMI/wsinfer-zoo/main/wsinfer-zoo-registry.json" +WSINFER_ZOO_REGISTRY_URL = "https://raw.githubusercontent.com/SBU-BMI/wsinfer-zoo/main/wsinfer-zoo-registry.json" # noqa # The path to the registry file. WSINFER_ZOO_REGISTRY_DEFAULT_PATH = Path.home() / ".wsinfer-zoo-registry.json" -@dataclasses.dataclass -class TransformConfiguration: - """Container for the transform configuration for a model. +_here = Path(__file__).parent.resolve() - This is stored in te 'transform' key of 'config.json'. - """ +# Load schema for model config JSON files. +MODEL_CONFIG_SCHEMA_PATH = _here / "schemas" / "model-config.schema.json" +if not MODEL_CONFIG_SCHEMA_PATH.exists(): + raise FileNotFoundError( + f"JSON schema for model configurations not found: {MODEL_CONFIG_SCHEMA_PATH}" + ) +with open(MODEL_CONFIG_SCHEMA_PATH) as f: + MODEL_CONFIG_SCHEMA = json.load(f) + +# Load schema for model zoo file. +WSINFER_ZOO_SCHEMA_PATH = _here / "schemas" / "wsinfer-zoo-registry.schema.json" +if not WSINFER_ZOO_SCHEMA_PATH.exists(): + raise FileNotFoundError( + f"JSON schema for wsinfer zoo not found: {WSINFER_ZOO_SCHEMA_PATH}" + ) +with open(WSINFER_ZOO_SCHEMA_PATH) as f: + WSINFER_ZOO_SCHEMA = json.load(f) + + +class WSInferZooException(Exception): + ... - resize_size: int - mean: Tuple[float, float, float] - std: Tuple[float, float, float] + +class InvalidRegistryConfiguration(WSInferZooException): + ... + + +class InvalidModelConfiguration(WSInferZooException): + ... + + +@dataclasses.dataclass +class TransformConfigurationItem: + """Container for one item in the 'transform' property of the model configuration.""" + + name: str + arguments: Dict[str, Any] @dataclasses.dataclass @@ -52,19 +82,25 @@ class ModelConfiguration: class_names: Sequence[str] patch_size_pixels: int spacing_um_px: float - transform: TransformConfiguration + transform: List[TransformConfigurationItem] @classmethod def from_dict(cls, config: Dict) -> "ModelConfiguration": - # TODO: add validation here... + try: + jsonschema.validate(config, schema=MODEL_CONFIG_SCHEMA) + except jsonschema.ValidationError as e: + raise InvalidModelConfiguration( + "Invalid model configuration. See traceback above for details." + ) from e num_classes = config["num_classes"] patch_size_pixels = config["patch_size_pixels"] spacing_um_px = config["spacing_um_px"] class_names = config["class_names"] - tdict = config["transform"] - transform = TransformConfiguration( - resize_size=tdict["resize_size"], mean=tdict["mean"], std=tdict["std"] - ) + transform_list: List[Dict[str, Any]] = config["transform"] + transform = [ + TransformConfigurationItem(name=t["name"], arguments=t["arguments"]) + for t in transform_list + ] return cls( num_classes=num_classes, patch_size_pixels=patch_size_pixels, @@ -77,8 +113,9 @@ def get_slide_patch_size(self, slide_spacing_um_px: float) -> int: """Get the size of the patches to extract from the slide to be compatible with the patch size and spacing the model expects. - The model expects images of a particular physical size. This can be calculated with - spacing_um_px * patch_size_pixels, and the results units are in micrometers (um). + The model expects images of a particular physical size. This can be calculated + with spacing_um_px * patch_size_pixels, and the results units are in + micrometers (um). The native spacing of a slide can be different than what the model expects, so patches should be extracted at a different size and rescaled to the pixel size @@ -89,6 +126,8 @@ def get_slide_patch_size(self, slide_spacing_um_px: float) -> int: @dataclasses.dataclass class HFInfo: + """Container for information on model's location on HuggingFace Hub.""" + repo_id: str revision: Optional[str] = None @@ -127,10 +166,10 @@ def load_torchscript_model_from_hf( class RegisteredModel: """Container with information about where to find a single model.""" + name: str description: str hf_repo_id: str hf_revision: str - model_id: int def load_model(self) -> Model: return load_torchscript_model_from_hf( @@ -138,51 +177,43 @@ def load_model(self) -> Model: ) def __str__(self) -> str: - return f"{self.model_id:02d} -> {self.description} ({self.hf_repo_id} @ {self.hf_revision})" + return ( + f"{self.name} -> {self.description} ({self.hf_repo_id}" + f" @ {self.hf_revision})" + ) @dataclasses.dataclass class ModelRegistry: """Registry of models that can be used with WSInfer.""" - models: List[RegisteredModel] - - def __post_init__(self): - if len(set(m.model_id for m in self.models)) != len(self.models): - raise ValueError("all model ids must be unique") + models: Dict[str, RegisteredModel] - @property - def model_ids(self) -> List[int]: - return [m.model_id for m in self.models] - - def get_model_by_id(self, model_id: int) -> RegisteredModel: - for m in self.models: - if m.model_id == model_id: - return m - raise ValueError(f"model not found with ID '{model_id}'.") + def get_model_by_name(self, name: str) -> RegisteredModel: + try: + return self.models[name] + except KeyError: + raise KeyError(f"model not found with name '{name}'.") @classmethod def from_dict(cls, config: Dict) -> "ModelRegistry": - assert isinstance(config, dict) - assert "models" in config.keys() - assert isinstance(config["models"], list) - assert config["models"] - - # Test that all model items have required keys. - for cm in config["models"]: - for key in ["description", "hf_repo_id", "hf_revision"]: - if key not in cm.keys(): - raise KeyError(f"required key '{key}' not found in model info") - - models = [ - RegisteredModel( - description=cm["description"], - hf_repo_id=cm["hf_repo_id"], - hf_revision=cm.get("hf_revision"), - model_id=cm.get("model_id", model_id), + """Create a new ModelRegistry instance from a dictionary.""" + try: + jsonschema.validate(instance=config, schema=WSINFER_ZOO_SCHEMA) + except jsonschema.ValidationError as e: + raise InvalidModelConfiguration( + "Model configuration is invalid. Read the traceback above for" + " more information about the case." + ) from e + models = { + name: RegisteredModel( + name=name, + description=kwds["description"], + hf_repo_id=kwds["hf_repo_id"], + hf_revision=kwds["hf_revision"], ) - for model_id, cm in enumerate(config["models"]) - ] + for name, kwds in config["models"].items() + } return cls(models=models) @@ -193,7 +224,7 @@ def _remote_registry_is_newer() -> bool: if not WSINFER_ZOO_REGISTRY_DEFAULT_PATH.exists(): return True - url = "https://api.github.com/repos/SBU-BMI/wsinfer-zoo/commits?path=wsinfer-zoo-registry.json&page=1&per_page=1" + url = "https://api.github.com/repos/SBU-BMI/wsinfer-zoo/commits?path=wsinfer-zoo-registry.json&page=1&per_page=1" # noqa resp = requests.get(url) if not resp.ok: raise requests.RequestException( From 1e9884fb2dd4aaf23fd0f66641f8deeed6c810ec Mon Sep 17 00:00:00 2001 From: kaczmarj Date: Tue, 20 Jun 2023 15:27:56 -0400 Subject: [PATCH 05/10] arguments is optional in TransformConfigurationItem --- wsinfer_zoo/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wsinfer_zoo/client.py b/wsinfer_zoo/client.py index bfe3bf0..8cec257 100644 --- a/wsinfer_zoo/client.py +++ b/wsinfer_zoo/client.py @@ -67,7 +67,7 @@ class TransformConfigurationItem: """Container for one item in the 'transform' property of the model configuration.""" name: str - arguments: Dict[str, Any] + arguments: Optional[Dict[str, Any]] @dataclasses.dataclass @@ -98,7 +98,7 @@ def from_dict(cls, config: Dict) -> "ModelConfiguration": class_names = config["class_names"] transform_list: List[Dict[str, Any]] = config["transform"] transform = [ - TransformConfigurationItem(name=t["name"], arguments=t["arguments"]) + TransformConfigurationItem(name=t["name"], arguments=t.get("arguments")) for t in transform_list ] return cls( From 5c9e44b0eac4a3815e96391c1aa4fc1aa26099df Mon Sep 17 00:00:00 2001 From: kaczmarj Date: Tue, 20 Jun 2023 22:47:21 -0400 Subject: [PATCH 06/10] add python typing support --- setup.cfg | 1 + wsinfer_zoo/py.typed | 0 2 files changed, 1 insertion(+) create mode 100644 wsinfer_zoo/py.typed diff --git a/setup.cfg b/setup.cfg index b6a85e8..c81a310 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,6 +51,7 @@ console_scripts = [options.package_data] wsinfer = + py.typed schemas/*.json [flake8] diff --git a/wsinfer_zoo/py.typed b/wsinfer_zoo/py.typed new file mode 100644 index 0000000..e69de29 From b2aef639bac7cca94706e9c39725e7dc87d46ba8 Mon Sep 17 00:00:00 2001 From: kaczmarj Date: Tue, 20 Jun 2023 22:47:51 -0400 Subject: [PATCH 07/10] add architecture to model config + add torchscript and weights models --- wsinfer_zoo/__init__.py | 30 ++++++++- wsinfer_zoo/cli.py | 3 +- wsinfer_zoo/client.py | 68 ++++++++++++++++++-- wsinfer_zoo/schemas/model-config.schema.json | 5 ++ 4 files changed, 98 insertions(+), 8 deletions(-) diff --git a/wsinfer_zoo/__init__.py b/wsinfer_zoo/__init__.py index d910837..fa9124a 100644 --- a/wsinfer_zoo/__init__.py +++ b/wsinfer_zoo/__init__.py @@ -1,9 +1,37 @@ import os as _os +import json +import jsonschema from wsinfer_zoo import _version -from wsinfer_zoo.client import _download_registry_if_necessary +from wsinfer_zoo.client import ( + _download_registry_if_necessary, + ModelRegistry, + WSINFER_ZOO_REGISTRY_DEFAULT_PATH, + WSINFER_ZOO_SCHEMA, + InvalidRegistryConfiguration, +) __version__ = _version.get_versions()["version"] if _os.getenv("WSINFER_ZOO_NO_UPDATE_REGISTRY") is None: _download_registry_if_necessary() + + +if not WSINFER_ZOO_REGISTRY_DEFAULT_PATH.exists(): + raise FileNotFoundError( + f"registry file not found: {WSINFER_ZOO_REGISTRY_DEFAULT_PATH}" + ) +with open(WSINFER_ZOO_REGISTRY_DEFAULT_PATH) as f: + d = json.load(f) +try: + jsonschema.validate(instance=d, schema=WSINFER_ZOO_SCHEMA) +except jsonschema.ValidationError as e: + raise InvalidRegistryConfiguration( + "Registry schema is invalid. Please contact the developer by" + " creating a new issue on our GitHub page:" + " https://github.com/SBU-BMI/wsinfer-zoo/issues/new." + ) from e + +registry = ModelRegistry.from_dict(d) + +del d, json, jsonschema diff --git a/wsinfer_zoo/cli.py b/wsinfer_zoo/cli.py index d1618a9..4522ed8 100644 --- a/wsinfer_zoo/cli.py +++ b/wsinfer_zoo/cli.py @@ -104,7 +104,7 @@ def get(ctx: click.Context, *, model_name: str): ) registered_model = registry.get_model_by_name(model_name) - model = registered_model.load_model() + model = registered_model.load_model_torchscript() model_dict = dataclasses.asdict(model) model_json = json.dumps(model_dict) click.echo(model_json) @@ -131,3 +131,4 @@ def validate_config(*, input): raise InvalidModelConfiguration( "The configuration is invalid. Please see the traceback above for details." ) from e + click.secho("Passed", fg="green") diff --git a/wsinfer_zoo/client.py b/wsinfer_zoo/client.py index 8cec257..a14c49e 100644 --- a/wsinfer_zoo/client.py +++ b/wsinfer_zoo/client.py @@ -22,6 +22,10 @@ HF_CONFIG_NAME = "config.json" # The name of the torchscript saved file. HF_TORCHSCRIPT_NAME = "torchscript_model.pt" +# The name of the safetensors file with weights. +HF_WEIGHTS_SAFETENSORS_NAME = "model.safetensors" +# The name of the pytorch (pickle) file with weights. +HF_WEIGHTS_PICKLE_NAME = "pytorch_model.bin" # URL to the latest model registry. WSINFER_ZOO_REGISTRY_URL = "https://raw.githubusercontent.com/SBU-BMI/wsinfer-zoo/main/wsinfer-zoo-registry.json" # noqa @@ -78,12 +82,17 @@ class ModelConfiguration: """ # FIXME: add fields like author, license, training data, publications, etc. + architecture: str num_classes: int class_names: Sequence[str] patch_size_pixels: int spacing_um_px: float transform: List[TransformConfigurationItem] + def __post_init__(self): + if len(self.class_names) != self.num_classes: + raise InvalidModelConfiguration() + @classmethod def from_dict(cls, config: Dict) -> "ModelConfiguration": try: @@ -92,6 +101,7 @@ def from_dict(cls, config: Dict) -> "ModelConfiguration": raise InvalidModelConfiguration( "Invalid model configuration. See traceback above for details." ) from e + architecture = config["architecture"] num_classes = config["num_classes"] patch_size_pixels = config["patch_size_pixels"] spacing_um_px = config["spacing_um_px"] @@ -102,6 +112,7 @@ def from_dict(cls, config: Dict) -> "ModelConfiguration": for t in transform_list ] return cls( + architecture=architecture, num_classes=num_classes, patch_size_pixels=patch_size_pixels, spacing_um_px=spacing_um_px, @@ -134,16 +145,33 @@ class HFInfo: @dataclasses.dataclass class Model: - """Container for the downloaded model path and config.""" - config: ModelConfiguration model_path: str + + +@dataclasses.dataclass +class HFModel(Model): + """Container for a model hosted on HuggingFace.""" + hf_info: HFInfo +@dataclasses.dataclass +class HFModelTorchScript(HFModel): + """Container for the downloaded model path and config.""" + + +# This is here to avoid confusion. We could have used Model directly with +# weights files, but then downstream it would not be clear whether the +# model has torchscript files or weights files. +@dataclasses.dataclass +class HFModelWeightsOnly(HFModel): + """Container for a model with weights only (not a TorchScript model).""" + + def load_torchscript_model_from_hf( repo_id: str, revision: Optional[str] = None -) -> Model: +) -> HFModelTorchScript: """Load a TorchScript model from HuggingFace.""" model_path = hf_hub_download(repo_id, HF_TORCHSCRIPT_NAME, revision=revision) @@ -155,10 +183,33 @@ def load_torchscript_model_from_hf( f"Expected configuration to be a dict but got {type(config_dict)}" ) config = ModelConfiguration.from_dict(config_dict) - del config_dict # FIXME: should we always load on cpu? hf_info = HFInfo(repo_id=repo_id, revision=revision) - model = Model(config=config, model_path=model_path, hf_info=hf_info) + model = HFModelTorchScript(config=config, model_path=model_path, hf_info=hf_info) + return model + + +def load_weights_from_hf( + repo_id: str, revision: Optional[str] = None, safetensors: bool = False +) -> HFModelWeightsOnly: + """Load model weights from HuggingFace (this is not TorchScript).""" + if safetensors: + model_path = hf_hub_download( + repo_id, HF_WEIGHTS_SAFETENSORS_NAME, revision=revision + ) + else: + model_path = hf_hub_download(repo_id, HF_WEIGHTS_PICKLE_NAME, revision=revision) + + config_path = hf_hub_download(repo_id, HF_CONFIG_NAME, revision=revision) + with open(config_path) as f: + config_dict = json.load(f) + if not isinstance(config_dict, dict): + raise TypeError( + f"Expected configuration to be a dict but got {type(config_dict)}" + ) + config = ModelConfiguration.from_dict(config_dict) + hf_info = HFInfo(repo_id=repo_id, revision=revision) + model = HFModelWeightsOnly(config=config, model_path=model_path, hf_info=hf_info) return model @@ -171,11 +222,16 @@ class RegisteredModel: hf_repo_id: str hf_revision: str - def load_model(self) -> Model: + def load_model_torchscript(self) -> HFModelTorchScript: return load_torchscript_model_from_hf( repo_id=self.hf_repo_id, revision=self.hf_revision ) + def load_model_weights(self, safetensors: bool = False) -> HFModelWeightsOnly: + return load_weights_from_hf( + repo_id=self.hf_repo_id, revision=self.hf_revision, safetensors=safetensors + ) + def __str__(self) -> str: return ( f"{self.name} -> {self.description} ({self.hf_repo_id}" diff --git a/wsinfer_zoo/schemas/model-config.schema.json b/wsinfer_zoo/schemas/model-config.schema.json index 81edefe..d0c814b 100644 --- a/wsinfer_zoo/schemas/model-config.schema.json +++ b/wsinfer_zoo/schemas/model-config.schema.json @@ -2,6 +2,10 @@ "$schema": "http://json-schema.org/draft-04/schema", "type": "object", "properties": { + "architecture": { + "type": "string", + "description": "Architecture of the model (Use TIMM names)" + }, "num_classes": { "type": "integer", "description": "The number of classes the model outputs", @@ -50,6 +54,7 @@ } }, "required": [ + "architecture", "num_classes", "patch_size_pixels", "spacing_um_px", From e0decba2addfeebb3be786310a7110fad48aef26 Mon Sep 17 00:00:00 2001 From: kaczmarj Date: Wed, 21 Jun 2023 11:24:07 -0400 Subject: [PATCH 08/10] add validate-repo cli + add funcs to validate JSON files --- wsinfer_zoo/__init__.py | 6 +- wsinfer_zoo/cli.py | 145 ++++++++++++++++--- wsinfer_zoo/client.py | 79 +++++----- wsinfer_zoo/schemas/model-config.schema.json | 8 + 4 files changed, 177 insertions(+), 61 deletions(-) diff --git a/wsinfer_zoo/__init__.py b/wsinfer_zoo/__init__.py index fa9124a..e0e565c 100644 --- a/wsinfer_zoo/__init__.py +++ b/wsinfer_zoo/__init__.py @@ -7,8 +7,8 @@ _download_registry_if_necessary, ModelRegistry, WSINFER_ZOO_REGISTRY_DEFAULT_PATH, - WSINFER_ZOO_SCHEMA, InvalidRegistryConfiguration, + validate_model_zoo_json, ) __version__ = _version.get_versions()["version"] @@ -24,8 +24,8 @@ with open(WSINFER_ZOO_REGISTRY_DEFAULT_PATH) as f: d = json.load(f) try: - jsonschema.validate(instance=d, schema=WSINFER_ZOO_SCHEMA) -except jsonschema.ValidationError as e: + validate_model_zoo_json(d) +except InvalidRegistryConfiguration as e: raise InvalidRegistryConfiguration( "Registry schema is invalid. Please contact the developer by" " creating a new issue on our GitHub page:" diff --git a/wsinfer_zoo/cli.py b/wsinfer_zoo/cli.py index 4522ed8..dc28f95 100644 --- a/wsinfer_zoo/cli.py +++ b/wsinfer_zoo/cli.py @@ -6,20 +6,23 @@ from pathlib import Path import click -import jsonschema +import huggingface_hub +import requests import tabulate from wsinfer_zoo.client import ( - MODEL_CONFIG_SCHEMA, WSINFER_ZOO_REGISTRY_DEFAULT_PATH, - WSINFER_ZOO_SCHEMA, + HF_CONFIG_NAME, + HF_WEIGHTS_SAFETENSORS_NAME, + HF_WEIGHTS_PICKLE_NAME, + HF_TORCHSCRIPT_NAME, InvalidModelConfiguration, InvalidRegistryConfiguration, ModelRegistry, + validate_model_zoo_json, + validate_config_json, ) -_here = Path(__file__).parent.resolve() - @click.group() @click.option( @@ -38,8 +41,8 @@ def cli(ctx: click.Context, *, registry_file: Path): d = json.load(f) # Raise an error if validation fails. try: - jsonschema.validate(instance=d, schema=WSINFER_ZOO_SCHEMA) - except jsonschema.ValidationError as e: + validate_model_zoo_json(d) + except InvalidRegistryConfiguration as e: raise InvalidRegistryConfiguration( "Registry schema is invalid. Please contact the developer by" " creating a new issue on our GitHub page:" @@ -111,14 +114,14 @@ def get(ctx: click.Context, *, model_name: str): @cli.command() -@click.option( - "--input", - help="Config file to validate (default is standard input)", - type=click.File("r"), - default=sys.stdin, -) +@click.argument("input", type=click.File("r")) def validate_config(*, input): - """Validate a model configuration file against the JSON schema.""" + """Validate a model configuration file against the JSON schema. + + INPUT is the config file to validate. + + Use a dash - to read standard input. + """ try: c = json.load(input) except Exception as e: @@ -126,9 +129,117 @@ def validate_config(*, input): # Raise an error if the schema is not valid. try: - jsonschema.validate(instance=c, schema=MODEL_CONFIG_SCHEMA) - except jsonschema.ValidationError as e: + validate_config_json(c) + except InvalidRegistryConfiguration as e: raise InvalidModelConfiguration( "The configuration is invalid. Please see the traceback above for details." ) from e - click.secho("Passed", fg="green") + click.secho("Configuration file is VALID", fg="green") + + +@cli.command() +@click.argument("huggingface_repo_id") +@click.option("-r", "--revision", help="Revision to validate", default="main") +def validate_repo(*, huggingface_repo_id: str, revision: str): + """Validate a repository on HuggingFace. + + This checks that the repository contains all of the necessary files and that + the configuration JSON file is valid. + """ + repo_id = huggingface_repo_id + del huggingface_repo_id + + try: + files_in_repo = list( + huggingface_hub.list_files_info(repo_id=repo_id, revision=revision) + ) + except huggingface_hub.utils.RepositoryNotFoundError: + click.secho( + f"Error: huggingface_repo_id '{repo_id}' not found on the HuggingFace Hub", + fg="red", + ) + sys.exit(1) + except huggingface_hub.utils.RevisionNotFoundError: + click.secho( + f"Error: revision {revision} not found for repository {repo_id}", + fg="red", + ) + sys.exit(1) + except requests.RequestException as e: + click.echo("Error with request: {e}") + click.echo("Please try again.") + sys.exit(2) + + file_info = {f.rfilename: f for f in files_in_repo} + + repo_url = f"https://huggingface.co/{repo_id}/tree/{revision}" + + filenames_and_help = [ + ( + HF_CONFIG_NAME, + "This file is a JSON file with the configuration of the model and includes" + " necessary information for how to apply this model to new data. You can" + " validate this file with the command 'wsinfer_zoo validate-config'.", + ), + ( + HF_TORCHSCRIPT_NAME, + "This file is a TorchScript representation of the model and can be made" + " with 'torch.jit.script(model)' followed by 'torch.jit.save'. This file" + " contains the pre-trained weights as well as a graph of the model." + " Importantly, it does not require a Python runtime to be used." + f" Then, upload the file to the HuggingFace model repo at {repo_url}", + ), + ( + HF_WEIGHTS_PICKLE_NAME, + "This file contains the weights of the pre-trained model in normal PyTorch" + " format. Once you have a trained model, create this file with" + f'\n\n torch.save(model.state_dict(), "{HF_WEIGHTS_PICKLE_NAME}")' + f"\n\n Then, upload the file to the HuggingFace model repo at {repo_url}", + ), + ( + HF_WEIGHTS_SAFETENSORS_NAME, + "This file contains the weights of the pre-trained model in SafeTensors" + " format. The advantage of this file is that it does not have security" + " concerns that Pickle files (pytorch default) have. To create the file:" + "\n\n from safetensors.torch import save_file" + f'\n save_file(model.state_dict(), "{HF_WEIGHTS_SAFETENSORS_NAME}")' + f"\n\n Then, upload the file to the HuggingFace model repo at {repo_url}", + ), + ] + + invalid = False + for name, help_msg in filenames_and_help: + if name not in file_info: + click.secho( + f"Required file '{name}' not found in HuggingFace model repo '{repo_id}'", + fg="red", + ) + click.echo(f" {help_msg}") + click.echo("-" * 40) + invalid = True + + if invalid: + click.secho( + f"Model repository {repo_id} is invalid. See above for details.", fg="red" + ) + sys.exit(1) + + config_path = huggingface_hub.hf_hub_download( + repo_id, HF_CONFIG_NAME, revision=revision + ) + with open(config_path) as f: + config_dict = json.load(f) + try: + validate_config_json(config_dict) + except InvalidModelConfiguration as e: + click.secho( + "Model configuration JSON file is invalid. Use 'wsinfer_zoo validate-config'" + " with the configuration file to debug this further.", + fg="red", + ) + click.secho( + f"Model repository {repo_id} is invalid. See above for details.", fg="red" + ) + sys.exit(1) + + click.secho(f"Repository {repo_id} is VALID.", fg="green") diff --git a/wsinfer_zoo/client.py b/wsinfer_zoo/client.py index a14c49e..fabe5be 100644 --- a/wsinfer_zoo/client.py +++ b/wsinfer_zoo/client.py @@ -10,15 +10,7 @@ import requests from huggingface_hub import hf_hub_download -# TODO: we might consider fetching available models from the web. -# from huggingface_hub import HfApi -# hf_api = HfApi() -# models = hf_api.list_models(author="kaczmarj") -# print("Found these models...") -# print(models) - -# FIXME: consider changing the name of this file because perhaps there will -# be multiple configs? Or add a key inside the json map 'wsinfer_config'. +# The name of the configuration JSON file. HF_CONFIG_NAME = "config.json" # The name of the torchscript saved file. HF_TORCHSCRIPT_NAME = "torchscript_model.pt" @@ -32,27 +24,8 @@ # The path to the registry file. WSINFER_ZOO_REGISTRY_DEFAULT_PATH = Path.home() / ".wsinfer-zoo-registry.json" - _here = Path(__file__).parent.resolve() -# Load schema for model config JSON files. -MODEL_CONFIG_SCHEMA_PATH = _here / "schemas" / "model-config.schema.json" -if not MODEL_CONFIG_SCHEMA_PATH.exists(): - raise FileNotFoundError( - f"JSON schema for model configurations not found: {MODEL_CONFIG_SCHEMA_PATH}" - ) -with open(MODEL_CONFIG_SCHEMA_PATH) as f: - MODEL_CONFIG_SCHEMA = json.load(f) - -# Load schema for model zoo file. -WSINFER_ZOO_SCHEMA_PATH = _here / "schemas" / "wsinfer-zoo-registry.schema.json" -if not WSINFER_ZOO_SCHEMA_PATH.exists(): - raise FileNotFoundError( - f"JSON schema for wsinfer zoo not found: {WSINFER_ZOO_SCHEMA_PATH}" - ) -with open(WSINFER_ZOO_SCHEMA_PATH) as f: - WSINFER_ZOO_SCHEMA = json.load(f) - class WSInferZooException(Exception): ... @@ -66,6 +39,41 @@ class InvalidModelConfiguration(WSInferZooException): ... +def validate_config_json(instance: object): + """Raise an error if the model configuration JSON is invalid. Otherwise return True.""" + schema_path = _here / "schemas" / "model-config.schema.json" + if not schema_path.exists(): + raise FileNotFoundError( + f"JSON schema for model configurations not found: {schema_path}" + ) + with open(schema_path) as f: + schema = json.load(f) + try: + jsonschema.validate(instance, schema=schema) + except jsonschema.ValidationError as e: + raise InvalidModelConfiguration( + "Invalid model configuration. See traceback above for details." + ) from e + + return True + + +def validate_model_zoo_json(instance: object): + """Raise an error if the model zoo registry JSON is invalid. Otherwise return True.""" + schema_path = _here / "schemas" / "wsinfer-zoo-registry.schema.json" + if not schema_path.exists(): + raise FileNotFoundError(f"JSON schema for wsinfer zoo not found: {schema_path}") + with open(schema_path) as f: + schema = json.load(f) + try: + jsonschema.validate(instance, schema=schema) + except jsonschema.ValidationError as e: + raise InvalidRegistryConfiguration( + "Invalid model zoo registry configuration. See traceback above for details." + ) from e + return True + + @dataclasses.dataclass class TransformConfigurationItem: """Container for one item in the 'transform' property of the model configuration.""" @@ -95,12 +103,7 @@ def __post_init__(self): @classmethod def from_dict(cls, config: Dict) -> "ModelConfiguration": - try: - jsonschema.validate(config, schema=MODEL_CONFIG_SCHEMA) - except jsonschema.ValidationError as e: - raise InvalidModelConfiguration( - "Invalid model configuration. See traceback above for details." - ) from e + validate_config_json(config) architecture = config["architecture"] num_classes = config["num_classes"] patch_size_pixels = config["patch_size_pixels"] @@ -254,13 +257,7 @@ def get_model_by_name(self, name: str) -> RegisteredModel: @classmethod def from_dict(cls, config: Dict) -> "ModelRegistry": """Create a new ModelRegistry instance from a dictionary.""" - try: - jsonschema.validate(instance=config, schema=WSINFER_ZOO_SCHEMA) - except jsonschema.ValidationError as e: - raise InvalidModelConfiguration( - "Model configuration is invalid. Read the traceback above for" - " more information about the case." - ) from e + validate_model_zoo_json(config) models = { name: RegisteredModel( name=name, diff --git a/wsinfer_zoo/schemas/model-config.schema.json b/wsinfer_zoo/schemas/model-config.schema.json index d0c814b..a4fc081 100644 --- a/wsinfer_zoo/schemas/model-config.schema.json +++ b/wsinfer_zoo/schemas/model-config.schema.json @@ -2,6 +2,13 @@ "$schema": "http://json-schema.org/draft-04/schema", "type": "object", "properties": { + "spec_version": { + "type": "string", + "description": "Version of the model config spec.", + "enum": [ + "1.0" + ] + }, "architecture": { "type": "string", "description": "Architecture of the model (Use TIMM names)" @@ -54,6 +61,7 @@ } }, "required": [ + "spec_version", "architecture", "num_classes", "patch_size_pixels", From baa32b8509a09dc8bb041fc715671c72fd4d0579 Mon Sep 17 00:00:00 2001 From: kaczmarj Date: Wed, 21 Jun 2023 12:42:06 -0400 Subject: [PATCH 09/10] add patchcam and kather100k models to zoo --- wsinfer-zoo-registry.json | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/wsinfer-zoo-registry.json b/wsinfer-zoo-registry.json index ad02636..a0f73ea 100644 --- a/wsinfer-zoo-registry.json +++ b/wsinfer-zoo-registry.json @@ -19,6 +19,16 @@ "description": "Prostate tumor", "hf_repo_id": "kaczmarj/prostate-tumor-resnet34.tcga-prad", "hf_revision": "main" + }, + "lymphnodes-tiatoolbox-resnet50.patchcamelyon": { + "description": "Lymph node metastasis (PatchCamelyon)", + "hf_repo_id": "kaczmarj/lymphnodes-tiatoolbox-resnet50.patchcamelyon", + "hf_revision": "main" + }, + "colorectal-tiatoolbox-resnet50.kather100k": { + "description": "Colorectal cancer tissue classification (Kather100K)", + "hf_repo_id": "kaczmarj/colorectal-tiatoolbox-resnet50.kather100k", + "hf_revision": "main" } } } From 2750ab6e53eb5fe071b2351b288bc23a447614de Mon Sep 17 00:00:00 2001 From: kaczmarj Date: Wed, 21 Jun 2023 15:18:33 -0400 Subject: [PATCH 10/10] make the cli wsinfer-zoo instead of wsinfer_zoo --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index c81a310..b79f3ee 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ dev = [options.entry_points] console_scripts = - wsinfer_zoo = wsinfer_zoo.cli:cli + wsinfer-zoo = wsinfer_zoo.cli:cli [options.package_data] wsinfer =