Skip to content

Commit

Permalink
Merge pull request #1323 from qstokkink/upd_dc_payload_type_hints
Browse files Browse the repository at this point in the history
Updated dataclass payloads to use a base class
  • Loading branch information
qstokkink authored Nov 19, 2024
2 parents 1592671 + d78dbac commit b99ba77
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 102 deletions.
7 changes: 4 additions & 3 deletions doc/basics/overlay_tutorial_5.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import os
from asyncio import run
from dataclasses import dataclass

from ipv8.community import Community, CommunitySettings
from ipv8.configuration import ConfigBuilder, Strategy, WalkerDefinition, default_bootstrap_defs
from ipv8.lazy_community import lazy_wrapper
from ipv8.messaging.payload_dataclass import dataclass
from ipv8.messaging.payload_dataclass import DataClassPayload
from ipv8.types import Peer
from ipv8.util import run_forever
from ipv8_service import IPv8


@dataclass(msg_id=1) # The value 1 identifies this message and must be unique per community
class MyMessage:
@dataclass
class MyMessage(DataClassPayload[1]): # The value 1 identifies this message and must be unique per community
clock: int # We add an integer (technically a "long long") field "clock" to this message


Expand Down
8 changes: 4 additions & 4 deletions doc/basics/testbase_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,28 +206,28 @@ In the following example peer 0 first sends message 1 and then sends message 2 t
The following construction asserts this:

.. literalinclude:: testbase_tutorial_2.py
:lines: 65-68
:lines: 66-69
:dedent: 4

Sometimes, you can't be sure in what order messages are sent.
In these cases you can use ``ordered=False``:

.. literalinclude:: testbase_tutorial_2.py
:lines: 71-77
:lines: 72-78
:dedent: 4

In other cases, your overlay may be sending messages which you cannot control and/or which you don't care about.
In these cases you can set a filter to only include the messages you want:

.. literalinclude:: testbase_tutorial_2.py
:lines: 80-85
:lines: 81-86
:dedent: 4

It may also be helpful to inspect the contents of each payload.
You can simply use the return value of the assert function to perform further inspection:

.. literalinclude:: testbase_tutorial_2.py
:lines: 88-95
:lines: 89-96
:dedent: 4

If you want to use ``assertReceivedBy()``, make sure that:
Expand Down
15 changes: 8 additions & 7 deletions doc/basics/testbase_tutorial_2.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import os
import unittest
from dataclasses import dataclass
from random import random, shuffle

from ipv8.community import Community, CommunitySettings
from ipv8.lazy_community import lazy_wrapper, lazy_wrapper_unsigned
from ipv8.messaging.payload_dataclass import dataclass
from ipv8.messaging.payload_dataclass import DataClassPayload
from ipv8.test.base import TestBase
from ipv8.types import Peer


@dataclass(msg_id=1)
class Message1:
@dataclass
class Message1(DataClassPayload[1]):
value: int


@dataclass(msg_id=2)
class Message2:
@dataclass
class Message2(DataClassPayload[2]):
value: int


@dataclass(msg_id=3)
class Message3:
@dataclass
class Message3(DataClassPayload[3]):
value: int


Expand Down
10 changes: 5 additions & 5 deletions doc/reference/serialization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ If the ``dataclass`` had used normal ``int`` types, these would have been two si
Each instance will have two fields: ``field1`` and ``field2`` corresponding to the integer and short.

.. literalinclude:: serialization_1.py
:lines: 9-61
:lines: 11-63


To show some of the differences, let's check out the output of the following script using these definitions:


.. literalinclude:: serialization_1.py
:lines: 64-75
:lines: 66-77


.. code-block:: bash
Expand Down Expand Up @@ -195,7 +195,7 @@ This method involves implementing the methods ``fix_pack_<your field name>`` and
Check out the following example:

.. literalinclude:: serialization_4.py
:lines: 11-36
:lines: 12-37

In both classes we create a message with a single field ``dictionary``.
To pack this field, we use ``json.dumps()`` to create a string representation of the dictionary.
Expand Down Expand Up @@ -250,9 +250,9 @@ You can specify them by using the ``"payload"`` datatype and setting the ``Paylo
For a ``VariablePayload`` this looks like the following example.

.. literalinclude:: serialization_7.py
:lines: 5-12
:lines: 7-14

For dataclass payloads this nesting is supported by simply specifying nested classes as follows.

.. literalinclude:: serialization_7.py
:lines: 15-24
:lines: 17-26
6 changes: 4 additions & 2 deletions doc/reference/serialization_1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass

from ipv8.messaging.lazy_payload import VariablePayload, vp_compile
from ipv8.messaging.payload import Payload
from ipv8.messaging.payload_dataclass import dataclass, type_from_format
from ipv8.messaging.payload_dataclass import DataClassPayload, type_from_format
from ipv8.messaging.serialization import Serializable


Expand Down Expand Up @@ -56,7 +58,7 @@ class MyCVariablePayload(VariablePayload):


@dataclass
class MyDataclassPayload:
class MyDataclassPayload(DataClassPayload):
field1: I
field2: H

Expand Down
7 changes: 4 additions & 3 deletions doc/reference/serialization_4.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import json
from dataclasses import dataclass
from typing import cast

from ipv8.messaging.lazy_payload import VariablePayload, vp_compile
from ipv8.messaging.payload_dataclass import dataclass
from ipv8.messaging.payload_dataclass import DataClassPayload
from ipv8.messaging.serialization import default_serializer


Expand All @@ -23,8 +24,8 @@ def fix_unpack_dictionary(cls: type[VPMessageKeepDict],
return json.loads(serialized_dictionary.decode())


@dataclass(msg_id=2)
class DCMessageKeepDict:
@dataclass
class DCMessageKeepDict(DataClassPayload[2]):
dictionary: str

def fix_pack_dictionary(self, the_dictionary: dict) -> str:
Expand Down
8 changes: 5 additions & 3 deletions doc/reference/serialization_7.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass

from ipv8.messaging.lazy_payload import VariablePayload
from ipv8.messaging.payload_dataclass import dataclass
from ipv8.messaging.payload_dataclass import DataClassPayload


class A(VariablePayload):
Expand All @@ -12,8 +14,8 @@ class B(VariablePayload):
names = ["a", "baz"]


@dataclass(msg_id=1)
class Message:
@dataclass
class Message(DataClassPayload[1]):
@dataclass
class Item:
foo: int
Expand Down
13 changes: 7 additions & 6 deletions ipv8/messaging/lazy_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,22 +230,23 @@ def vp_compile(vp_definition: type[T]) -> type[T]:
"""
# We use ``exec`` purposefully here, disable the pylint warning:
# ruff: noqa: B010, S102
local_scope = locals()

# Load the function definitions into the local scope.
exec(_compile_init(vp_definition.names, {
k: v.default
for k, v in inspect.signature(vp_definition.__init__).parameters.items()
if v.default is not inspect.Parameter.empty
}), globals(), locals())
exec(_compile_from_unpack_list(vp_definition, vp_definition.names), globals(), locals())
exec(_compile_to_pack_list(vp_definition, vp_definition.format_list, vp_definition.names), globals(), locals())
}), globals(), local_scope)
exec(_compile_from_unpack_list(vp_definition, vp_definition.names), globals(), local_scope)
exec(_compile_to_pack_list(vp_definition, vp_definition.format_list, vp_definition.names), globals(), local_scope)

# Rewrite the class methods from the locally loaded overwrites.
# from_unpack_list is a classmethod, so we need to scope it properly.
setattr(vp_definition, '__init__', locals()['__init__'])
setattr(vp_definition, '__init__', local_scope['__init__'])
setattr(vp_definition, '__match_args__', tuple(vp_definition.names))
setattr(vp_definition, 'from_unpack_list', types.MethodType(locals()['from_unpack_list'], vp_definition))
setattr(vp_definition, 'to_pack_list', locals()['to_pack_list'])
setattr(vp_definition, 'from_unpack_list', types.MethodType(local_scope['from_unpack_list'], vp_definition))
setattr(vp_definition, 'to_pack_list', local_scope['to_pack_list'])
return vp_definition


Expand Down
88 changes: 56 additions & 32 deletions ipv8/messaging/payload_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from collections.abc import Iterable
from dataclasses import dataclass as ogdataclass
from functools import partial
from typing import Callable, TypeVar, cast, get_args, get_type_hints
import dataclasses
import sys
from typing import Any, TypeVar, cast, get_args, get_type_hints

from .lazy_payload import VariablePayload, vp_compile
from typing_extensions import Self

from .lazy_payload import VariablePayload, VariablePayloadWID, vp_compile
from .serialization import FormatListType, Serializable


Expand Down Expand Up @@ -41,37 +42,60 @@ def type_map(t: type) -> FormatListType: # noqa: PLR0911
raise NotImplementedError(t, " unknown")


def dataclass(cls: type | None = None, *, # noqa: PLR0913
init: bool = True,
repr: bool = True, # noqa: A002
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
msg_id: int | None = None) -> partial[type[VariablePayload]] | type[VariablePayload]:
"""
Equivalent to ``@dataclass``, but also makes the wrapped class a ``VariablePayload``.
def convert_to_payload(dataclass_type: type, msg_id: int | None = None) -> None:
if msg_id is not None:
dataclass_type.msg_id = msg_id # type: ignore[attr-defined]
dt_fields = dataclasses.fields(dataclass_type)
type_hints = get_type_hints(dataclass_type)
dataclass_type.names = [field.name for field in dt_fields] # type: ignore[attr-defined]
dataclass_type.format_list = [type_map(type_hints[field.name]) for field in # type: ignore[attr-defined]
dt_fields]
setattr(sys.modules[dataclass_type.__module__], dataclass_type.__name__, vp_compile(dataclass_type))


See ``dataclasses.dataclass`` for argument descriptions.
class DataClassPayload(VariablePayload):
"""
A Payload that is defined as a dataclass.
"""
if cls is None:
# Forward user parameters. Format: ``@dataclass(foo=bar)``.
return partial(cast(Callable[..., type[VariablePayload]], dataclass), init=init, repr=repr, eq=eq, order=order,
unsafe_hash=unsafe_hash, frozen=frozen, msg_id=msg_id)

# Finally, we have the actual class. Format: ``@dataclass`` or forwarded from partial (see above).
origin: type = ogdataclass(cls, init=init, repr=repr, eq=eq, order=order, # type: ignore[call-overload]
unsafe_hash=unsafe_hash, frozen=frozen)
def __class_getitem__(cls, item: int) -> DataClassPayloadWID:
"""
Syntactic sugar to add a msg_id attribute into the class inheritance structure.
class DataClassPayload(origin, VariablePayload):
names = list(get_type_hints(cls).keys())
format_list = list(map(type_map, cast(Iterable[type], get_type_hints(cls).values())))
|
if msg_id is not None:
setattr(DataClassPayload, "msg_id", msg_id) # noqa: B010
DataClassPayload.__name__ = cls.__name__
DataClassPayload.__qualname__ = cls.__qualname__
return vp_compile(DataClassPayload)
.. code-block::
class MyPayload(DataClassPayload[12]):
pass
assert MyPayload().msg_id == 12
:param item: The item to get, i.e., the message id
"""
return cast(DataClassPayloadWID, type(cls.__name__, (DataClassPayloadWID, ), {"msg_id": item}))

def __new__(cls, *args: Any, **kwargs) -> Self: # noqa: ANN401, ARG003
"""
Allocate memory for a new DataClassPayload class.
"""
out = super().__new__(cls)
convert_to_payload(cls)
return out


class DataClassPayloadWID(VariablePayloadWID):
"""
A Payload that is defined as a dataclass and has a message id [0, 255].
"""

def __new__(cls, *args: Any, **kwargs) -> Self: # noqa: ANN401, ARG003
"""
Allocate memory for a new DataClassPayloadWID class.
"""
out = super().__new__(cls)
convert_to_payload(cls, msg_id=cls.msg_id)
return out


__all__ = ['dataclass', 'type_from_format']
__all__ = ['DataClassPayload', 'type_from_format']
Loading

0 comments on commit b99ba77

Please sign in to comment.