Skip to content

Commit

Permalink
Fix: Simplify model content validation
Browse files Browse the repository at this point in the history
  • Loading branch information
hoh committed Oct 16, 2024
1 parent 2c79fca commit 8e7435e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 67 deletions.
74 changes: 9 additions & 65 deletions aleph_message/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import logging
from copy import copy
from enum import Enum
from hashlib import sha256
from json import JSONDecodeError
from pathlib import Path
Expand Down Expand Up @@ -318,74 +317,19 @@ class ProgramMessage(BaseMessage):
content: ProgramContent
forgotten_by: Optional[List[str]] = None

@staticmethod
def normalize_content(content: Union[Dict[str, Any], Any]) -> Any:
"""
Normalizes the structure of a dictionary (`content`) to ensure that its
values are correctly formatted and compatible with Pydantic V2
This method handles specific cases where certain types
(such as `ItemHash`, `Enum`, `list`, and `dict`) require special
handling to align with the stricter requirements of Pydantic V2.
- Converts `ItemHash` instances to their string representation.
- Converts `Enum` instances to their corresponding `value`.
- Processes lists:
- If the key is "volumes" and all elements are strings, the list is
left as is.
- Otherwise, it recursively normalizes each element of the list.
- Processes dictionaries:
- If the key is "size_mib", it extracts the first value from the
dictionary.
- Otherwise, it recursively normalizes the dictionary.
Args:
content (Union[Dict[str, Any], Any]): The dictionary or other data
type to normalize.
Returns:
Any: The normalized content, with appropriate transformations
applied to `ItemHash`, `Enum`, `list`, and `dict` values, ensuring
compatibility withPydantic V2.
"""
if not isinstance(content, dict):
return content

def handle_value(key: str, value: Any) -> Any:
if isinstance(value, ItemHash):
return str(value)
elif isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return (
value
if key == "volumes" and all(isinstance(v, str) for v in value)
else [ProgramMessage.normalize_content(v) for v in value]
)
elif isinstance(value, dict):
return (
list(value.values())[0]
if key == "size_mib"
else ProgramMessage.normalize_content(value)
)
else:
return value

return {key: handle_value(key, value) for key, value in content.items()}

@field_validator("content")
def check_content(cls, v, values):
"""Ensure that the content of the message is correctly formatted."""
item_type = values.data.get("item_type")
if item_type == ItemType.inline:
# Ensure that the content correct JSON
item_content = json.loads(values.data.get("item_content"))

# Normalizing content to fit the structure of item_content
normalized_content = cls.normalize_content(v.model_dump(exclude_none=True))

if normalized_content != item_content:
# Print the differences to help debugging
logger.debug("Differences found between content and item_content")
logger.debug(f"Content: {normalized_content}")
logger.debug(f"Item Content: {item_content}")
# Ensure that the content matches the expected structure
if v.model_dump(exclude_none=True) != item_content:
logger.warning(
"Content and item_content differ for message %s",
values.data["item_hash"],
)
raise ValueError("Content and item_content differ")
return v

Expand Down Expand Up @@ -435,7 +379,7 @@ def parse_message(message_dict: Dict) -> AlephMessage:
raise ValueError(f"Unknown message type {message_dict['type']}")


def add_item_content_and_hash(message_dict: Dict, inplace: bool = False):
def add_item_content_and_hash(message_dict: Dict, inplace: bool = False) -> Dict:
if not inplace:
message_dict = copy(message_dict)

Expand Down
5 changes: 5 additions & 0 deletions aleph_message/models/item_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import lru_cache

from pydantic import GetCoreSchemaHandler
from pydantic.functional_serializers import model_serializer
from pydantic_core import core_schema

from ..exceptions import UnknownHashError
Expand Down Expand Up @@ -35,6 +36,10 @@ def is_storage(cls, item_hash: str):
def is_ipfs(cls, item_hash: str):
return cls.from_hash(item_hash) == cls.ipfs

@model_serializer
def __str__(self):
return self.value


class ItemHash(str):
item_type: ItemType
Expand Down
36 changes: 34 additions & 2 deletions aleph_message/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from os import listdir
from os.path import isdir, join
from pathlib import Path
from unittest import mock

import pytest
import requests
Expand All @@ -15,6 +16,7 @@
AggregateMessage,
ForgetMessage,
InstanceMessage,
ItemHash,
ItemType,
MessagesResponse,
MessageType,
Expand All @@ -25,7 +27,7 @@
create_message_from_file,
create_message_from_json,
create_new_message,
parse_message, ItemHash,
parse_message,
)
from aleph_message.models.execution.environment import AMDSEVPolicy
from aleph_message.models.execution.instance import RootfsVolume
Expand Down Expand Up @@ -362,7 +364,9 @@ def test_volume_size_constraints():
# Use partial function to avoid repeating the same code
create_test_rootfs = partial(
RootfsVolume,
parent=ParentVolume(ref=ItemHash("QmX8K1c22WmQBAww5ShWQqwMiFif7XFrJD6iFBj7skQZXW")),
parent=ParentVolume(
ref=ItemHash("QmX8K1c22WmQBAww5ShWQqwMiFif7XFrJD6iFBj7skQZXW")
),
persistence=VolumePersistence.store,
)

Expand All @@ -379,6 +383,34 @@ def test_volume_size_constraints():
_ = create_test_rootfs(size_mib=size_mib_rootfs + 1)


def test_program_message_content_and_item_content_differ():
# Test that a ValidationError is raised if the content and item_content differ

# Get a program message as JSON-compatible dict
path = Path(__file__).parent / "messages/machine.json"
with open(path) as fd:
message_dict_original = json.load(fd)
message_dict: dict = add_item_content_and_hash(message_dict_original, inplace=True)

# patch hashlib.sha256 with a mock else this raises an error first
mock_hash = mock.MagicMock()
mock_hash.hexdigest.return_value = (
"cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe"
)
message_dict["item_hash"] = (
"cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe"
)

# Patch the content to differ from item_content
message_dict["content"]["replaces"] = "does-not-exist"

# Test that a ValidationError is raised if the content and item_content differ
with mock.patch("aleph_message.models.sha256", return_value=mock_hash):
with pytest.raises(ValidationError) as excinfo:
ProgramMessage.model_validate(message_dict)
assert "Content and item_content differ" in str(excinfo.value)


@pytest.mark.slow
@pytest.mark.skipif(not isdir(MESSAGES_STORAGE_PATH), reason="No file on disk to test")
def test_messages_from_disk():
Expand Down

0 comments on commit 8e7435e

Please sign in to comment.