diff --git a/README.md b/README.md index f7af4a7..f23b8cd 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,39 @@ List[...] | pa.list_(...) | Dict[..., ...] | pa.map_(pa key_type, pa value_type) | Enum of str | pa.dictionary(pa.int32(), pa.string()) | Enum of int | pa.int64() | +UUID (uuid.UUID or pydantic.types.UUID*) | pa.uuid() | SEE NOTE BELOW! + +Note on UUIDs: the UUID type is only supported in pyarrow 18.0 and above. However, +as of pyarrow 19.0, when pyarrow creates a table in eg `pa.Table.from_pylist(objs, schema=schema)`, +it expects bytes not a uuid.UUID type. Hence, if you are using .model_dump() to create +the data for pyarrow, you need to add a serializer on your pydantic model to convert to bytes. +This may be fixed in later versions (see [https://github.com/apache/arrow/issues/43855]). + +eg (with pyarrow >= 18.0): +```py +import uuid +from typing import Annotated + +import pyarrow as pa +from pydantic import BaseModel, PlainSerializer +from pydantic_to_pyarrow import get_pyarrow_schema + +class ModelWithUuid(BaseModel): + uuid: Annotated[uuid.UUID, PlainSerializer(lambda x: x.bytes, return_type=bytes)] + + +schema = get_pyarrow_schema(ModelWithUuid) + +model1 = ModelWithUuid(uuid=uuid.uuid1()) +model2 = ModelWithUuid(uuid=uuid.uuid4()) +data = [model1.model_dump(), model2.model_dump()] +table = pa.Table.from_pylist(data) +print(table) +#> pyarrow.Table +#> uuid: binary +#> ---- +#> uuid: [[BF206AC0DA4711EF8271EF4F4B7A3587,211C4C5D94C74876AE5E32DBCCDC16C7]] +``` ## Settings diff --git a/src/pydantic_to_pyarrow/schema.py b/src/pydantic_to_pyarrow/schema.py index 0cdda4c..660f03d 100644 --- a/src/pydantic_to_pyarrow/schema.py +++ b/src/pydantic_to_pyarrow/schema.py @@ -32,7 +32,6 @@ class Settings(NamedTuple): datetime.date: pa.date32(), NaiveDatetime: pa.timestamp("ms", tz=None), datetime.time: pa.time64("us"), - uuid.UUID: pa.uuid(), } # Timezone aware datetimes will lose their timezone information @@ -158,6 +157,18 @@ def _get_enum_type(field_type: Type[Any]) -> pa.DataType: raise SchemaCreationError(msg) +def _get_uuid_type() -> pa.DataType: + # Different branches will execute depending on the pyarrow version + # This is tested through nox and python versions, but each one + # won't cover both branches. Hence, excluding from coverage. + if hasattr(pa, "uuid"): # pragma: no cover + return pa.uuid() + else: # pragma: no cover + msg = f"pyarrow version {pa.__version__} does not support pa.uuid() type, " + msg += "needs version 18.0 or higher" + raise SchemaCreationError(msg) + + def _is_optional(field_type: Type[Any]) -> bool: origin = get_origin(field_type) is_python_39_union = origin is Union @@ -169,7 +180,9 @@ def _is_optional(field_type: Type[Any]) -> bool: return type(None) in get_args(field_type) -def _get_pyarrow_type( +# noqa: PLR0911 - ignore until a refactoring can reduce the number of +# return statements. +def _get_pyarrow_type( # noqa: PLR0911 field_type: Type[Any], metadata: List[Any], settings: Settings, @@ -177,6 +190,9 @@ def _get_pyarrow_type( if field_type in FIELD_MAP: return FIELD_MAP[field_type] + if field_type is uuid.UUID: + return _get_uuid_type() + if settings.allow_losing_tz and field_type in LOSING_TZ_TYPES: return LOSING_TZ_TYPES[field_type] diff --git a/tests/test_schema.py b/tests/test_schema.py index 5b3c513..6974cbf 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -32,12 +32,6 @@ from pydantic_to_pyarrow import SchemaCreationError, get_pyarrow_schema -_uuid_bytes_serializer = PlainSerializer(lambda x: x.bytes, return_type=bytes) -_ByteSerializedUUID1Field = Annotated[UUID1, _uuid_bytes_serializer] -_ByteSerializedUUID3Field = Annotated[UUID3, _uuid_bytes_serializer] -_ByteSerializedUUID4Field = Annotated[UUID4, _uuid_bytes_serializer] -_ByteSerializedUUID5Field = Annotated[UUID5, _uuid_bytes_serializer] - def _write_pq_and_read( objs: List[Dict[str, Any]], schema: pa.Schema @@ -604,35 +598,54 @@ class DictModel(BaseModel): def test_uuid() -> None: + # pyarrow 18.0.0+ is required for UUID support + # Even then, pyarrow doesn't automatically convert UUIDs to bytes + # for the serialization, so we need to do that manually + # (https://github.com/apache/arrow/issues/43855) + as_bytes = PlainSerializer(lambda x: x.bytes, return_type=bytes) + class ModelWithUUID(BaseModel): - foo_1: _ByteSerializedUUID1Field = Field(default_factory=uuid.uuid1) - foo_3: _ByteSerializedUUID3Field = Field( + foo_0: Annotated[uuid.UUID, as_bytes] = Field(default_factory=uuid.uuid1) + foo_1: Annotated[UUID1, as_bytes] = Field(default_factory=uuid.uuid1) + foo_3: Annotated[UUID3, as_bytes] = Field( default_factory=lambda: uuid.uuid3(uuid.NAMESPACE_DNS, "pydantic.org") ) - foo_4: _ByteSerializedUUID4Field = Field(default_factory=uuid.uuid4) - foo_5: _ByteSerializedUUID5Field = Field( + foo_4: Annotated[UUID4, as_bytes] = Field(default_factory=uuid.uuid4) + foo_5: Annotated[UUID5, as_bytes] = Field( default_factory=lambda: uuid.uuid5(uuid.NAMESPACE_DNS, "pydantic.org") ) - expected = pa.schema( - [ - pa.field("foo_1", pa.uuid(), nullable=False), - pa.field("foo_3", pa.uuid(), nullable=False), - pa.field("foo_4", pa.uuid(), nullable=False), - pa.field("foo_5", pa.uuid(), nullable=False), - ] - ) + if version.Version(pa.__version__) < version.Version("18.0.0"): + with pytest.raises(SchemaCreationError) as err: + get_pyarrow_schema(ModelWithUUID) + assert "needs version 18.0 or higher" in str(err) + else: + expected = pa.schema( + [ + pa.field("foo_0", pa.uuid(), nullable=False), + pa.field("foo_1", pa.uuid(), nullable=False), + pa.field("foo_3", pa.uuid(), nullable=False), + pa.field("foo_4", pa.uuid(), nullable=False), + pa.field("foo_5", pa.uuid(), nullable=False), + ] + ) - objs = [ - ModelWithUUID().model_dump(), - ModelWithUUID().model_dump(), - ] + actual = get_pyarrow_schema(ModelWithUUID) + assert actual == expected - actual = get_pyarrow_schema(ModelWithUUID) - assert actual == expected + objs = [ + ModelWithUUID().model_dump(), + ModelWithUUID().model_dump(), + ] - new_schema, new_objs = _write_pq_and_read(objs, expected) - assert new_schema == expected + new_schema, new_objs = _write_pq_and_read(objs, expected) + assert new_schema == expected + # objs was created with the uuid serializer to bytes, + # but pyarrow will read the uuids into UUID objects directly + for obj in objs: + for key in obj: + obj[key] = uuid.UUID(bytes=obj[key]) + assert new_objs == objs def test_alias() -> None: