From 79418af16cd5cad0d7c19744930b281c1c7c28c4 Mon Sep 17 00:00:00 2001 From: Valentin Iovene Date: Fri, 31 Jan 2025 12:41:00 +0100 Subject: [PATCH] (postponed, requires upstream fix) add support for UUID (#24) * add support for UUID * Update UUID code to deal with different pyarrow versions; add documentation for UUID conversion --------- Co-authored-by: simw --- README.md | 33 ++++++++++++++++++ src/pydantic_to_pyarrow/schema.py | 20 ++++++++++- tests/test_schema.py | 58 ++++++++++++++++++++++++++++++- 3 files changed, 109 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f7af4a7..f23b8cd 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,39 @@ List[...] | pa.list_(...) | Dict[..., ...] | pa.map_(pa key_type, pa value_type) | Enum of str | pa.dictionary(pa.int32(), pa.string()) | Enum of int | pa.int64() | +UUID (uuid.UUID or pydantic.types.UUID*) | pa.uuid() | SEE NOTE BELOW! + +Note on UUIDs: the UUID type is only supported in pyarrow 18.0 and above. However, +as of pyarrow 19.0, when pyarrow creates a table in eg `pa.Table.from_pylist(objs, schema=schema)`, +it expects bytes not a uuid.UUID type. Hence, if you are using .model_dump() to create +the data for pyarrow, you need to add a serializer on your pydantic model to convert to bytes. +This may be fixed in later versions (see [https://github.com/apache/arrow/issues/43855]). + +eg (with pyarrow >= 18.0): +```py +import uuid +from typing import Annotated + +import pyarrow as pa +from pydantic import BaseModel, PlainSerializer +from pydantic_to_pyarrow import get_pyarrow_schema + +class ModelWithUuid(BaseModel): + uuid: Annotated[uuid.UUID, PlainSerializer(lambda x: x.bytes, return_type=bytes)] + + +schema = get_pyarrow_schema(ModelWithUuid) + +model1 = ModelWithUuid(uuid=uuid.uuid1()) +model2 = ModelWithUuid(uuid=uuid.uuid4()) +data = [model1.model_dump(), model2.model_dump()] +table = pa.Table.from_pylist(data) +print(table) +#> pyarrow.Table +#> uuid: binary +#> ---- +#> uuid: [[BF206AC0DA4711EF8271EF4F4B7A3587,211C4C5D94C74876AE5E32DBCCDC16C7]] +``` ## Settings diff --git a/src/pydantic_to_pyarrow/schema.py b/src/pydantic_to_pyarrow/schema.py index 8515538..660f03d 100644 --- a/src/pydantic_to_pyarrow/schema.py +++ b/src/pydantic_to_pyarrow/schema.py @@ -1,5 +1,6 @@ import datetime import types +import uuid from decimal import Decimal from enum import EnumMeta from typing import Any, List, Literal, NamedTuple, Optional, Type, TypeVar, Union, cast @@ -156,6 +157,18 @@ def _get_enum_type(field_type: Type[Any]) -> pa.DataType: raise SchemaCreationError(msg) +def _get_uuid_type() -> pa.DataType: + # Different branches will execute depending on the pyarrow version + # This is tested through nox and python versions, but each one + # won't cover both branches. Hence, excluding from coverage. + if hasattr(pa, "uuid"): # pragma: no cover + return pa.uuid() + else: # pragma: no cover + msg = f"pyarrow version {pa.__version__} does not support pa.uuid() type, " + msg += "needs version 18.0 or higher" + raise SchemaCreationError(msg) + + def _is_optional(field_type: Type[Any]) -> bool: origin = get_origin(field_type) is_python_39_union = origin is Union @@ -167,7 +180,9 @@ def _is_optional(field_type: Type[Any]) -> bool: return type(None) in get_args(field_type) -def _get_pyarrow_type( +# noqa: PLR0911 - ignore until a refactoring can reduce the number of +# return statements. +def _get_pyarrow_type( # noqa: PLR0911 field_type: Type[Any], metadata: List[Any], settings: Settings, @@ -175,6 +190,9 @@ def _get_pyarrow_type( if field_type in FIELD_MAP: return FIELD_MAP[field_type] + if field_type is uuid.UUID: + return _get_uuid_type() + if settings.allow_losing_tz and field_type in LOSING_TZ_TYPES: return LOSING_TZ_TYPES[field_type] diff --git a/tests/test_schema.py b/tests/test_schema.py index 3729eab..6974cbf 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,5 +1,6 @@ import datetime import tempfile +import uuid from decimal import Decimal from enum import Enum, auto from pathlib import Path @@ -11,8 +12,12 @@ import pytest from annotated_types import Gt from packaging import version -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, PlainSerializer from pydantic.types import ( + UUID1, + UUID3, + UUID4, + UUID5, AwareDatetime, NaiveDatetime, PositiveInt, @@ -592,6 +597,57 @@ class DictModel(BaseModel): assert objs == [{"foo": dict(t["foo"])} for t in new_objs] +def test_uuid() -> None: + # pyarrow 18.0.0+ is required for UUID support + # Even then, pyarrow doesn't automatically convert UUIDs to bytes + # for the serialization, so we need to do that manually + # (https://github.com/apache/arrow/issues/43855) + as_bytes = PlainSerializer(lambda x: x.bytes, return_type=bytes) + + class ModelWithUUID(BaseModel): + foo_0: Annotated[uuid.UUID, as_bytes] = Field(default_factory=uuid.uuid1) + foo_1: Annotated[UUID1, as_bytes] = Field(default_factory=uuid.uuid1) + foo_3: Annotated[UUID3, as_bytes] = Field( + default_factory=lambda: uuid.uuid3(uuid.NAMESPACE_DNS, "pydantic.org") + ) + foo_4: Annotated[UUID4, as_bytes] = Field(default_factory=uuid.uuid4) + foo_5: Annotated[UUID5, as_bytes] = Field( + default_factory=lambda: uuid.uuid5(uuid.NAMESPACE_DNS, "pydantic.org") + ) + + if version.Version(pa.__version__) < version.Version("18.0.0"): + with pytest.raises(SchemaCreationError) as err: + get_pyarrow_schema(ModelWithUUID) + assert "needs version 18.0 or higher" in str(err) + else: + expected = pa.schema( + [ + pa.field("foo_0", pa.uuid(), nullable=False), + pa.field("foo_1", pa.uuid(), nullable=False), + pa.field("foo_3", pa.uuid(), nullable=False), + pa.field("foo_4", pa.uuid(), nullable=False), + pa.field("foo_5", pa.uuid(), nullable=False), + ] + ) + + actual = get_pyarrow_schema(ModelWithUUID) + assert actual == expected + + objs = [ + ModelWithUUID().model_dump(), + ModelWithUUID().model_dump(), + ] + + new_schema, new_objs = _write_pq_and_read(objs, expected) + assert new_schema == expected + # objs was created with the uuid serializer to bytes, + # but pyarrow will read the uuids into UUID objects directly + for obj in objs: + for key in obj: + obj[key] = uuid.UUID(bytes=obj[key]) + assert new_objs == objs + + def test_alias() -> None: class AliasModel(BaseModel): field1: str