From c4af116ffcbdd72a78540493b4346783cbc89c2c Mon Sep 17 00:00:00 2001 From: lojack5 <1458329+lojack5@users.noreply.github.com> Date: Sat, 10 Sep 2022 15:37:48 -0600 Subject: [PATCH 1/8] part 1: Correctly handle Generic types when the TypeVar is the annotation on the class attribute. --- structured/structured.py | 40 ++++++++++++++++++++++++++++++++++++++++ tests/test_generics.py | 23 +++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 tests/test_generics.py diff --git a/structured/structured.py b/structured/structured.py index e8c7497..e0f1de5 100644 --- a/structured/structured.py +++ b/structured/structured.py @@ -1,4 +1,7 @@ from __future__ import annotations +import operator +from typing import get_args, get_origin +import typing __all__ = [ 'Structured', @@ -337,6 +340,18 @@ def __init_subclass__( f'{base.__name__} ({base.byte_order.name}). ' 'If this is intentional, use `byte_order_mode=OVERRIDE`.' ) + if base: + # determine correct typehints for any Generic's in the base class + orig_bases = getattr(cls, '__orig_bases__', ()) + base_to_origbase = { + origin: orig_base + for orig_base in orig_bases + if (origin := get_origin(orig_base)) and issubclass(origin, Structured) + } + orig_base = base_to_origbase.get(base, None) + if orig_base: + updates = base._specialize(*get_args(orig_base)) + cls.__annotations__.update(updates) # Analyze the class typehints = get_type_hints(cls) serializer, attrs = create_serializer( @@ -347,3 +362,28 @@ def __init_subclass__( cls.attrs = attrs cls.byte_order = byte_order + @classmethod + def _specialize(cls, *args): + supers: dict[type[Structured], Any] = {} + tvars = () + for base in getattr(cls, '__orig_bases__', ()): + if (origin := get_origin(base)) is typing.Generic: + tvars = get_args(base) + elif origin and issubclass(origin, Structured): + supers[origin] = base + tvar_map = dict(zip(tvars, args)) + if not tvar_map: + return {} + annotations = {} + for attr, attr_type in get_type_hints(cls).items(): + if attr in cls.__annotations__: + # Attribute's final type hint comes from this class + if remapped_type := tvar_map.get(attr_type, None): + annotations[attr] = remapped_type + all_annotations = [annotations] + for base, alias in supers.items(): + args = get_args(alias) + args = (tvar_map.get(arg, arg) for arg in args) + super_annotations = base._specialize(*args) + all_annotations.append(super_annotations) + return reduce(operator.or_, reversed(all_annotations)) diff --git a/tests/test_generics.py b/tests/test_generics.py new file mode 100644 index 0000000..9c9ccd5 --- /dev/null +++ b/tests/test_generics.py @@ -0,0 +1,23 @@ +import struct +from structured import * +from typing import Generic, TypeVar, Union + +_Byte = TypeVar('_Byte', bound=Union[uint8, int8]) +_String = TypeVar('_String', bound=Union[char, pascal, unicode]) + + +class Base(Generic[_Byte, _String], Structured): + a: _Byte + b: _String + + +class UnsignedUnicode(Base[uint8, unicode[uint8]]): + pass + + +def test_generics() -> None: + obj = UnsignedUnicode(10, 'Hello') + target_data = struct.pack('BB5s', 10, 5, b'Hello') + + # pack/unpack + assert obj.pack() == target_data From e072541a12cbfdf4d1aa9f6f1ad6c34b529bd9d5 Mon Sep 17 00:00:00 2001 From: lojack5 <1458329+lojack5@users.noreply.github.com> Date: Sat, 10 Sep 2022 16:55:37 -0600 Subject: [PATCH 2/8] first go at handling TypVars in arguments to complex types --- structured/complex_types/array_headers.py | 6 +- structured/complex_types/arrays.py | 3 + structured/complex_types/strings.py | 23 +++++- structured/structured.py | 4 + structured/utils.py | 25 ++++++ tests/test_generics.py | 95 +++++++++++++++++++++-- 6 files changed, 147 insertions(+), 9 deletions(-) diff --git a/structured/complex_types/array_headers.py b/structured/complex_types/array_headers.py index 7130280..d049446 100644 --- a/structured/complex_types/array_headers.py +++ b/structured/complex_types/array_headers.py @@ -6,8 +6,9 @@ from __future__ import annotations from functools import cache +from typing import TypeVar -from ..utils import specialized +from ..utils import StructuredAlias, specialized from ..structured import Structured from ..basic_types import uint8, uint16, uint32, uint64 from ..type_checking import ClassVar, Union @@ -278,6 +279,9 @@ def __class_getitem__(cls, key) -> type[Header]: raise TypeError(f'{cls.__name__}[] expected two arguments') else: count, size_check = key + # TypeVar checks: + if isinstance(count, TypeVar) or isinstance(size_check, TypeVar): + return StructuredAlias(cls, (count, size_check)) try: if size_check is None: args = (count,) diff --git a/structured/complex_types/arrays.py b/structured/complex_types/arrays.py index f47191a..9f29798 100644 --- a/structured/complex_types/arrays.py +++ b/structured/complex_types/arrays.py @@ -74,6 +74,9 @@ def __class_getitem__( else: header = Header[args[:-1]] array_type = args[-1] + # TypeVar checks: + if isinstance(header, StructuredAlias) or isinstance(array_type, TypeVar): + return StructuredAlias(cls, (header, array_type)) if (not isinstance(header, type) or not issubclass(header, HeaderBase) or header is HeaderBase or diff --git a/structured/complex_types/strings.py b/structured/complex_types/strings.py index 8cf9061..69f25c9 100644 --- a/structured/complex_types/strings.py +++ b/structured/complex_types/strings.py @@ -10,8 +10,9 @@ from functools import cache, partial import struct +from typing import TypeVar -from ..utils import specialized +from ..utils import StructuredAlias, specialized from ..base_types import ( Serializer, StructSerializer, requires_indexing, ByteOrder, struct_cache, structured_type, counted, @@ -59,6 +60,7 @@ def __class_getitem__(cls, args) -> type[structured_type]: return cls._create(*args) @classmethod + @cache def _create( cls, count: Union[int, type[SizeTypes], type[NET]], @@ -69,6 +71,8 @@ def _create( new_cls = _dynamic_char[count] elif count is NET: new_cls = _net_char + elif isinstance(count, TypeVar): + return StructuredAlias(cls, (count,)) else: raise TypeError( f'{cls.__qualname__}[] count must be an int, NET, or uint* ' @@ -117,11 +121,15 @@ class unicode(str, requires_indexing): :type encoding: Union[str, type[EncoderDecoder]] """ @classmethod - @cache def __class_getitem__(cls, args) -> type[Serializer]: """Create the specialization.""" if not isinstance(args, tuple): args = (args, ) + # Cache doesn't place nice with default args, + # _create(uint8) + # _create(uint8, 'utf8') + # technically are different call types, so the cache isn't hit. + # Pass through an intermediary to take care of this. return cls.create(*args) @classmethod @@ -130,12 +138,23 @@ def create( count: Union[int, type[SizeTypes], type[NET]], encoding: Union[str, type[EncoderDecoder]] = 'utf8', ) -> type[Serializer]: + return cls._create(count, encoding) + + @classmethod + @cache + def _create( + cls, + count: Union[int, type[SizeTypes], type[NET]], + encoding: Union[str, type[EncoderDecoder]], + ) -> type[Serializer]: """Create the specialization. :param count: Size of the *encoded* string. :param encoding: Encoding method to use. :return: The specialized class. """ + if isinstance(count, TypeVar): + return StructuredAlias(cls, (count, encoding)) if isinstance(encoding, str): encoder = partial(str.encode, encoding=encoding) decoder = partial(bytes.decode, encoding=encoding) diff --git a/structured/structured.py b/structured/structured.py index e0f1de5..8dc6a0e 100644 --- a/structured/structured.py +++ b/structured/structured.py @@ -3,6 +3,8 @@ from typing import get_args, get_origin import typing +from structured.utils import StructuredAlias + __all__ = [ 'Structured', 'ByteOrder', 'ByteOrderMode', @@ -380,6 +382,8 @@ def _specialize(cls, *args): # Attribute's final type hint comes from this class if remapped_type := tvar_map.get(attr_type, None): annotations[attr] = remapped_type + elif isinstance(attr_type, StructuredAlias): + annotations[attr] = attr_type.resolve(tvar_map) all_annotations = [annotations] for base, alias in supers.items(): args = get_args(alias) diff --git a/structured/utils.py b/structured/utils.py index 7c99137..4912a53 100644 --- a/structured/utils.py +++ b/structured/utils.py @@ -1,6 +1,7 @@ """ Various utility methods. """ +from typing import TypeVar from .type_checking import _T, NoReturn, Any, Callable @@ -31,3 +32,27 @@ def wrapper(cls: type[_T]) -> type[_T]: cls.__name__ = f'{base_cls.__name__}[{name}]' return cls return wrapper + + +class StructuredAlias: + cls: type + args: tuple + + def __init__(self, cls, args): + self.cls = cls + self.args = args + + def resolve(self, tvar_map: dict[TypeVar, type]): + resolved = [] + for arg in self.args: + arg = tvar_map.get(arg, arg) + if isinstance(arg, StructuredAlias): + arg = arg.resolve(tvar_map) + resolved.append(arg) + resolved = tuple(resolved) + if any((isinstance(arg, (TypeVar, StructuredAlias)) for arg in resolved)): + # Act as immutable, so create a new instance, since these objects + # are often cached in type factory indexing methods. + return StructuredAlias(self.cls, resolved) + else: + return self.cls[resolved] diff --git a/tests/test_generics.py b/tests/test_generics.py index 9c9ccd5..2142f7c 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1,9 +1,15 @@ import struct from structured import * -from typing import Generic, TypeVar, Union +from typing import Generic, TypeVar, Union, get_type_hints + +from structured.utils import StructuredAlias _Byte = TypeVar('_Byte', bound=Union[uint8, int8]) _String = TypeVar('_String', bound=Union[char, pascal, unicode]) +_Size = TypeVar('_Size', bound=Union[uint8, uint16, uint32, uint64]) +T = TypeVar('T', bound=Structured) +U = TypeVar('U') +V = TypeVar('V') class Base(Generic[_Byte, _String], Structured): @@ -15,9 +21,86 @@ class UnsignedUnicode(Base[uint8, unicode[uint8]]): pass -def test_generics() -> None: - obj = UnsignedUnicode(10, 'Hello') - target_data = struct.pack('BB5s', 10, 5, b'Hello') +class TestAliasing: + class Item(Structured): + a: int8 + + tvar_map = { + _Size: uint32, + _Byte: uint8, + T: Item, + } + + def test_unicode(self) -> None: + obj = unicode[_Size] + assert isinstance(obj, StructuredAlias) + assert obj.cls is unicode + assert obj.args == (_Size, 'utf8') + assert obj.resolve(self.tvar_map) is unicode[uint32] + + def test_Header(self) -> None: + obj = Header[1, _Size] + assert isinstance(obj, StructuredAlias) + assert obj.cls is Header + assert obj.args == (1, _Size) + assert obj.resolve(self.tvar_map) is Header[1, uint32] + + def test_char(self) -> None: + obj = char[_Size] + assert isinstance(obj, StructuredAlias) + assert obj.cls is char + assert obj.args == (_Size,) + assert obj.resolve(self.tvar_map) is char[uint32] + + def test_array(self) -> None: + # same typevar + obj = array[Header[_Size], _Size] + assert isinstance(obj, StructuredAlias) + assert obj.cls is array + assert obj.args == (Header[_Size], _Size) + assert obj.resolve(self.tvar_map) is array[Header[uint32], uint32] + + # different typevars + obj = array[Header[_Size, _Byte], T] + assert isinstance(obj, StructuredAlias) + assert obj.cls is array + assert obj.args == (Header[_Size, _Byte], T) + + obj1 = obj.resolve({_Size: uint32}) + assert isinstance(obj1, StructuredAlias) + assert isinstance(obj1.args[0], StructuredAlias) + assert obj1.args[0].args == (uint32, _Byte) + assert obj1.args[1] is T + + obj2 = obj1.resolve({_Byte: uint8}) + assert isinstance(obj2, StructuredAlias) + assert obj2.cls is array + assert obj2.args == (Header[uint32, uint8], T) + + obj3 = obj2.resolve({T: self.Item}) + assert obj3 is array[Header[uint32, uint8], self.Item] + assert obj.resolve(self.tvar_map) is array[Header[uint32, uint8], self.Item] + + +def test_automatic_resolution(): + class Item(Structured): + a: int8 + + class Base(Generic[_Size, T, U, V], Structured): + a: _Size + b: unicode[U] + c: array[Header[1, V], T] + + class PartiallySpecialized(Generic[U, T], Base[uint8, T, uint32, U]): pass + class FullySpecialized1(Base[uint8, Item, uint32, uint16]): pass + class FullySpecialized2(PartiallySpecialized[uint16, Item]): pass + + assert PartiallySpecialized.attrs == ('a', 'b') + hints = get_type_hints(PartiallySpecialized) + assert hints['a'] is uint8 + assert hints['b'] is unicode[uint32] + assert isinstance(hints['c'], StructuredAlias) - # pack/unpack - assert obj.pack() == target_data + assert FullySpecialized1.attrs == FullySpecialized2.attrs + assert FullySpecialized1.attrs == ('a', 'b', 'c') + assert get_type_hints(FullySpecialized1) == get_type_hints(FullySpecialized2) From 7a8d88a3fb10b69ffabecd5ff33939959c55ae14 Mon Sep 17 00:00:00 2001 From: lojack5 <1458329+lojack5@users.noreply.github.com> Date: Sat, 10 Sep 2022 18:04:19 -0600 Subject: [PATCH 3/8] rework array headers to use the new generic capabilities --- structured/complex_types/array_headers.py | 230 ++++++---------------- 1 file changed, 63 insertions(+), 167 deletions(-) diff --git a/structured/complex_types/array_headers.py b/structured/complex_types/array_headers.py index d049446..6dbd926 100644 --- a/structured/complex_types/array_headers.py +++ b/structured/complex_types/array_headers.py @@ -6,15 +6,17 @@ from __future__ import annotations from functools import cache -from typing import TypeVar from ..utils import StructuredAlias, specialized from ..structured import Structured from ..basic_types import uint8, uint16, uint32, uint64 -from ..type_checking import ClassVar, Union +from ..type_checking import ClassVar, Union, Generic, Optional, TypeVar +_SizeTypes = (uint8, uint16, uint32, uint64) SizeTypes = Union[uint8, uint16, uint32, uint64] +TSize = TypeVar('TSize', bound=SizeTypes) +TCount = TypeVar('TCount', bound=SizeTypes) class HeaderBase: @@ -54,68 +56,45 @@ def count(self, new_count: int) -> None: f'expected an array of length {self.count}, but got {new_count}' ) - def __class_getitem__( - cls: type[StaticHeader], - count: int, - ) -> type[StaticHeader]: + @classmethod + def specialize(cls, count: int) -> type[StaticHeader]: """Specialize for a specific static size.""" if count <= 0: raise ValueError('count must be positive') - class _StaticHeader(cls): + class _StaticHeader(StaticHeader): _count: ClassVar[int] = count return _StaticHeader -class DynamicHeader(Structured, HeaderBase): +class DynamicHeader(Generic[TCount], Structured, HeaderBase): """Base for dynamically sized arrays, where the array length is just prior to the array data. """ - count: int + count: TCount data_size: ClassVar[int] = 0 two_pass: ClassVar[bool] = False - _headers = {} - def __init__(self, count: int, data_size: int) -> None: """Only `count` is packed/unpacked.""" - self.count = count - - def __class_getitem__( - cls, - count_type: type[SizeTypes], - ) -> type[DynamicHeader]: - """Specialize based on the uint* type used to store the array length.""" - return cls._headers[count_type] - -class _DynamicHeader8(DynamicHeader): - count: uint8 -class _DynamicHeader16(DynamicHeader): - count: uint16 -class _DynamicHeader32(DynamicHeader): - count: uint32 -class _DynamicHeader64(DynamicHeader): - count: uint64 -DynamicHeader._headers = { - uint8: _DynamicHeader8, - uint16: _DynamicHeader16, - uint32: _DynamicHeader32, - uint64: _DynamicHeader64, -} - - -class StaticCheckedHeader(Structured, HeaderBase): + self.count = count # type: ignore + + @classmethod + def specialize(cls, count_type: type[SizeTypes]) -> type[DynamicHeader]: + class _DynamicHeader(DynamicHeader[count_type]): pass + return _DynamicHeader + + +class StaticCheckedHeader(Generic[TSize], Structured, HeaderBase): """Statically sized array, with a size check int packed just prior to the array data. """ _count: ClassVar[int] = 0 - data_size: int + data_size: TSize two_pass: ClassVar[bool] = True - _headers = {} - def __init__(self, count: int, data_size: int) -> None: """Only `data_size` is packed/unpacked.""" - self.data_size = data_size + self.data_size = data_size # type: ignore def validate_data_size(self, data_size: int) -> None: """Verify correct amount of bytes were read.""" @@ -139,39 +118,20 @@ def count(self, new_count: int) -> None: f'{new_count}' ) - def __class_getitem__( - cls: type[StaticCheckedHeader], - key: tuple[int, type[SizeTypes]], - ) -> type[StaticCheckedHeader]: + @classmethod + def specialize(cls, count: int, size_type: type[SizeTypes]) -> type[StaticCheckedHeader]: """Specialize for the specific static size and check type.""" - count, size_type = key if count <= 0: raise ValueError('count must be positive') - base = cls._headers[size_type] - class _StaticCheckedHeader(base): + class _StaticCheckedHeader(StaticCheckedHeader[size_type]): _count: ClassVar[int] = count return _StaticCheckedHeader -class _StaticCheckedHeader8(StaticCheckedHeader): - data_size: uint8 -class _StaticCheckedHeader16(StaticCheckedHeader): - data_size: uint16 -class _StaticChechedHeader32(StaticCheckedHeader): - data_size: uint32 -class _StaticCheckedHeader64(StaticCheckedHeader): - data_size: uint64 -StaticCheckedHeader._headers = { - uint8: _StaticCheckedHeader8, - uint16: _StaticCheckedHeader16, - uint32: _StaticChechedHeader32, - uint64: _StaticCheckedHeader64, -} - - -class DynamicCheckedHeader(Structured, HeaderBase): + +class DynamicCheckedHeader(Generic[TCount, TSize], Structured, HeaderBase): """Dynamically sized array with a size check.""" - count: int - data_size: int + count: TCount + data_size: TSize two_pass: ClassVar[bool] = True _headers = {} @@ -184,79 +144,10 @@ def validate_data_size(self, data_size: int) -> None: f'got {data_size}' ) - def __class_getitem__( - cls: type[DynamicCheckedHeader], - key: tuple[type[SizeTypes], type[SizeTypes]], - ) -> type[DynamicCheckedHeader]: - """Specialize based on size type and check type.""" - return cls._headers[key] - -class _DynamicCheckedHeader8_8(DynamicCheckedHeader): - count: uint8 - data_size: uint8 -class _DynamicCheckedHeader8_16(DynamicCheckedHeader): - count: uint8 - data_size: uint16 -class _DynamicCheckedHeader8_32(DynamicCheckedHeader): - count: uint8 - data_size: uint32 -class _DynamicCheckedHeader8_64(DynamicCheckedHeader): - count: uint8 - data_size: uint64 -class _DynamicCheckedHeader16_8(DynamicCheckedHeader): - count: uint16 - data_size: uint8 -class _DynamicCheckedHeader16_16(DynamicCheckedHeader): - count: uint16 - data_size: uint16 -class _DynamicCheckedHeader16_32(DynamicCheckedHeader): - count: uint16 - data_size: uint32 -class _DynamicCheckedHeader16_64(DynamicCheckedHeader): - count: uint16 - data_size: uint64 -class _DynamicCheckedHeader32_8(DynamicCheckedHeader): - count: uint32 - data_size: uint8 -class _DynamicCheckedHeader32_16(DynamicCheckedHeader): - count: uint32 - data_size: uint16 -class _DynamicCheckedHeader32_32(DynamicCheckedHeader): - count: uint32 - data_size: uint32 -class _DynamicCheckedHeader32_64(DynamicCheckedHeader): - count: uint32 - data_size: uint64 -class _DynamicCheckedHeader64_8(DynamicCheckedHeader): - count: uint64 - data_size: uint8 -class _DynamicCheckedHeader64_16(DynamicCheckedHeader): - count: uint64 - data_size: uint16 -class _DynamicCheckedHeader64_32(DynamicCheckedHeader): - count: uint64 - data_size: uint32 -class _DynamicCheckedHeader64_64(DynamicCheckedHeader): - count: uint64 - data_size: uint64 -DynamicCheckedHeader._headers = { - (uint8, uint8): _DynamicCheckedHeader8_8, - (uint8, uint16): _DynamicCheckedHeader8_16, - (uint8, uint32): _DynamicCheckedHeader8_32, - (uint8, uint64): _DynamicCheckedHeader8_64, - (uint16, uint8): _DynamicCheckedHeader16_8, - (uint16, uint16): _DynamicCheckedHeader16_16, - (uint16, uint32): _DynamicCheckedHeader16_32, - (uint16, uint64): _DynamicCheckedHeader16_64, - (uint32, uint8): _DynamicCheckedHeader32_8, - (uint32, uint16): _DynamicCheckedHeader32_16, - (uint32, uint32): _DynamicCheckedHeader32_32, - (uint32, uint64): _DynamicCheckedHeader32_64, - (uint64, uint8): _DynamicCheckedHeader64_8, - (uint64, uint16): _DynamicCheckedHeader64_16, - (uint64, uint32): _DynamicCheckedHeader64_32, - (uint64, uint64): _DynamicCheckedHeader64_64, -} + @classmethod + def specialize(cls, count_type: type[SizeTypes], size_type: type[SizeTypes]) -> type[DynamicCheckedHeader]: + class _DynamicCheckedHeader(DynamicCheckedHeader[count_type, size_type]): pass + return _DynamicCheckedHeader class Header(Structured, HeaderBase): @@ -267,37 +158,42 @@ class Header(Structured, HeaderBase): data_size: int two_pass: ClassVar[bool] - @classmethod - @cache def __class_getitem__(cls, key) -> type[Header]: """Main entry point for making Headers. Do type checks and create the appropriate Header type. """ if not isinstance(key, tuple): - count, size_check = key, None - elif len(key) != 2: - raise TypeError(f'{cls.__name__}[] expected two arguments') - else: - count, size_check = key - # TypeVar checks: + key = (key, ) + return cls.create(*key) + + @classmethod + def create(cls, count, size_check=None): + return cls._create(count, size_check) + + @classmethod + @cache + def _create(cls, count: Union[int, type[SizeTypes]], size_check: Optional[type[SizeTypes]]) -> type[Header]: + # TypeVar quick out. if isinstance(count, TypeVar) or isinstance(size_check, TypeVar): - return StructuredAlias(cls, (count, size_check)) - try: - if size_check is None: - args = (count,) - if isinstance(count, int): - header = StaticHeader[count] - else: - header = DynamicHeader[count] + return StructuredAlias(cls, (count, size_check)) # type: ignore + # Final type checking + if size_check is not None: + if not isinstance(size_check, type) or not issubclass(size_check, _SizeTypes): + raise TypeError('size check must be a uint* type.') + elif not isinstance(count, int): + if not isinstance(count, type) or not issubclass(count, _SizeTypes): + raise TypeError('array length must be an integer or uint* type.') + # Dispatch + if size_check is None: + args = (count, ) + if isinstance(count, int): + header = StaticHeader.specialize(count) + else: + header = DynamicHeader.specialize(count) + else: + args = (count, size_check) + if isinstance(count, int): + header = StaticCheckedHeader.specialize(count, size_check) else: - args = (count, size_check) - if isinstance(count, int): - header = StaticCheckedHeader[count, size_check] - else: - header = DynamicCheckedHeader[count, size_check] - return specialized(cls, *args)(header) # type: ignore - except KeyError: - raise TypeError( - f'{cls.__name__}[] expected first argument integer or uint* ' - 'type, second argument uint* type or None' - ) from None + header = DynamicCheckedHeader.specialize(count, size_check) + return specialized(cls, *args)(header) # type: ignore From d1a524e98e4fcaf6afb767e9ff8459fdcc2113e3 Mon Sep 17 00:00:00 2001 From: lojack5 <1458329+lojack5@users.noreply.github.com> Date: Sat, 10 Sep 2022 18:44:38 -0600 Subject: [PATCH 4/8] style cleanup - docstrings - line lengths --- structured/complex_types/array_headers.py | 52 ++++++++++++++++++++--- structured/complex_types/arrays.py | 6 ++- structured/complex_types/strings.py | 20 ++++++--- structured/structured.py | 12 ++++-- structured/utils.py | 10 ++++- 5 files changed, 81 insertions(+), 19 deletions(-) diff --git a/structured/complex_types/array_headers.py b/structured/complex_types/array_headers.py index 6dbd926..f48f441 100644 --- a/structured/complex_types/array_headers.py +++ b/structured/complex_types/array_headers.py @@ -119,8 +119,17 @@ def count(self, new_count: int) -> None: ) @classmethod - def specialize(cls, count: int, size_type: type[SizeTypes]) -> type[StaticCheckedHeader]: - """Specialize for the specific static size and check type.""" + def specialize( + cls, + count: int, + size_type: type[SizeTypes], + ) -> type[StaticCheckedHeader]: + """Specialize for the specific static size and check type. + + :param count: Static length for the array. + :param size_type: Type of integer to unpack for the array data size. + :return: The specialized Header class. + """ if count <= 0: raise ValueError('count must be positive') class _StaticCheckedHeader(StaticCheckedHeader[size_type]): @@ -145,7 +154,17 @@ def validate_data_size(self, data_size: int) -> None: ) @classmethod - def specialize(cls, count_type: type[SizeTypes], size_type: type[SizeTypes]) -> type[DynamicCheckedHeader]: + def specialize( + cls, + count_type: type[SizeTypes], + size_type: type[SizeTypes] + ) -> type[DynamicCheckedHeader]: + """Specialize for the specific count type and check type. + + :param count_type: Type of integer to unpack for the array length. + :param size_type: Type of integer to unpack for the array data size. + :return: The specialized Header class. + """ class _DynamicCheckedHeader(DynamicCheckedHeader[count_type, size_type]): pass return _DynamicCheckedHeader @@ -168,21 +187,40 @@ def __class_getitem__(cls, key) -> type[Header]: @classmethod def create(cls, count, size_check=None): + """Intermediate method to pass through default args to the real cached + creation method. + """ return cls._create(count, size_check) @classmethod @cache - def _create(cls, count: Union[int, type[SizeTypes]], size_check: Optional[type[SizeTypes]]) -> type[Header]: + def _create( + cls, + count: Union[int, type[SizeTypes]], + size_check: Optional[type[SizeTypes]] + ) -> type[Header]: + """Check header arguments and dispatch to the correct Header + specialization. + + :param count: Static length or integer type to unpack for array length. + :param size_check: Integer type to unpack for array data size, or None + for no integer to unpack. + :return: The applicable Header specialization + """ # TypeVar quick out. if isinstance(count, TypeVar) or isinstance(size_check, TypeVar): return StructuredAlias(cls, (count, size_check)) # type: ignore # Final type checking if size_check is not None: - if not isinstance(size_check, type) or not issubclass(size_check, _SizeTypes): + if not (isinstance(size_check, type) and + issubclass(size_check, _SizeTypes)): raise TypeError('size check must be a uint* type.') elif not isinstance(count, int): - if not isinstance(count, type) or not issubclass(count, _SizeTypes): - raise TypeError('array length must be an integer or uint* type.') + if not (isinstance(count, type) and + issubclass(count, _SizeTypes)): + raise TypeError( + 'array length must be an integer or uint* type.' + ) # Dispatch if size_check is None: args = (count, ) diff --git a/structured/complex_types/arrays.py b/structured/complex_types/arrays.py index 9f29798..eded632 100644 --- a/structured/complex_types/arrays.py +++ b/structured/complex_types/arrays.py @@ -75,8 +75,10 @@ def __class_getitem__( header = Header[args[:-1]] array_type = args[-1] # TypeVar checks: - if isinstance(header, StructuredAlias) or isinstance(array_type, TypeVar): - return StructuredAlias(cls, (header, array_type)) + if (isinstance(header, StructuredAlias) or + isinstance(array_type, TypeVar)): + return StructuredAlias(cls, (header, array_type)) # type: ignore + # Type checking if (not isinstance(header, type) or not issubclass(header, HeaderBase) or header is HeaderBase or diff --git a/structured/complex_types/strings.py b/structured/complex_types/strings.py index 69f25c9..5bdb303 100644 --- a/structured/complex_types/strings.py +++ b/structured/complex_types/strings.py @@ -72,7 +72,7 @@ def _create( elif count is NET: new_cls = _net_char elif isinstance(count, TypeVar): - return StructuredAlias(cls, (count,)) + return StructuredAlias(cls, (count,)) # type: ignore else: raise TypeError( f'{cls.__qualname__}[] count must be an int, NET, or uint* ' @@ -154,11 +154,12 @@ def _create( :return: The specialized class. """ if isinstance(count, TypeVar): - return StructuredAlias(cls, (count, encoding)) + return StructuredAlias(cls, (count, encoding)) # type: ignore if isinstance(encoding, str): encoder = partial(str.encode, encoding=encoding) decoder = partial(bytes.decode, encoding=encoding) - elif isinstance(encoding, type) and issubclass(encoding, EncoderDecoder): + elif (isinstance(encoding, type) and + issubclass(encoding, EncoderDecoder)): encoder = encoding.encode decoder = encoding.decode else: @@ -421,7 +422,12 @@ class _unicode(base): def pack(self, *values: Any) -> bytes: return super().pack(self.encoder(values[0])) - def pack_into(self, buffer: WritableBuffer, offset: int, *values: str) -> None: + def pack_into( + self, + buffer: WritableBuffer, + offset: int, + *values: str, + ) -> None: super().pack_into(buffer, offset, self.encoder(values[0])) def pack_write(self, writable: SupportsWrite, *values: str) -> None: @@ -430,7 +436,11 @@ def pack_write(self, writable: SupportsWrite, *values: str) -> None: def unpack(self, buffer: ReadableBuffer) -> tuple[str]: return self.decoder(super().unpack(buffer)[0]).rstrip('\0'), - def unpack_from(self, buffer: ReadableBuffer, offset: int = 0) -> tuple[str]: + def unpack_from( + self, + buffer: ReadableBuffer, + offset: int = 0, + ) -> tuple[str]: return self.decoder( super().unpack_from(buffer, offset)[0]).rstrip('\0'), diff --git a/structured/structured.py b/structured/structured.py index 8dc6a0e..633bee5 100644 --- a/structured/structured.py +++ b/structured/structured.py @@ -34,7 +34,9 @@ def validate_typehint(attr_type: type) -> TypeGuard[type[_Annotation]]: if issubclass(attr_type, (format_type, Serializer)): return True else: - raise TypeError(f'Unknown structured type {attr_type.__qualname__}') + raise TypeError( + f'Unknown structured type {attr_type.__qualname__}' + ) return False @@ -48,7 +50,10 @@ class MyStruct(Structured): return kind -def filter_typehints(typehints: dict[str, Any], classdict: dict[str, Any]) -> dict[str, type[_Annotation]]: +def filter_typehints( + typehints: dict[str, Any], + classdict: dict[str, Any], + ) -> dict[str, type[_Annotation]]: filtered = { attr: attr_type for attr, attr_type in typehints.items() @@ -348,7 +353,8 @@ def __init_subclass__( base_to_origbase = { origin: orig_base for orig_base in orig_bases - if (origin := get_origin(orig_base)) and issubclass(origin, Structured) + if (origin := get_origin(orig_base)) + and issubclass(origin, Structured) } orig_base = base_to_origbase.get(base, None) if orig_base: diff --git a/structured/utils.py b/structured/utils.py index 4912a53..0daef40 100644 --- a/structured/utils.py +++ b/structured/utils.py @@ -35,6 +35,9 @@ def wrapper(cls: type[_T]) -> type[_T]: class StructuredAlias: + """Class to hold one of the structured types that takes types as arguments, + which has been passes either another StructuredAlias or a TypeVar. + """ cls: type args: tuple @@ -50,9 +53,12 @@ def resolve(self, tvar_map: dict[TypeVar, type]): arg = arg.resolve(tvar_map) resolved.append(arg) resolved = tuple(resolved) - if any((isinstance(arg, (TypeVar, StructuredAlias)) for arg in resolved)): + if any(( + isinstance(arg, (TypeVar, StructuredAlias)) + for arg in resolved + )): # Act as immutable, so create a new instance, since these objects # are often cached in type factory indexing methods. return StructuredAlias(self.cls, resolved) else: - return self.cls[resolved] + return self.cls[resolved] # type: ignore From 3a17af39b76af66f1f0614544c37a28343facf27 Mon Sep 17 00:00:00 2001 From: lojack5 <1458329+lojack5@users.noreply.github.com> Date: Sat, 10 Sep 2022 19:07:39 -0600 Subject: [PATCH 5/8] readme update --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index 0224591..c6fb8a0 100644 --- a/README.md +++ b/README.md @@ -275,3 +275,16 @@ class MyStruct(Structured): ``` No solution is perfect, and any type checker set to a strict level will complain about a lot of code. + + +## Generic `Structured` classes +You can also create your `Structured` class as a `Generic`. Due to details of how `typing.Generic` works, to get a working specialized version, you must subclass the specialization: + +```python +class MyGeneric(Generic[T, U], Structured): + a: T + b: array[Header[10], U] + + +class ConcreteClass(MyGeneric[uint8, uint32]): pass +``` From 67c25de3b7cb2465974e9adf06970c47d0397c82 Mon Sep 17 00:00:00 2001 From: lojack5 <1458329+lojack5@users.noreply.github.com> Date: Sun, 11 Sep 2022 08:26:24 -0600 Subject: [PATCH 6/8] add Generic for the alternate method of using `serialized` --- README.md | 2 +- structured/structured.py | 31 ++++++++++++++++++++++--------- tests/test_generics.py | 39 ++++++++++++++++++++++++++++++++------- tests/test_utils.py | 2 ++ 4 files changed, 57 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index c6fb8a0..94881fa 100644 --- a/README.md +++ b/README.md @@ -283,7 +283,7 @@ You can also create your `Structured` class as a `Generic`. Due to details of h ```python class MyGeneric(Generic[T, U], Structured): a: T - b: array[Header[10], U] + b: list[U] = serializerd(array[Header[10], U]) class ConcreteClass(MyGeneric[uint8, uint32]): pass diff --git a/structured/structured.py b/structured/structured.py index 633bee5..2339f2c 100644 --- a/structured/structured.py +++ b/structured/structured.py @@ -347,8 +347,9 @@ def __init_subclass__( f'{base.__name__} ({base.byte_order.name}). ' 'If this is intentional, use `byte_order_mode=OVERRIDE`.' ) + # Evaluta any generics in base class + classdict = cls.__dict__ if base: - # determine correct typehints for any Generic's in the base class orig_bases = getattr(cls, '__orig_bases__', ()) base_to_origbase = { origin: orig_base @@ -358,13 +359,15 @@ def __init_subclass__( } orig_base = base_to_origbase.get(base, None) if orig_base: - updates = base._specialize(*get_args(orig_base)) - cls.__annotations__.update(updates) + annotation_updates, classdict_updates = base._specialize( + *get_args(orig_base) + ) + cls.__annotations__.update(annotation_updates) + # NOTE: cls.__dict__ is a mappingproxy + classdict = dict(classdict) | classdict_updates # Analyze the class typehints = get_type_hints(cls) - serializer, attrs = create_serializer( - typehints, cls.__dict__, byte_order - ) + serializer, attrs = create_serializer(typehints, classdict, byte_order) # And set the updated class attributes cls.serializer = serializer cls.attrs = attrs @@ -381,8 +384,10 @@ def _specialize(cls, *args): supers[origin] = base tvar_map = dict(zip(tvars, args)) if not tvar_map: - return {} + raise TypeError('{cls.__name__} is not a Generic') + # First handle the direct base class annotations = {} + classdict = {} for attr, attr_type in get_type_hints(cls).items(): if attr in cls.__annotations__: # Attribute's final type hint comes from this class @@ -390,10 +395,18 @@ def _specialize(cls, *args): annotations[attr] = remapped_type elif isinstance(attr_type, StructuredAlias): annotations[attr] = attr_type.resolve(tvar_map) + for attr, attr_val in cls.__dict__.items(): + if isinstance(attr_val, StructuredAlias): + classdict[attr] = attr_val.resolve(tvar_map) + # Now any classes higher in the chain all_annotations = [annotations] + all_classdict = [classdict] for base, alias in supers.items(): args = get_args(alias) args = (tvar_map.get(arg, arg) for arg in args) - super_annotations = base._specialize(*args) + super_annotations, super_classdict = base._specialize(*args) all_annotations.append(super_annotations) - return reduce(operator.or_, reversed(all_annotations)) + all_classdict.append(super_classdict) + final_annotations = reduce(operator.or_, reversed(all_annotations)) + final_classdict = reduce(operator.or_, reversed(all_classdict)) + return final_annotations, final_classdict diff --git a/tests/test_generics.py b/tests/test_generics.py index 2142f7c..fbdfb0b 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1,9 +1,12 @@ import struct -from structured import * from typing import Generic, TypeVar, Union, get_type_hints +import pytest + +from structured import * from structured.utils import StructuredAlias + _Byte = TypeVar('_Byte', bound=Union[uint8, int8]) _String = TypeVar('_String', bound=Union[char, pascal, unicode]) _Size = TypeVar('_Size', bound=Union[uint8, uint16, uint32, uint64]) @@ -21,10 +24,11 @@ class UnsignedUnicode(Base[uint8, unicode[uint8]]): pass -class TestAliasing: - class Item(Structured): - a: int8 +class Item(Structured): + a: int8 + +class TestAliasing: tvar_map = { _Size: uint32, _Byte: uint8, @@ -77,9 +81,9 @@ def test_array(self) -> None: assert obj2.cls is array assert obj2.args == (Header[uint32, uint8], T) - obj3 = obj2.resolve({T: self.Item}) - assert obj3 is array[Header[uint32, uint8], self.Item] - assert obj.resolve(self.tvar_map) is array[Header[uint32, uint8], self.Item] + obj3 = obj2.resolve({T: Item}) + assert obj3 is array[Header[uint32, uint8], Item] + assert obj.resolve(self.tvar_map) is array[Header[uint32, uint8], Item] def test_automatic_resolution(): @@ -104,3 +108,24 @@ class FullySpecialized2(PartiallySpecialized[uint16, Item]): pass assert FullySpecialized1.attrs == FullySpecialized2.attrs assert FullySpecialized1.attrs == ('a', 'b', 'c') assert get_type_hints(FullySpecialized1) == get_type_hints(FullySpecialized2) + + +def test_serialized_generics() -> None: + class Base(Generic[_Size], Structured): + a: list[_Size] = serialized(array[Header[3], _Size]) + + class Concrete(Base[uint32]): + pass + + assert Concrete.attrs == ('a',) + target_data = struct.pack(f'3I', 1, 2, 3) + target_obj = Concrete.create_unpack(target_data) + assert target_obj.a == [1, 2, 3] + + +def test_errors() -> None: + class NotGeneric(Structured): + a: uint8 + + with pytest.raises(TypeError): + NotGeneric._specialize(uint8) \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py index 5a8d15a..1bc89e8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1 +1,3 @@ from structured import * + +# No tests at the moment From eb5f032a000717fece2747edbf4d2bdd11bc95fe Mon Sep 17 00:00:00 2001 From: lojack5 <1458329+lojack5@users.noreply.github.com> Date: Sun, 11 Sep 2022 10:20:09 -0600 Subject: [PATCH 7/8] fix for py 3.9 and `__annotations__` --- structured/structured.py | 25 +++++++++++++++++++------ structured/type_checking.py | 19 +++++++++++++++++++ tests/test_generics.py | 2 +- 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/structured/structured.py b/structured/structured.py index 2339f2c..179e763 100644 --- a/structured/structured.py +++ b/structured/structured.py @@ -19,6 +19,7 @@ from .type_checking import ( Any, ClassVar, Optional, ReadableBuffer, SupportsRead, SupportsWrite, WritableBuffer, get_type_hints, isclassvar, cast, TypeGuard, Union, TypeVar, + get_annotations, update_annotations, ) @@ -359,12 +360,12 @@ def __init_subclass__( } orig_base = base_to_origbase.get(base, None) if orig_base: - annotation_updates, classdict_updates = base._specialize( + annotations, clsdict = base._get_specialization_hints( *get_args(orig_base) ) - cls.__annotations__.update(annotation_updates) + update_annotations(cls, annotations) # NOTE: cls.__dict__ is a mappingproxy - classdict = dict(classdict) | classdict_updates + classdict = dict(classdict) | clsdict # Analyze the class typehints = get_type_hints(cls) serializer, attrs = create_serializer(typehints, classdict, byte_order) @@ -373,8 +374,16 @@ def __init_subclass__( cls.attrs = attrs cls.byte_order = byte_order + + @classmethod - def _specialize(cls, *args): + def _get_specialization_hints( + cls, + *args + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Get needed updates to __annotations__ and __dict if this class were + to be specialized with `args`, + """ supers: dict[type[Structured], Any] = {} tvars = () for base in getattr(cls, '__orig_bases__', ()): @@ -388,8 +397,10 @@ def _specialize(cls, *args): # First handle the direct base class annotations = {} classdict = {} + cls_annotations = get_annotations(cls) + cls_annotations = cls.__dict__.get('__annotations__', {}) for attr, attr_type in get_type_hints(cls).items(): - if attr in cls.__annotations__: + if attr in cls_annotations: # Attribute's final type hint comes from this class if remapped_type := tvar_map.get(attr_type, None): annotations[attr] = remapped_type @@ -404,7 +415,9 @@ def _specialize(cls, *args): for base, alias in supers.items(): args = get_args(alias) args = (tvar_map.get(arg, arg) for arg in args) - super_annotations, super_classdict = base._specialize(*args) + super_annotations, super_classdict = base._get_specialization_hints( + *args + ) all_annotations.append(super_annotations) all_classdict.append(super_classdict) final_annotations = reduce(operator.or_, reversed(all_annotations)) diff --git a/structured/type_checking.py b/structured/type_checking.py index af6a0b1..94ad258 100644 --- a/structured/type_checking.py +++ b/structured/type_checking.py @@ -14,6 +14,25 @@ _T = TypeVar('_T') +def update_annotations(cls: type, annotations: dict[str, Any]) -> None: + """Python <3.10 compatible way to update a class's annotations dict. See: + + https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older + """ + if '__annotations__' in cls.__dict__: + cls.__annotations__.update(annotations) + else: + setattr(cls, '__annotations__', annotations) + +def get_annotations(cls: type) -> dict[str, Any]: + """Python <3.10 compatible way to get a class's annotations dict. See: + + https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older + """ + return cls.__dict__.get('__annotations__', {}) + + + def isclassvar(annotation: Any) -> bool: """Determine if a type annotations is for a class variable. diff --git a/tests/test_generics.py b/tests/test_generics.py index fbdfb0b..147e08c 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -128,4 +128,4 @@ class NotGeneric(Structured): a: uint8 with pytest.raises(TypeError): - NotGeneric._specialize(uint8) \ No newline at end of file + NotGeneric._get_specialization_hints(uint8) \ No newline at end of file From bba3c8b6f6a9849333406c83ffe48b5327aabfde Mon Sep 17 00:00:00 2001 From: lojack5 <1458329+lojack5@users.noreply.github.com> Date: Mon, 12 Sep 2022 11:25:23 -0600 Subject: [PATCH 8/8] readme updates on generics limitation in arrays --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index 94881fa..10be947 100644 --- a/README.md +++ b/README.md @@ -288,3 +288,17 @@ class MyGeneric(Generic[T, U], Structured): class ConcreteClass(MyGeneric[uint8, uint32]): pass ``` + +One **limitation** here however, you cannot use a generic Structured class as an array object type. It will act as the base class without specialization (See #8). So for example, the following code will not work as you expect: +```python +class Item(Generic[T], Structured): + a: T + +class MyStruct(Generic[T], Structured): + items: array[Header[10], Item[T]] + +class Concrete(MyStruct[uint32]): pass + +assert Concrete.args == ('items', ) +> AssertionError +```