Skip to content

Commit

Permalink
fix runtime errors, since schema annotation breaks python runtime!
Browse files Browse the repository at this point in the history
- it mangles __type and __type into <classs_name>__type, so these are two properties, so the constructor fails
- added narrow_type extra schema, it refines a type to a const enum for the schema (the resulting schema is teh same as before)
  • Loading branch information
Totto16 committed Aug 9, 2023
1 parent f417d1d commit 604a8a3
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 34 deletions.
5 changes: 2 additions & 3 deletions src/content/collection_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
NameParser,
ScannedFile,
Summary,
deduplicate_required,
narrow_type,
)
from content.series_content import SeriesContent

Expand All @@ -34,10 +34,9 @@ class CollectionContentDict(ContentDict):
series: list[SeriesContent]


@schema(extra=deduplicate_required)
@schema(extra=narrow_type(("type", Literal[ContentType.collection])))
@dataclass(slots=True, repr=True)
class CollectionContent(Content):
__type: Literal[ContentType.collection] = field(metadata=alias("type"))
__description: CollectionDescription = field(metadata=alias("description"))
__series: list[SeriesContent] = field(metadata=alias("series"))

Expand Down
5 changes: 2 additions & 3 deletions src/content/episode_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
NameParser,
ScannedFile,
Summary,
deduplicate_required,
narrow_type,
)


Expand Down Expand Up @@ -48,10 +48,9 @@ def itr_print_percent() -> None:
print(f"{percent:.02f} %")


@schema(extra=deduplicate_required)
@schema(extra=narrow_type(("type", Literal[ContentType.episode])))
@dataclass(slots=True, repr=True)
class EpisodeContent(Content):
__type: Literal[ContentType.episode] = field(metadata=alias("type"))
__description: EpisodeDescription = field(metadata=alias("description"))
__language: Language = field(metadata=alias("language"))

Expand Down
65 changes: 57 additions & 8 deletions src/content/general.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from collections.abc import Callable, Mapping, MutableMapping
from dataclasses import dataclass, field
from enum import Enum
from hashlib import sha256
from pathlib import Path
from typing import Any, Generic, Optional, Self, TypedDict, TypeVar
from typing import (
Any,
Generic,
Optional,
Self,
TypedDict,
TypeVar,
cast,
)

from apischema import schema
from apischema.json_schema import (
deserialization_schema,
serialization_schema,
)
from classifier import Language
from enlighten import Manager

Expand Down Expand Up @@ -489,10 +502,46 @@ def safe_index(ls: list[SF], item: SF) -> Optional[int]:
return None


def deduplicate_required(schema: dict[str, Any]) -> None:
if schema.get("required") is not None and isinstance(schema["required"], list):
result: list[Any] = []
for element in schema["required"]:
if element not in result:
result.append(element)
schema["required"] = result
def get_schema(
any_type: Any,
*,
additional_properties: Optional[bool] = None,
all_refs: Optional[bool] = None,
) -> MutableMapping[str, Any]:
result: Mapping[str, Any] = deserialization_schema(
any_type,
additional_properties=additional_properties,
all_refs=all_refs,
)

result2 = serialization_schema(
any_type,
additional_properties=additional_properties,
all_refs=all_refs,
)

if result != result2:
msg = "Deserialization and Serialization scheme mismatch"
raise RuntimeError(msg)

return cast(MutableMapping[str, Any], result)


def narrow_type(replace: tuple[str, Any]) -> Callable[[dict[str, Any]], None]:
name, type_desc = replace

def narrow_schema(schema: dict[str, Any]) -> None:
if schema.get("properties") is not None and isinstance(
schema["properties"],
dict,
):
resulting_type: MutableMapping[str, Any] = get_schema(type_desc)
del resulting_type["$schema"]

if cast(dict[str, Any], schema["properties"]).get(name) is None:
msg = f"Narrowing type failed, type is not present. key '{name}'"
raise RuntimeError(msg)

schema["properties"][name] = resulting_type

return narrow_schema
7 changes: 2 additions & 5 deletions src/content/season_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ScannedFile,
SeasonDescription,
Summary,
deduplicate_required,
narrow_type,
)


Expand All @@ -35,12 +35,9 @@ class SeasonContentDict(ContentDict):
episodes: list[EpisodeContent]


@schema(extra=deduplicate_required)
@schema(extra=narrow_type(("type", Literal[ContentType.season])))
@dataclass(slots=True, repr=True)
class SeasonContent(Content):
__type: Literal[ContentType.season] = field(
metadata=alias("type"),
) # TODO: submit upstream path, to allow this: (to not add "type" in the required field twice)
__description: SeasonDescription = field(metadata=alias("description"))
__episodes: list[EpisodeContent] = field(metadata=alias("episodes"))

Expand Down
7 changes: 4 additions & 3 deletions src/content/series_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
ScannedFile,
SeriesDescription,
Summary,
deduplicate_required,
narrow_type,
)
from content.season_content import SeasonContent

Expand All @@ -35,10 +35,11 @@ class SeriesContentDict(ContentDict):
seasons: list[SeasonContent]


@schema(extra=deduplicate_required)


@schema(extra=narrow_type(("type", Literal[ContentType.series])))
@dataclass(slots=True, repr=True)
class SeriesContent(Content):
__type: Literal[ContentType.series] = field(metadata=alias("type"))
__description: SeriesDescription = field(metadata=alias("description"))
__seasons: list[SeasonContent] = field(metadata=alias("seasons"))

Expand Down
20 changes: 8 additions & 12 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
from typing import TYPE_CHECKING, Annotated, Any, Optional, Self, TypedDict

from apischema import deserialize, schema, serialize
from apischema.json_schema import (
deserialization_schema,
serialization_schema,
)
from classifier import Classifier
from content.base_class import Content, ContentCharacteristic, process_folder
from content.collection_content import CollectionContent
from content.episode_content import EpisodeContent
from content.general import Callback, ContentType, NameParser, ScannedFileType
from content.general import (
Callback,
ContentType,
NameParser,
ScannedFileType,
get_schema,
)
from content.scan_helpers import content_from_scan
from content.season_content import SeasonContent
from content.series_content import SeriesContent
Expand Down Expand Up @@ -285,18 +287,12 @@ def parse_contents(


def generate_json_schema(file_path: Path, any_type: Any) -> None:
result: Mapping[str, Any] = deserialization_schema(
result: Mapping[str, Any] = get_schema(
any_type,
additional_properties=False,
all_refs=True,
)

result2 = serialization_schema(any_type, additional_properties=False, all_refs=True)

if result != result2:
msg = "Deserialization and Serialization scheme mismatch"
raise RuntimeError(msg)

if not file_path.parent.exists():
Path(file_path).parent.mkdir(parents=True)

Expand Down

0 comments on commit 604a8a3

Please sign in to comment.