Skip to content

Commit

Permalink
Fix: Solved some of new mypy issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
Andres D. Molins committed Jul 9, 2024
1 parent ca5826b commit ea916ee
Show file tree
Hide file tree
Showing 15 changed files with 93 additions and 77 deletions.
5 changes: 3 additions & 2 deletions src/aleph/chains/chain_data_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, Optional, List, Any, Mapping, Set, cast, Type, Union, Self

import aio_pika.abc
from aleph_message.models import StoreContent, ItemType, Chain, MessageType
from aleph_message.models import StoreContent, ItemType, Chain, MessageType, ItemHash
from configmanager import Config
from pydantic import ValidationError

Expand Down Expand Up @@ -187,7 +187,8 @@ def _get_tx_messages_smart_contract_protocol(tx: ChainTxDb) -> List[Dict[str, An
address=payload.address,
time=payload.timestamp_seconds,
item_type=ItemType.ipfs,
item_hash=payload.content,
item_hash=ItemHash(payload.content),
metadata=None,
)
item_content = content.json(exclude_none=True)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/chains/ethereum.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async def _request_transactions(
try:
jdata = json.loads(message)
context = TxContext(
chain=CHAIN_NAME,
chain=Chain(CHAIN_NAME),
hash=event_data.transactionHash.hex(),
time=timestamp,
height=event_data.blockNumber,
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/chains/nuls2.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def _request_transactions(
jdata = json.loads(ddata)

context = TxContext(
chain=CHAIN_NAME,
chain=Chain(CHAIN_NAME),
hash=tx["hash"],
height=tx["height"],
time=tx["createTime"],
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/db/accessors/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def reject_existing_pending_message(

# The message may already be processed and someone is sending invalid copies.
# Just drop the pending message.
message_status = get_message_status(session=session, item_hash=item_hash)
message_status = get_message_status(session=session, item_hash=ItemHash(item_hash))
if message_status:
if message_status.status not in (MessageStatus.PENDING, MessageStatus.REJECTED):
delete_pending_message(session=session, pending_message=pending_message)
Expand Down
6 changes: 3 additions & 3 deletions src/aleph/handlers/content/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def _insert_aggregate_element(session: DbSession, message: MessageDb):
content = cast(AggregateContent, message.parsed_content)
aggregate_element = AggregateElementDb(
item_hash=message.item_hash,
key=content.key,
key=str(content.key),
owner=content.address,
content=content.content,
creation_datetime=timestamp_to_datetime(message.parsed_content.time),
Expand Down Expand Up @@ -228,10 +228,10 @@ async def forget_message(self, session: DbSession, message: MessageDb) -> Set[st
key = content.key

LOGGER.debug("Deleting aggregate element %s...", message.item_hash)
delete_aggregate(session=session, owner=owner, key=key)
delete_aggregate(session=session, owner=owner, key=str(key))
delete_aggregate_element(session=session, item_hash=message.item_hash)

LOGGER.debug("Refreshing aggregate %s/%s...", owner, key)
refresh_aggregate(session=session, owner=owner, key=key)
refresh_aggregate(session=session, owner=owner, key=str(key))

return set()
4 changes: 2 additions & 2 deletions src/aleph/handlers/content/forget.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async def _forget_message(
async def _forget_item_hash(
self, session: DbSession, item_hash: str, forgotten_by: MessageDb
):
message_status = get_message_status(session=session, item_hash=item_hash)
message_status = get_message_status(session=session, item_hash=ItemHash(item_hash))
if not message_status:
raise ForgetTargetNotFound(target_hash=item_hash)

Expand All @@ -187,7 +187,7 @@ async def _forget_item_hash(
)
raise ForgetTargetNotFound(item_hash)

message = get_message_by_item_hash(session=session, item_hash=item_hash)
message = get_message_by_item_hash(session=session, item_hash=ItemHash(item_hash))
if not message:
raise ForgetTargetNotFound(item_hash)

Expand Down
3 changes: 2 additions & 1 deletion src/aleph/handlers/content/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ async def process_post(self, session: DbSession, message: MessageDb):
if (
content.type == self.balances_post_type
and content.address in self.balances_addresses
and content.content
):
LOGGER.info("Updating balances...")
update_balances(session=session, content=content.content)
Expand All @@ -150,7 +151,7 @@ async def forget_message(self, session: DbSession, message: MessageDb) -> Set[st
delete_post(session=session, item_hash=message.item_hash)

if content.type == "amend":
original_post = get_original_post(session, content.ref)
original_post = get_original_post(session, str(content.ref))
if original_post is None:
raise InternalError(
f"Could not find original post ({content.ref} for amend ({message.item_hash})."
Expand Down
65 changes: 34 additions & 31 deletions src/aleph/handlers/content/vm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import logging
import math
from typing import List, Set, overload, Protocol, Optional
from typing import List, Set, overload, Protocol, Optional, Union

from aleph_message.models import (
ProgramContent,
ExecutableContent,
InstanceContent,
MessageType,
)
from aleph_message.models.execution.instance import RootfsVolume
from aleph_message.models.execution.program import (
MachineType,
ProgramContent,
)
from aleph_message.models.execution.volume import (
AbstractVolume,
ImmutableVolume,
Expand Down Expand Up @@ -72,19 +77,13 @@

def _get_vm_content(message: MessageDb) -> ExecutableContent:
content = message.parsed_content
if not isinstance(content, ExecutableContent):
if not isinstance(content, (InstanceContent, ProgramContent)):
raise InvalidMessageFormat(
f"Unexpected content type for program message: {message.item_hash}"
)
return content


from aleph_message.models.execution.program import (
MachineType,
ProgramContent,
)


@overload
def _map_content_to_db_model(
item_hash: str, content: InstanceContent
Expand All @@ -95,7 +94,7 @@ def _map_content_to_db_model(
# This seems linked to multiple inheritance of Pydantic base models, a deeper investigation
# is required.
@overload
def _map_content_to_db_model(item_hash: str, content: ProgramContent) -> ProgramDb: # type: ignore[misc]
def _map_content_to_db_model(item_hash: str, content: ProgramContent) -> ProgramDb:
...


Expand Down Expand Up @@ -186,7 +185,7 @@ def vm_message_to_db(message: MessageDb) -> VmBaseDb:
content = _get_vm_content(message)
vm = _map_content_to_db_model(message.item_hash, content)

if isinstance(vm, ProgramDb):
if isinstance(vm, ProgramContent):
vm.program_type = content.type
vm.persistent = bool(content.on.persistent)
vm.http_trigger = content.on.http
Expand Down Expand Up @@ -222,13 +221,16 @@ def vm_message_to_db(message: MessageDb) -> VmBaseDb:

elif isinstance(content, InstanceContent):
parent = content.rootfs.parent
vm.rootfs = RootfsVolumeDb(
parent_ref=parent.ref,
parent_use_latest=parent.use_latest,
size_mib=content.rootfs.size_mib,
persistence=content.rootfs.persistence,
)
vm.authorized_keys = content.authorized_keys
if isinstance(vm, VmInstanceDb):
vm.rootfs = RootfsVolumeDb(
parent_ref=parent.ref,
parent_use_latest=parent.use_latest,
size_mib=content.rootfs.size_mib,
persistence=content.rootfs.persistence,
)
vm.authorized_keys = content.authorized_keys
else:
raise TypeError(f"Unexpected VM message content type: {type(vm)}")

else:
raise TypeError(f"Unexpected VM message content type: {type(content)}")
Expand Down Expand Up @@ -279,18 +281,18 @@ def check_parent_volumes_size_requirements(
) -> None:
def _get_parent_volume_file(_parent: ParentVolume) -> StoredFileDb:
if _parent.use_latest:
file_tag = get_file_tag(session=session, tag=_parent.ref)
file_tag = get_file_tag(session=session, tag=FileTag(_parent.ref))
if file_tag is None:
raise InternalError(
f"Could not find latest version of parent volume {volume.parent.ref}"
f"Could not find latest version of parent volume {_parent.ref}"
)

return file_tag.file

file_pin = get_message_file_pin(session=session, item_hash=_parent.ref)
if file_pin is None:
raise InternalError(
f"Could not find original version of parent volume {volume.parent.ref}"
f"Could not find original version of parent volume {_parent.ref}"
)

return file_pin.file
Expand All @@ -299,7 +301,7 @@ class HasParent(Protocol):
parent: ParentVolume
size_mib: int

volumes_with_parent: List[HasParent] = [
volumes_with_parent: List[Union[PersistentVolume, RootfsVolume]] = [
volume
for volume in content.volumes
if isinstance(volume, PersistentVolume) and volume.parent is not None
Expand All @@ -309,16 +311,17 @@ class HasParent(Protocol):
volumes_with_parent.append(content.rootfs)

for volume in volumes_with_parent:
volume_metadata = _get_parent_volume_file(volume.parent)
volume_size = volume.size_mib * 1024 * 1024
if volume_size < volume_metadata.size:
raise VmVolumeTooSmall(
parent_size=volume_metadata.size,
parent_ref=volume.parent.ref,
parent_file=volume_metadata.hash,
volume_name=getattr(volume, "name", "rootfs"),
volume_size=volume_size,
)
if volume.parent:
volume_metadata = _get_parent_volume_file(volume.parent)
volume_size = volume.size_mib * 1024 * 1024
if volume_size < volume_metadata.size:
raise VmVolumeTooSmall(
parent_size=volume_metadata.size,
parent_ref=volume.parent.ref,
parent_file=volume_metadata.hash,
volume_name=getattr(volume, "name", "rootfs"),
volume_size=volume_size,
)


class VmMessageHandler(ContentHandler):
Expand Down
3 changes: 2 additions & 1 deletion src/aleph/handlers/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import ValidationError
from sqlalchemy import insert

from aleph_message.models import ItemHash
from aleph.chains.signature_verifier import SignatureVerifier
from aleph.db.accessors.files import insert_content_file_pin, upsert_file
from aleph.db.accessors.messages import (
Expand Down Expand Up @@ -377,7 +378,7 @@ async def process(
"""

existing_message = get_message_by_item_hash(
session=session, item_hash=pending_message.item_hash
session=session, item_hash=ItemHash(pending_message.item_hash)
)
if existing_message:
await self.confirm_existing_message(
Expand Down
19 changes: 10 additions & 9 deletions src/aleph/schemas/api/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Dict,
Mapping,
Annotated,
Type,
)

from aleph_message.models import (
Expand Down Expand Up @@ -67,38 +68,38 @@ class Config:


class AggregateMessage(
BaseMessage[Literal[MessageType.aggregate], AggregateContent] # type: ignore
BaseMessage[Literal[MessageType.aggregate], AggregateContent]
):
...


class ForgetMessage(
BaseMessage[Literal[MessageType.forget], ForgetContent] # type: ignore
BaseMessage[Literal[MessageType.forget], ForgetContent]
):
...


class InstanceMessage(BaseMessage[Literal[MessageType.instance], InstanceContent]): # type: ignore
class InstanceMessage(BaseMessage[Literal[MessageType.instance], InstanceContent]):
...


class PostMessage(BaseMessage[Literal[MessageType.post], PostContent]): # type: ignore
class PostMessage(BaseMessage[Literal[MessageType.post], PostContent]):
...


class ProgramMessage(
BaseMessage[Literal[MessageType.program], ProgramContent] # type: ignore
BaseMessage[Literal[MessageType.program], ProgramContent]
):
...


class StoreMessage(
BaseMessage[Literal[MessageType.store], StoreContent] # type: ignore
BaseMessage[Literal[MessageType.store], StoreContent]
):
...


MESSAGE_CLS_DICT = {
MESSAGE_CLS_DICT: Dict[Any, Type[AggregateMessage | ForgetMessage | InstanceMessage | PostMessage | ProgramMessage | StoreMessage]] = {
MessageType.aggregate: AggregateMessage,
MessageType.forget: ForgetMessage,
MessageType.instance: InstanceMessage,
Expand All @@ -125,13 +126,13 @@ def format_message(message: MessageDb) -> AlephMessage:
message_type = message.type

message_cls = MESSAGE_CLS_DICT[message_type]
return message_cls.from_orm(message) # type: ignore[return-value]
return message_cls.from_orm(message)


def format_message_dict(message: Dict[str, Any]) -> AlephMessage:
message_type = message.get("type")
message_cls = MESSAGE_CLS_DICT[message_type]
return message_cls.parse_obj(message) # type: ignore[return-value]
return message_cls.parse_obj(message)


class BaseMessageStatus(BaseModel):
Expand Down
18 changes: 9 additions & 9 deletions src/aleph/schemas/pending_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
in aleph-client.
"""

from typing import Any, Literal, Generic
from typing import Any, Literal, Generic, Dict, Type

from aleph_message.models import (
AggregateContent,
Expand Down Expand Up @@ -93,43 +93,43 @@ def load_content(cls, values):


class PendingAggregateMessage(
BasePendingMessage[Literal[MessageType.aggregate], AggregateContent] # type: ignore
BasePendingMessage[Literal[MessageType.aggregate], AggregateContent]
):
pass


class PendingForgetMessage(
BasePendingMessage[Literal[MessageType.forget], ForgetContent] # type: ignore
BasePendingMessage[Literal[MessageType.forget], ForgetContent]
):
pass


class PendingInstanceMessage(
BasePendingMessage[Literal[MessageType.instance], InstanceContent] # type: ignore
BasePendingMessage[Literal[MessageType.instance], InstanceContent]
):
pass


class PendingPostMessage(BasePendingMessage[Literal[MessageType.post], PostContent]): # type: ignore
class PendingPostMessage(BasePendingMessage[Literal[MessageType.post], PostContent]):
pass


class PendingProgramMessage(
BasePendingMessage[Literal[MessageType.program], ProgramContent] # type: ignore
BasePendingMessage[Literal[MessageType.program], ProgramContent]
):
pass


class PendingStoreMessage(BasePendingMessage[Literal[MessageType.store], StoreContent]): # type: ignore
class PendingStoreMessage(BasePendingMessage[Literal[MessageType.store], StoreContent]):
pass


class PendingInlineStoreMessage(PendingStoreMessage):
item_content: str
item_type: Literal[ItemType.inline] # type: ignore[valid-type]
item_type: Literal[ItemType.inline]


MESSAGE_TYPE_TO_CLASS = {
MESSAGE_TYPE_TO_CLASS: Dict[Any, Type[PendingAggregateMessage | PendingForgetMessage | PendingInstanceMessage | PendingPostMessage | PendingProgramMessage | PendingStoreMessage]] = {
MessageType.aggregate: PendingAggregateMessage,
MessageType.forget: PendingForgetMessage,
MessageType.instance: PendingInstanceMessage,
Expand Down
Loading

0 comments on commit ea916ee

Please sign in to comment.