Skip to content

Commit

Permalink
ci: run mypy w/ pre-commit; fix lingering typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Feb 24, 2024
1 parent cecab98 commit 73fe8e3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ repos:
rev: v0.2.2
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.8.0'
hooks:
- id: mypy
additional_dependencies: [pydantic, types-requests, types-pytz, types-setuptools, types-urllib3, StrEnum]
12 changes: 11 additions & 1 deletion horde_model_reference/legacy/get_all_filesizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
from loguru import logger

from horde_model_reference.legacy.classes.raw_legacy_model_database_records import (
RawLegacy_FileRecord,
RawLegacy_StableDiffusion_ModelRecord,
)
from horde_model_reference.path_consts import AIWORKER_CACHE_HOME


def get_all_file_sizes(sd_db: Path, write_to_path: Path | None = None) -> bool:
def get_all_file_sizes(sd_db: Path, write_to_path: Path | str) -> bool:

if AIWORKER_CACHE_HOME is None:
logger.error("AIWORKER_CACHE_HOME is not set.")
return False

raw_json_sd_db: str
with open(sd_db) as sd_db_file:
raw_json_sd_db = sd_db_file.read()
Expand All @@ -28,6 +34,10 @@ def get_all_file_sizes(sd_db: Path, write_to_path: Path | None = None) -> bool:
}

for _, model_details in parsed_db_records.items():
if not isinstance(model_details.config["files"][0], RawLegacy_FileRecord):
logger.error(f"File {model_details.config['files'][0]} is not a valid file record.")
continue

filename = model_details.config["files"][0].path

full_file_path = Path(AIWORKER_CACHE_HOME) / "compvis" / filename
Expand Down
8 changes: 4 additions & 4 deletions horde_model_reference/model_reference_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ class Generic_ModelReference(RootModel[Mapping[str, Generic_ModelRecord]]):
class StableDiffusion_ModelReference(Generic_ModelReference):
"""The combined metadata and model list."""

_baseline: dict[STABLE_DIFFUSION_BASELINE_CATEGORY, int] = PrivateAttr(default_factory=dict)
_baseline: dict[STABLE_DIFFUSION_BASELINE_CATEGORY | str, int] = PrivateAttr(default_factory=dict)
"""A dictionary of all the baseline types and how many models use them."""
_styles: dict[MODEL_STYLE, int] = PrivateAttr(default_factory=dict)
_styles: dict[MODEL_STYLE | str, int] = PrivateAttr(default_factory=dict)
"""A dictionary of all the styles and how many models use them."""
_tags: dict[str, int] = PrivateAttr(default_factory=dict)
"""A dictionary of all the tags and how many models use them."""
Expand Down Expand Up @@ -200,13 +200,13 @@ def rebuild_metadata(self) -> None:
self._models_hosts[host] = self._models_hosts.get(host, 0) + 1

@property
def baseline(self) -> dict[STABLE_DIFFUSION_BASELINE_CATEGORY, int]:
def baseline(self) -> dict[STABLE_DIFFUSION_BASELINE_CATEGORY | str, int]:
"""Return a dictionary of all the baseline types and how many models use them."""
self.check_was_models_modified()
return self._baseline

@property
def styles(self) -> dict[MODEL_STYLE, int]:
def styles(self) -> dict[MODEL_STYLE | str, int]:
"""Return a dictionary of all the styles and how many models use them."""
self.check_was_models_modified()
return self._styles
Expand Down

0 comments on commit 73fe8e3

Please sign in to comment.