From 432c238a49d1db3fe78672a686be1442a385ef97 Mon Sep 17 00:00:00 2001 From: qstokkink Date: Mon, 18 Nov 2024 16:57:32 +0100 Subject: [PATCH 1/2] Updated dataclass payloads to use base class --- doc/basics/overlay_tutorial_5.py | 7 +- doc/basics/testbase_tutorial.rst | 8 +- doc/basics/testbase_tutorial_2.py | 15 ++-- doc/reference/serialization.rst | 10 +-- doc/reference/serialization_1.py | 6 +- doc/reference/serialization_4.py | 7 +- doc/reference/serialization_7.py | 8 +- ipv8/messaging/payload_dataclass.py | 88 ++++++++++++------- ipv8/test/messaging/test_payload_dataclass.py | 57 ++++++------ 9 files changed, 118 insertions(+), 88 deletions(-) diff --git a/doc/basics/overlay_tutorial_5.py b/doc/basics/overlay_tutorial_5.py index 13b7f24f0..86957a34f 100644 --- a/doc/basics/overlay_tutorial_5.py +++ b/doc/basics/overlay_tutorial_5.py @@ -1,17 +1,18 @@ import os from asyncio import run +from dataclasses import dataclass from ipv8.community import Community, CommunitySettings from ipv8.configuration import ConfigBuilder, Strategy, WalkerDefinition, default_bootstrap_defs from ipv8.lazy_community import lazy_wrapper -from ipv8.messaging.payload_dataclass import dataclass +from ipv8.messaging.payload_dataclass import DataClassPayload from ipv8.types import Peer from ipv8.util import run_forever from ipv8_service import IPv8 -@dataclass(msg_id=1) # The value 1 identifies this message and must be unique per community -class MyMessage: +@dataclass +class MyMessage(DataClassPayload[1]): # The value 1 identifies this message and must be unique per community clock: int # We add an integer (technically a "long long") field "clock" to this message diff --git a/doc/basics/testbase_tutorial.rst b/doc/basics/testbase_tutorial.rst index 48d4763d2..f05e891c3 100644 --- a/doc/basics/testbase_tutorial.rst +++ b/doc/basics/testbase_tutorial.rst @@ -206,28 +206,28 @@ In the following example peer 0 first sends message 1 and then sends message 2 t The following construction asserts this: .. literalinclude:: testbase_tutorial_2.py - :lines: 65-68 + :lines: 66-69 :dedent: 4 Sometimes, you can't be sure in what order messages are sent. In these cases you can use ``ordered=False``: .. literalinclude:: testbase_tutorial_2.py - :lines: 71-77 + :lines: 72-78 :dedent: 4 In other cases, your overlay may be sending messages which you cannot control and/or which you don't care about. In these cases you can set a filter to only include the messages you want: .. literalinclude:: testbase_tutorial_2.py - :lines: 80-85 + :lines: 81-86 :dedent: 4 It may also be helpful to inspect the contents of each payload. You can simply use the return value of the assert function to perform further inspection: .. literalinclude:: testbase_tutorial_2.py - :lines: 88-95 + :lines: 89-96 :dedent: 4 If you want to use ``assertReceivedBy()``, make sure that: diff --git a/doc/basics/testbase_tutorial_2.py b/doc/basics/testbase_tutorial_2.py index 42f2dbba2..ecda5f01c 100644 --- a/doc/basics/testbase_tutorial_2.py +++ b/doc/basics/testbase_tutorial_2.py @@ -1,26 +1,27 @@ import os import unittest +from dataclasses import dataclass from random import random, shuffle from ipv8.community import Community, CommunitySettings from ipv8.lazy_community import lazy_wrapper, lazy_wrapper_unsigned -from ipv8.messaging.payload_dataclass import dataclass +from ipv8.messaging.payload_dataclass import DataClassPayload from ipv8.test.base import TestBase from ipv8.types import Peer -@dataclass(msg_id=1) -class Message1: +@dataclass +class Message1(DataClassPayload[1]): value: int -@dataclass(msg_id=2) -class Message2: +@dataclass +class Message2(DataClassPayload[2]): value: int -@dataclass(msg_id=3) -class Message3: +@dataclass +class Message3(DataClassPayload[3]): value: int diff --git a/doc/reference/serialization.rst b/doc/reference/serialization.rst index dbf18e4cc..8c1ab1de7 100644 --- a/doc/reference/serialization.rst +++ b/doc/reference/serialization.rst @@ -38,14 +38,14 @@ If the ``dataclass`` had used normal ``int`` types, these would have been two si Each instance will have two fields: ``field1`` and ``field2`` corresponding to the integer and short. .. literalinclude:: serialization_1.py - :lines: 9-61 + :lines: 11-63 To show some of the differences, let's check out the output of the following script using these definitions: .. literalinclude:: serialization_1.py - :lines: 64-75 + :lines: 66-77 .. code-block:: bash @@ -195,7 +195,7 @@ This method involves implementing the methods ``fix_pack_`` and Check out the following example: .. literalinclude:: serialization_4.py - :lines: 11-36 + :lines: 12-37 In both classes we create a message with a single field ``dictionary``. To pack this field, we use ``json.dumps()`` to create a string representation of the dictionary. @@ -250,9 +250,9 @@ You can specify them by using the ``"payload"`` datatype and setting the ``Paylo For a ``VariablePayload`` this looks like the following example. .. literalinclude:: serialization_7.py - :lines: 5-12 + :lines: 7-14 For dataclass payloads this nesting is supported by simply specifying nested classes as follows. .. literalinclude:: serialization_7.py - :lines: 15-24 + :lines: 17-26 diff --git a/doc/reference/serialization_1.py b/doc/reference/serialization_1.py index 47219226f..3dd063da0 100644 --- a/doc/reference/serialization_1.py +++ b/doc/reference/serialization_1.py @@ -1,8 +1,10 @@ from __future__ import annotations +from dataclasses import dataclass + from ipv8.messaging.lazy_payload import VariablePayload, vp_compile from ipv8.messaging.payload import Payload -from ipv8.messaging.payload_dataclass import dataclass, type_from_format +from ipv8.messaging.payload_dataclass import DataClassPayload, type_from_format from ipv8.messaging.serialization import Serializable @@ -56,7 +58,7 @@ class MyCVariablePayload(VariablePayload): @dataclass -class MyDataclassPayload: +class MyDataclassPayload(DataClassPayload): field1: I field2: H diff --git a/doc/reference/serialization_4.py b/doc/reference/serialization_4.py index 855645d6d..d464eabe7 100644 --- a/doc/reference/serialization_4.py +++ b/doc/reference/serialization_4.py @@ -1,10 +1,11 @@ from __future__ import annotations import json +from dataclasses import dataclass from typing import cast from ipv8.messaging.lazy_payload import VariablePayload, vp_compile -from ipv8.messaging.payload_dataclass import dataclass +from ipv8.messaging.payload_dataclass import DataClassPayload from ipv8.messaging.serialization import default_serializer @@ -23,8 +24,8 @@ def fix_unpack_dictionary(cls: type[VPMessageKeepDict], return json.loads(serialized_dictionary.decode()) -@dataclass(msg_id=2) -class DCMessageKeepDict: +@dataclass +class DCMessageKeepDict(DataClassPayload[2]): dictionary: str def fix_pack_dictionary(self, the_dictionary: dict) -> str: diff --git a/doc/reference/serialization_7.py b/doc/reference/serialization_7.py index 75a14fbe3..1c99f9797 100644 --- a/doc/reference/serialization_7.py +++ b/doc/reference/serialization_7.py @@ -1,5 +1,7 @@ +from dataclasses import dataclass + from ipv8.messaging.lazy_payload import VariablePayload -from ipv8.messaging.payload_dataclass import dataclass +from ipv8.messaging.payload_dataclass import DataClassPayload class A(VariablePayload): @@ -12,8 +14,8 @@ class B(VariablePayload): names = ["a", "baz"] -@dataclass(msg_id=1) -class Message: +@dataclass +class Message(DataClassPayload[1]): @dataclass class Item: foo: int diff --git a/ipv8/messaging/payload_dataclass.py b/ipv8/messaging/payload_dataclass.py index 5b8290b8b..7a1f4cbb6 100644 --- a/ipv8/messaging/payload_dataclass.py +++ b/ipv8/messaging/payload_dataclass.py @@ -1,11 +1,12 @@ from __future__ import annotations -from collections.abc import Iterable -from dataclasses import dataclass as ogdataclass -from functools import partial -from typing import Callable, TypeVar, cast, get_args, get_type_hints +import dataclasses +import sys +from typing import Any, TypeVar, cast, get_args, get_type_hints -from .lazy_payload import VariablePayload, vp_compile +from typing_extensions import Self + +from .lazy_payload import VariablePayload, VariablePayloadWID, vp_compile from .serialization import FormatListType, Serializable @@ -41,37 +42,60 @@ def type_map(t: type) -> FormatListType: # noqa: PLR0911 raise NotImplementedError(t, " unknown") -def dataclass(cls: type | None = None, *, # noqa: PLR0913 - init: bool = True, - repr: bool = True, # noqa: A002 - eq: bool = True, - order: bool = False, - unsafe_hash: bool = False, - frozen: bool = False, - msg_id: int | None = None) -> partial[type[VariablePayload]] | type[VariablePayload]: - """ - Equivalent to ``@dataclass``, but also makes the wrapped class a ``VariablePayload``. +def convert_to_payload(dataclass_type: type, msg_id: int | None = None) -> None: + if msg_id is not None: + dataclass_type.msg_id = msg_id # type: ignore[attr-defined] + dt_fields = dataclasses.fields(dataclass_type) + type_hints = get_type_hints(dataclass_type) + dataclass_type.names = [field.name for field in dt_fields] # type: ignore[attr-defined] + dataclass_type.format_list = [type_map(type_hints[field.name]) for field in # type: ignore[attr-defined] + dt_fields] + setattr(sys.modules[dataclass_type.__module__], dataclass_type.__name__, vp_compile(dataclass_type)) + - See ``dataclasses.dataclass`` for argument descriptions. +class DataClassPayload(VariablePayload): + """ + A Payload that is defined as a dataclass. """ - if cls is None: - # Forward user parameters. Format: ``@dataclass(foo=bar)``. - return partial(cast(Callable[..., type[VariablePayload]], dataclass), init=init, repr=repr, eq=eq, order=order, - unsafe_hash=unsafe_hash, frozen=frozen, msg_id=msg_id) - # Finally, we have the actual class. Format: ``@dataclass`` or forwarded from partial (see above). - origin: type = ogdataclass(cls, init=init, repr=repr, eq=eq, order=order, # type: ignore[call-overload] - unsafe_hash=unsafe_hash, frozen=frozen) + def __class_getitem__(cls, item: int) -> DataClassPayloadWID: + """ + Syntactic sugar to add a msg_id attribute into the class inheritance structure. - class DataClassPayload(origin, VariablePayload): - names = list(get_type_hints(cls).keys()) - format_list = list(map(type_map, cast(Iterable[type], get_type_hints(cls).values()))) + | - if msg_id is not None: - setattr(DataClassPayload, "msg_id", msg_id) # noqa: B010 - DataClassPayload.__name__ = cls.__name__ - DataClassPayload.__qualname__ = cls.__qualname__ - return vp_compile(DataClassPayload) + .. code-block:: + + class MyPayload(DataClassPayload[12]): + pass + + assert MyPayload().msg_id == 12 + + :param item: The item to get, i.e., the message id + """ + return cast(DataClassPayloadWID, type(cls.__name__, (DataClassPayloadWID, ), {"msg_id": item})) + + def __new__(cls, *args: Any, **kwargs) -> Self: # noqa: ANN401, ARG003 + """ + Allocate memory for a new DataClassPayload class. + """ + out = super().__new__(cls) + convert_to_payload(cls) + return out + + +class DataClassPayloadWID(VariablePayloadWID): + """ + A Payload that is defined as a dataclass and has a message id [0, 255]. + """ + + def __new__(cls, *args: Any, **kwargs) -> Self: # noqa: ANN401, ARG003 + """ + Allocate memory for a new DataClassPayloadWID class. + """ + out = super().__new__(cls) + convert_to_payload(cls, msg_id=cls.msg_id) + return out -__all__ = ['dataclass', 'type_from_format'] +__all__ = ['DataClassPayload', 'type_from_format'] diff --git a/ipv8/test/messaging/test_payload_dataclass.py b/ipv8/test/messaging/test_payload_dataclass.py index 2ca9e5938..4dc896d2e 100644 --- a/ipv8/test/messaging/test_payload_dataclass.py +++ b/ipv8/test/messaging/test_payload_dataclass.py @@ -1,10 +1,9 @@ from __future__ import annotations -from dataclasses import dataclass as ogdataclass -from dataclasses import is_dataclass +from dataclasses import dataclass, is_dataclass from typing import TypeVar -from ...messaging.payload_dataclass import dataclass, type_from_format +from ...messaging.payload_dataclass import DataClassPayload, type_from_format from ...messaging.serialization import default_serializer from ..base import TestBase @@ -14,7 +13,7 @@ @dataclass -class NativeBool: +class NativeBool(DataClassPayload): """ A single boolean payload. """ @@ -23,7 +22,7 @@ class NativeBool: @dataclass -class NativeInt: +class NativeInt(DataClassPayload): """ A single integer payload. """ @@ -32,7 +31,7 @@ class NativeInt: @dataclass -class NativeBytes: +class NativeBytes(DataClassPayload): """ A single bytes payload. """ @@ -41,7 +40,7 @@ class NativeBytes: @dataclass -class NativeStr: +class NativeStr(DataClassPayload): """ A single string payload. """ @@ -50,7 +49,7 @@ class NativeStr: @dataclass -class SerializerType: +class SerializerType(DataClassPayload): """ A ``Serializer`` format payload. """ @@ -59,7 +58,7 @@ class SerializerType: @dataclass -class NestedType: +class NestedType(DataClassPayload): """ A single nested payload. """ @@ -68,31 +67,34 @@ class NestedType: @dataclass -class NestedListType: +class NestedListType(DataClassPayload): """ A single list of nested payload. """ a: list[NativeInt] + @dataclass -class ListIntType: +class ListIntType(DataClassPayload): """ A single list of integers. """ a: list[int] + @dataclass -class ListBoolType: +class ListBoolType(DataClassPayload): """ A single list of booleans. """ a: list[bool] -@ogdataclass -class Unknown: + +@dataclass +class Unknown(DataClassPayload): """ To whomever is reading this and wondering why dict is not supported: use a nested payload instead. """ @@ -101,7 +103,7 @@ class Unknown: @dataclass -class A: +class A(DataClassPayload): """ A payload consisting of two integers. """ @@ -111,7 +113,7 @@ class A: @dataclass -class B: +class B(DataClassPayload): """ A payload consisting of two integers, of which one has a default value. """ @@ -121,7 +123,7 @@ class B: @dataclass(eq=False) -class FwdDataclass: +class FwdDataclass(DataClassPayload): """ A payload to test if the dataclass overwrite forwards its arguments to the "real" dataclass. """ @@ -130,7 +132,7 @@ class FwdDataclass: @dataclass -class StripMsgId: +class StripMsgId(DataClassPayload): """ Payload to make sure that the message id is not seen as a field. """ @@ -142,20 +144,17 @@ class StripMsgId: format_list = [] # Expose secret VariablePayload list -@dataclass(msg_id=1) -class FwdMsgId: +@dataclass +class FwdMsgId(DataClassPayload[1]): """ Payload that specfies the message id as an argument to the dataclass overwrite. """ a: int - names = [] # Expose secret VariablePayload list - format_list = [] # Expose secret VariablePayload list - @dataclass -class EverythingItem: +class EverythingItem(DataClassPayload): """ An item for the following Everything payload. """ @@ -164,8 +163,7 @@ class EverythingItem: @dataclass -@ogdataclass -class Everything: +class Everything(DataClassPayload): """ Dataclass payload that includes all functionality. """ @@ -411,7 +409,8 @@ def test_unknown_payload(self) -> None: """ Check if an unknown type raises an error. """ - self.assertRaises(NotImplementedError, dataclass, Unknown) + self.assertRaises(NotImplementedError, Unknown, {"a": "b"}) + def test_fwd_args(self) -> None: """ @@ -424,7 +423,7 @@ def test_fwd_args(self) -> None: def test_strip_msg_id(self) -> None: """ - Check if the ``msg_id`` field is identifier and stripped. + Check if the ``msg_id`` field is identified and stripped. """ payload = StripMsgId(42) @@ -435,7 +434,7 @@ def test_strip_msg_id(self) -> None: def test_fwd_msg_id(self) -> None: """ - Check if the ``msg_id`` argument is sets the Payload ``msg_id``. + Check if the ``msg_id`` argument sets the Payload ``msg_id``. """ payload = FwdMsgId(42) From d78dbacb7a6418af9424b8c98e72f2ce4b0fa8f9 Mon Sep 17 00:00:00 2001 From: qstokkink Date: Mon, 18 Nov 2024 18:46:33 +0100 Subject: [PATCH 2/2] Fix 3.13 locals() compatibility --- ipv8/messaging/lazy_payload.py | 13 +++++++------ ipv8/test/messaging/test_lazy_payload.py | 20 ++++++++++++-------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/ipv8/messaging/lazy_payload.py b/ipv8/messaging/lazy_payload.py index bd417ce95..515ed7678 100644 --- a/ipv8/messaging/lazy_payload.py +++ b/ipv8/messaging/lazy_payload.py @@ -230,22 +230,23 @@ def vp_compile(vp_definition: type[T]) -> type[T]: """ # We use ``exec`` purposefully here, disable the pylint warning: # ruff: noqa: B010, S102 + local_scope = locals() # Load the function definitions into the local scope. exec(_compile_init(vp_definition.names, { k: v.default for k, v in inspect.signature(vp_definition.__init__).parameters.items() if v.default is not inspect.Parameter.empty - }), globals(), locals()) - exec(_compile_from_unpack_list(vp_definition, vp_definition.names), globals(), locals()) - exec(_compile_to_pack_list(vp_definition, vp_definition.format_list, vp_definition.names), globals(), locals()) + }), globals(), local_scope) + exec(_compile_from_unpack_list(vp_definition, vp_definition.names), globals(), local_scope) + exec(_compile_to_pack_list(vp_definition, vp_definition.format_list, vp_definition.names), globals(), local_scope) # Rewrite the class methods from the locally loaded overwrites. # from_unpack_list is a classmethod, so we need to scope it properly. - setattr(vp_definition, '__init__', locals()['__init__']) + setattr(vp_definition, '__init__', local_scope['__init__']) setattr(vp_definition, '__match_args__', tuple(vp_definition.names)) - setattr(vp_definition, 'from_unpack_list', types.MethodType(locals()['from_unpack_list'], vp_definition)) - setattr(vp_definition, 'to_pack_list', locals()['to_pack_list']) + setattr(vp_definition, 'from_unpack_list', types.MethodType(local_scope['from_unpack_list'], vp_definition)) + setattr(vp_definition, 'to_pack_list', local_scope['to_pack_list']) return vp_definition diff --git a/ipv8/test/messaging/test_lazy_payload.py b/ipv8/test/messaging/test_lazy_payload.py index c677b6738..8156dd276 100644 --- a/ipv8/test/messaging/test_lazy_payload.py +++ b/ipv8/test/messaging/test_lazy_payload.py @@ -511,6 +511,7 @@ def test_plain_mismatch_list(self) -> None: If this test fails, you probably screwed up the class-level sub-pattern. """ payload = BitsPayload(False, True, False, True, False, True, False, True) + local_scope = locals() # The following will crash all interpreters < 3.10 if not contained in a string. exec( # noqa: S102 @@ -520,9 +521,9 @@ def test_plain_mismatch_list(self) -> None: matched = True case _: matched = False -""", '', 'exec'), globals(), locals()) +""", '', 'exec'), globals(), local_scope) - self.assertFalse(locals()["matched"]) + self.assertFalse(local_scope["matched"]) @skipUnlessPython310 def test_compiled_mismatch_list(self) -> None: @@ -533,6 +534,7 @@ def test_compiled_mismatch_list(self) -> None: If this test fails, you probably screwed up the class-level sub-pattern. """ payload = CompiledBitsPayload(False, True, False, True, False, True, False, True) + local_scope = locals() # The following will crash all interpreters < 3.10 if not contained in a string. exec( # noqa: S102 @@ -542,9 +544,9 @@ def test_compiled_mismatch_list(self) -> None: matched = True case _: matched = False -""", '', 'exec'), globals(), locals()) +""", '', 'exec'), globals(), local_scope) - self.assertFalse(locals()["matched"]) + self.assertFalse(local_scope["matched"]) @skipUnlessPython310 def test_plain_match_pattern(self) -> None: @@ -552,6 +554,7 @@ def test_plain_match_pattern(self) -> None: Check if a VariablePayload instance matches its own pattern. """ payload = BitsPayload(False, True, False, True, False, True, False, True) + local_scope = locals() # The following will crash all interpreters < 3.10 if not contained in a string. exec( # noqa: S102 @@ -561,9 +564,9 @@ def test_plain_match_pattern(self) -> None: matched = True case _: matched = False -""", '', 'exec'), globals(), locals()) +""", '', 'exec'), globals(), local_scope) - self.assertTrue(locals()["matched"]) + self.assertTrue(local_scope["matched"]) @skipUnlessPython310 def test_compiled_match_pattern(self) -> None: @@ -571,6 +574,7 @@ def test_compiled_match_pattern(self) -> None: Check if a compiled VariablePayload instance matches its own pattern. """ payload = CompiledBitsPayload(False, True, False, True, False, True, False, True) + local_scope = locals() # The following will crash all interpreters < 3.10 if not contained in a string. exec( # noqa: S102 @@ -580,6 +584,6 @@ def test_compiled_match_pattern(self) -> None: matched = True case _: matched = False -""", '', 'exec'), globals(), locals()) +""", '', 'exec'), globals(), local_scope) - self.assertTrue(locals()["matched"]) + self.assertTrue(local_scope["matched"])