Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverted Pydantic v2 migration #114

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 47 additions & 57 deletions aleph_message/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import datetime
import json
import logging
from copy import copy
from hashlib import sha256
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, cast

from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, Extra, Field, validator
from typing_extensions import TypeAlias

from .abstract import BaseContent, HashableModel
Expand All @@ -17,10 +16,6 @@
from .execution.program import ProgramContent
from .item_hash import ItemHash, ItemType

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


__all__ = [
"AggregateContent",
"AggregateMessage",
Expand Down Expand Up @@ -59,7 +54,8 @@ class MongodbId(BaseModel):

oid: str = Field(alias="$oid")

model_config = ConfigDict(extra="forbid")
class Config:
extra = Extra.forbid


class ChainRef(BaseModel):
Expand All @@ -80,7 +76,8 @@ class MessageConfirmationHash(BaseModel):
binary: str = Field(alias="$binary")
type: str = Field(alias="$type")

model_config = ConfigDict(extra="forbid")
class Config:
extra = Extra.forbid


class MessageConfirmation(BaseModel):
Expand All @@ -96,13 +93,15 @@ class MessageConfirmation(BaseModel):
default=None, description="The address that published the transaction."
)

model_config = ConfigDict(extra="forbid")
class Config:
extra = Extra.forbid


class AggregateContentKey(BaseModel):
name: str

model_config = ConfigDict(extra="forbid")
class Config:
extra = Extra.forbid


class PostContent(BaseContent):
Expand All @@ -117,15 +116,16 @@ class PostContent(BaseContent):
)
type: str = Field(description="User-generated 'content-type' of a POST message")

@field_validator("type")
@validator("type")
def check_type(cls, v, values):
if v == "amend":
ref = values.data.get("ref")
ref = values.get("ref")
if not ref:
raise ValueError("A 'ref' is required for POST type 'amend'")
return v

model_config = ConfigDict(extra="forbid")
class Config:
extra = Extra.forbid


class AggregateContent(BaseContent):
Expand All @@ -136,7 +136,8 @@ class AggregateContent(BaseContent):
)
content: Dict = Field(description="The content of an aggregate must be a dict")

model_config = ConfigDict(extra="forbid")
class Config:
extra = Extra.forbid


class StoreContent(BaseContent):
Expand All @@ -147,11 +148,10 @@ class StoreContent(BaseContent):
size: Optional[int] = None # Generated by the node on storage
content_type: Optional[str] = None # Generated by the node on storage
ref: Optional[str] = None
metadata: Optional[Dict[str, Any]] = Field(
default=None, description="Metadata of the VM"
)
metadata: Optional[Dict[str, Any]] = Field(description="Metadata of the VM")

model_config = ConfigDict(extra="allow")
class Config:
extra = Extra.allow


class ForgetContent(BaseContent):
Expand Down Expand Up @@ -214,9 +214,9 @@ class BaseMessage(BaseModel):

forgotten_by: Optional[List[str]]

@field_validator("item_content")
@validator("item_content")
def check_item_content(cls, v: Optional[str], values) -> Optional[str]:
item_type = values.data.get("item_type")
item_type = values["item_type"]
if v is None:
return None
elif item_type == ItemType.inline:
Expand All @@ -232,14 +232,14 @@ def check_item_content(cls, v: Optional[str], values) -> Optional[str]:
)
return v

@field_validator("item_hash")
@validator("item_hash")
def check_item_hash(cls, v: ItemHash, values) -> ItemHash:
item_type = values.data.get("item_type")
item_type = values["item_type"]
if item_type == ItemType.inline:
item_content: str = values.data.get("item_content")
item_content: str = values["item_content"]

# Double check that the hash function is supported
hash_type = values.data.get("hash_type") or HashType.sha256
hash_type = values["hash_type"] or HashType.sha256
assert hash_type.value == HashType.sha256

computed_hash: str = sha256(item_content.encode()).hexdigest()
Expand All @@ -255,56 +255,49 @@ def check_item_hash(cls, v: ItemHash, values) -> ItemHash:
assert item_type == ItemType.storage
return v

@field_validator("confirmed")
@validator("confirmed")
def check_confirmed(cls, v, values):
confirmations = values.data.get("confirmations")
confirmations = values["confirmations"]
if v is True and not bool(confirmations):
raise ValueError("Message cannot be 'confirmed' without 'confirmations'")
return v

@field_validator("time")
@validator("time")
def convert_float_to_datetime(cls, v, values):
if isinstance(v, float):
v = datetime.datetime.fromtimestamp(v)
assert isinstance(v, datetime.datetime)
return v

model_config = ConfigDict(extra="forbid")

def custom_dump(self):
"""Exclude MongoDB identifiers from dumps for historical reasons."""
return self.model_dump(exclude={"id_", "_id"})
class Config:
extra = Extra.forbid
exclude = {"id_", "_id"}


class PostMessage(BaseMessage):
"""Unique data posts (unique data points, events, ...)"""

type: Literal[MessageType.post]
content: PostContent
forgotten_by: Optional[List[str]] = None


class AggregateMessage(BaseMessage):
"""A key-value storage specific to an address"""

type: Literal[MessageType.aggregate]
content: AggregateContent
forgotten_by: Optional[list] = None


class StoreMessage(BaseMessage):
type: Literal[MessageType.store]
content: StoreContent
forgotten_by: Optional[list] = None
metadata: Optional[Dict[str, Any]] = None


class ForgetMessage(BaseMessage):
type: Literal[MessageType.forget]
content: ForgetContent
forgotten_by: Optional[list] = None

@field_validator("forgotten_by")
@validator("forgotten_by")
def cannot_be_forgotten(cls, v: Optional[List[str]], values) -> Optional[List[str]]:
assert values
if v:
Expand All @@ -315,29 +308,25 @@ def cannot_be_forgotten(cls, v: Optional[List[str]], values) -> Optional[List[st
class ProgramMessage(BaseMessage):
type: Literal[MessageType.program]
content: ProgramContent
forgotten_by: Optional[List[str]] = None

@field_validator("content")
@validator("content")
def check_content(cls, v, values):
"""Ensure that the content of the message is correctly formatted."""
item_type = values.data.get("item_type")
item_type = values["item_type"]
if item_type == ItemType.inline:
# Ensure that the content correct JSON
item_content = json.loads(values.data.get("item_content"))
# Ensure that the content matches the expected structure
if v.model_dump(exclude_none=True) != item_content:
logger.warning(
"Content and item_content differ for message %s",
values.data["item_hash"],
)
item_content = json.loads(values["item_content"])
if v.dict(exclude_none=True) != item_content:
# Print differences
vdict = v.dict(exclude_none=True)
for key, value in item_content.items():
if vdict[key] != value:
print(f"{key}: {vdict[key]} != {value}")
raise ValueError("Content and item_content differ")
return v


class InstanceMessage(BaseMessage):
type: Literal[MessageType.instance]
content: InstanceContent
forgotten_by: Optional[List[str]] = None


AlephMessage: TypeAlias = Union[
Expand Down Expand Up @@ -374,12 +363,12 @@ def parse_message(message_dict: Dict) -> AlephMessage:
message_class.__annotations__["type"].__args__[0]
)
if message_dict["type"] == message_type:
return message_class.model_validate(message_dict)
return message_class.parse_obj(message_dict)
else:
raise ValueError(f"Unknown message type {message_dict['type']}")


def add_item_content_and_hash(message_dict: Dict, inplace: bool = False) -> Dict:
def add_item_content_and_hash(message_dict: Dict, inplace: bool = False):
if not inplace:
message_dict = copy(message_dict)

Expand All @@ -401,7 +390,7 @@ def create_new_message(
"""
message_content = add_item_content_and_hash(message_dict)
if factory:
return cast(T, factory.model_validate(message_content))
return cast(T, factory.parse_obj(message_content))
else:
return cast(T, parse_message(message_content))

Expand All @@ -416,7 +405,7 @@ def create_message_from_json(
message_dict = json.loads(json_data)
message_content = add_item_content_and_hash(message_dict, inplace=True)
if factory:
return factory.model_validate(message_content)
return factory.parse_obj(message_content)
else:
return parse_message(message_content)

Expand All @@ -433,7 +422,7 @@ def create_message_from_file(
message_dict = decoder.load(fd)
message_content = add_item_content_and_hash(message_dict, inplace=True)
if factory:
return factory.model_validate(message_content)
return factory.parse_obj(message_content)
else:
return parse_message(message_content)

Expand All @@ -447,4 +436,5 @@ class MessagesResponse(BaseModel):
pagination_per_page: int
pagination_item: str

model_config = ConfigDict(extra="forbid")
class Config:
extra = Extra.forbid
5 changes: 3 additions & 2 deletions aleph_message/models/abstract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, Extra


def hashable(obj):
Expand All @@ -24,4 +24,5 @@ class BaseContent(BaseModel):
address: str
time: float

model_config = ConfigDict(extra="forbid")
class Config:
extra = Extra.forbid
2 changes: 1 addition & 1 deletion aleph_message/models/execution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .abstract import BaseExecutableContent
from .base import Encoding, Interface, MachineType, Payment, PaymentType
from .instance import InstanceContent
from .program import ProgramContent
from .base import Encoding, MachineType, PaymentType, Payment, Interface

__all__ = [
"BaseExecutableContent",
Expand Down
2 changes: 1 addition & 1 deletion aleph_message/models/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Payment(HashableModel):

chain: Chain
"""Which chain to check for funds"""
receiver: Optional[str] = None
receiver: Optional[str]
"""Optional alternative address to send tokens to"""
type: PaymentType
"""Whether to pay by holding $ALEPH or by streaming tokens"""
Expand Down
Loading
Loading