Skip to content

Commit

Permalink
Merge pull request #11 from lojack5/generic-structured
Browse files Browse the repository at this point in the history
Generics and Structured:

Allow for Generics in Structured classes in almost all locations.  The only limitation remaining is Generic Structured types as an array object type.
  • Loading branch information
lojack5 authored Sep 12, 2022
2 parents 478b006 + bba3c8b commit e4b7982
Show file tree
Hide file tree
Showing 9 changed files with 430 additions and 172 deletions.
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,30 @@ class MyStruct(Structured):
```

No solution is perfect, and any type checker set to a strict level will complain about a lot of code.


## Generic `Structured` classes
You can also create your `Structured` class as a `Generic`. Due to details of how `typing.Generic` works, to get a working specialized version, you must subclass the specialization:

```python
class MyGeneric(Generic[T, U], Structured):
a: T
b: list[U] = serializerd(array[Header[10], U])


class ConcreteClass(MyGeneric[uint8, uint32]): pass
```

One **limitation** here however, you cannot use a generic Structured class as an array object type. It will act as the base class without specialization (See #8). So for example, the following code will not work as you expect:
```python
class Item(Generic[T], Structured):
a: T

class MyStruct(Generic[T], Structured):
items: array[Header[10], Item[T]]

class Concrete(MyStruct[uint32]): pass

assert Concrete.args == ('items', )
> AssertionError
```
262 changes: 100 additions & 162 deletions structured/complex_types/array_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

from functools import cache

from ..utils import specialized
from ..utils import StructuredAlias, specialized
from ..structured import Structured
from ..basic_types import uint8, uint16, uint32, uint64
from ..type_checking import ClassVar, Union
from ..type_checking import ClassVar, Union, Generic, Optional, TypeVar


_SizeTypes = (uint8, uint16, uint32, uint64)
SizeTypes = Union[uint8, uint16, uint32, uint64]
TSize = TypeVar('TSize', bound=SizeTypes)
TCount = TypeVar('TCount', bound=SizeTypes)


class HeaderBase:
Expand Down Expand Up @@ -53,68 +56,45 @@ def count(self, new_count: int) -> None:
f'expected an array of length {self.count}, but got {new_count}'
)

def __class_getitem__(
cls: type[StaticHeader],
count: int,
) -> type[StaticHeader]:
@classmethod
def specialize(cls, count: int) -> type[StaticHeader]:
"""Specialize for a specific static size."""
if count <= 0:
raise ValueError('count must be positive')
class _StaticHeader(cls):
class _StaticHeader(StaticHeader):
_count: ClassVar[int] = count
return _StaticHeader


class DynamicHeader(Structured, HeaderBase):
class DynamicHeader(Generic[TCount], Structured, HeaderBase):
"""Base for dynamically sized arrays, where the array length is just prior
to the array data.
"""
count: int
count: TCount
data_size: ClassVar[int] = 0
two_pass: ClassVar[bool] = False

_headers = {}

def __init__(self, count: int, data_size: int) -> None:
"""Only `count` is packed/unpacked."""
self.count = count
self.count = count # type: ignore

def __class_getitem__(
cls,
count_type: type[SizeTypes],
) -> type[DynamicHeader]:
"""Specialize based on the uint* type used to store the array length."""
return cls._headers[count_type]

class _DynamicHeader8(DynamicHeader):
count: uint8
class _DynamicHeader16(DynamicHeader):
count: uint16
class _DynamicHeader32(DynamicHeader):
count: uint32
class _DynamicHeader64(DynamicHeader):
count: uint64
DynamicHeader._headers = {
uint8: _DynamicHeader8,
uint16: _DynamicHeader16,
uint32: _DynamicHeader32,
uint64: _DynamicHeader64,
}


class StaticCheckedHeader(Structured, HeaderBase):
@classmethod
def specialize(cls, count_type: type[SizeTypes]) -> type[DynamicHeader]:
class _DynamicHeader(DynamicHeader[count_type]): pass
return _DynamicHeader


class StaticCheckedHeader(Generic[TSize], Structured, HeaderBase):
"""Statically sized array, with a size check int packed just prior to the
array data.
"""
_count: ClassVar[int] = 0
data_size: int
data_size: TSize
two_pass: ClassVar[bool] = True

_headers = {}

def __init__(self, count: int, data_size: int) -> None:
"""Only `data_size` is packed/unpacked."""
self.data_size = data_size
self.data_size = data_size # type: ignore

def validate_data_size(self, data_size: int) -> None:
"""Verify correct amount of bytes were read."""
Expand All @@ -138,39 +118,29 @@ def count(self, new_count: int) -> None:
f'{new_count}'
)

def __class_getitem__(
cls: type[StaticCheckedHeader],
key: tuple[int, type[SizeTypes]],
@classmethod
def specialize(
cls,
count: int,
size_type: type[SizeTypes],
) -> type[StaticCheckedHeader]:
"""Specialize for the specific static size and check type."""
count, size_type = key
"""Specialize for the specific static size and check type.
:param count: Static length for the array.
:param size_type: Type of integer to unpack for the array data size.
:return: The specialized Header class.
"""
if count <= 0:
raise ValueError('count must be positive')
base = cls._headers[size_type]
class _StaticCheckedHeader(base):
class _StaticCheckedHeader(StaticCheckedHeader[size_type]):
_count: ClassVar[int] = count
return _StaticCheckedHeader

class _StaticCheckedHeader8(StaticCheckedHeader):
data_size: uint8
class _StaticCheckedHeader16(StaticCheckedHeader):
data_size: uint16
class _StaticChechedHeader32(StaticCheckedHeader):
data_size: uint32
class _StaticCheckedHeader64(StaticCheckedHeader):
data_size: uint64
StaticCheckedHeader._headers = {
uint8: _StaticCheckedHeader8,
uint16: _StaticCheckedHeader16,
uint32: _StaticChechedHeader32,
uint64: _StaticCheckedHeader64,
}


class DynamicCheckedHeader(Structured, HeaderBase):

class DynamicCheckedHeader(Generic[TCount, TSize], Structured, HeaderBase):
"""Dynamically sized array with a size check."""
count: int
data_size: int
count: TCount
data_size: TSize
two_pass: ClassVar[bool] = True

_headers = {}
Expand All @@ -183,79 +153,20 @@ def validate_data_size(self, data_size: int) -> None:
f'got {data_size}'
)

def __class_getitem__(
cls: type[DynamicCheckedHeader],
key: tuple[type[SizeTypes], type[SizeTypes]],
@classmethod
def specialize(
cls,
count_type: type[SizeTypes],
size_type: type[SizeTypes]
) -> type[DynamicCheckedHeader]:
"""Specialize based on size type and check type."""
return cls._headers[key]

class _DynamicCheckedHeader8_8(DynamicCheckedHeader):
count: uint8
data_size: uint8
class _DynamicCheckedHeader8_16(DynamicCheckedHeader):
count: uint8
data_size: uint16
class _DynamicCheckedHeader8_32(DynamicCheckedHeader):
count: uint8
data_size: uint32
class _DynamicCheckedHeader8_64(DynamicCheckedHeader):
count: uint8
data_size: uint64
class _DynamicCheckedHeader16_8(DynamicCheckedHeader):
count: uint16
data_size: uint8
class _DynamicCheckedHeader16_16(DynamicCheckedHeader):
count: uint16
data_size: uint16
class _DynamicCheckedHeader16_32(DynamicCheckedHeader):
count: uint16
data_size: uint32
class _DynamicCheckedHeader16_64(DynamicCheckedHeader):
count: uint16
data_size: uint64
class _DynamicCheckedHeader32_8(DynamicCheckedHeader):
count: uint32
data_size: uint8
class _DynamicCheckedHeader32_16(DynamicCheckedHeader):
count: uint32
data_size: uint16
class _DynamicCheckedHeader32_32(DynamicCheckedHeader):
count: uint32
data_size: uint32
class _DynamicCheckedHeader32_64(DynamicCheckedHeader):
count: uint32
data_size: uint64
class _DynamicCheckedHeader64_8(DynamicCheckedHeader):
count: uint64
data_size: uint8
class _DynamicCheckedHeader64_16(DynamicCheckedHeader):
count: uint64
data_size: uint16
class _DynamicCheckedHeader64_32(DynamicCheckedHeader):
count: uint64
data_size: uint32
class _DynamicCheckedHeader64_64(DynamicCheckedHeader):
count: uint64
data_size: uint64
DynamicCheckedHeader._headers = {
(uint8, uint8): _DynamicCheckedHeader8_8,
(uint8, uint16): _DynamicCheckedHeader8_16,
(uint8, uint32): _DynamicCheckedHeader8_32,
(uint8, uint64): _DynamicCheckedHeader8_64,
(uint16, uint8): _DynamicCheckedHeader16_8,
(uint16, uint16): _DynamicCheckedHeader16_16,
(uint16, uint32): _DynamicCheckedHeader16_32,
(uint16, uint64): _DynamicCheckedHeader16_64,
(uint32, uint8): _DynamicCheckedHeader32_8,
(uint32, uint16): _DynamicCheckedHeader32_16,
(uint32, uint32): _DynamicCheckedHeader32_32,
(uint32, uint64): _DynamicCheckedHeader32_64,
(uint64, uint8): _DynamicCheckedHeader64_8,
(uint64, uint16): _DynamicCheckedHeader64_16,
(uint64, uint32): _DynamicCheckedHeader64_32,
(uint64, uint64): _DynamicCheckedHeader64_64,
}
"""Specialize for the specific count type and check type.
:param count_type: Type of integer to unpack for the array length.
:param size_type: Type of integer to unpack for the array data size.
:return: The specialized Header class.
"""
class _DynamicCheckedHeader(DynamicCheckedHeader[count_type, size_type]): pass
return _DynamicCheckedHeader


class Header(Structured, HeaderBase):
Expand All @@ -266,34 +177,61 @@ class Header(Structured, HeaderBase):
data_size: int
two_pass: ClassVar[bool]

@classmethod
@cache
def __class_getitem__(cls, key) -> type[Header]:
"""Main entry point for making Headers. Do type checks and create the
appropriate Header type.
"""
if not isinstance(key, tuple):
count, size_check = key, None
elif len(key) != 2:
raise TypeError(f'{cls.__name__}[] expected two arguments')
key = (key, )
return cls.create(*key)

@classmethod
def create(cls, count, size_check=None):
"""Intermediate method to pass through default args to the real cached
creation method.
"""
return cls._create(count, size_check)

@classmethod
@cache
def _create(
cls,
count: Union[int, type[SizeTypes]],
size_check: Optional[type[SizeTypes]]
) -> type[Header]:
"""Check header arguments and dispatch to the correct Header
specialization.
:param count: Static length or integer type to unpack for array length.
:param size_check: Integer type to unpack for array data size, or None
for no integer to unpack.
:return: The applicable Header specialization
"""
# TypeVar quick out.
if isinstance(count, TypeVar) or isinstance(size_check, TypeVar):
return StructuredAlias(cls, (count, size_check)) # type: ignore
# Final type checking
if size_check is not None:
if not (isinstance(size_check, type) and
issubclass(size_check, _SizeTypes)):
raise TypeError('size check must be a uint* type.')
elif not isinstance(count, int):
if not (isinstance(count, type) and
issubclass(count, _SizeTypes)):
raise TypeError(
'array length must be an integer or uint* type.'
)
# Dispatch
if size_check is None:
args = (count, )
if isinstance(count, int):
header = StaticHeader.specialize(count)
else:
header = DynamicHeader.specialize(count)
else:
count, size_check = key
try:
if size_check is None:
args = (count,)
if isinstance(count, int):
header = StaticHeader[count]
else:
header = DynamicHeader[count]
args = (count, size_check)
if isinstance(count, int):
header = StaticCheckedHeader.specialize(count, size_check)
else:
args = (count, size_check)
if isinstance(count, int):
header = StaticCheckedHeader[count, size_check]
else:
header = DynamicCheckedHeader[count, size_check]
return specialized(cls, *args)(header) # type: ignore
except KeyError:
raise TypeError(
f'{cls.__name__}[] expected first argument integer or uint* '
'type, second argument uint* type or None'
) from None
header = DynamicCheckedHeader.specialize(count, size_check)
return specialized(cls, *args)(header) # type: ignore
5 changes: 5 additions & 0 deletions structured/complex_types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def __class_getitem__(
else:
header = Header[args[:-1]]
array_type = args[-1]
# TypeVar checks:
if (isinstance(header, StructuredAlias) or
isinstance(array_type, TypeVar)):
return StructuredAlias(cls, (header, array_type)) # type: ignore
# Type checking
if (not isinstance(header, type) or
not issubclass(header, HeaderBase) or
header is HeaderBase or
Expand Down
Loading

0 comments on commit e4b7982

Please sign in to comment.