Skip to content

Commit

Permalink
(postponed, requires upstream fix) add support for UUID (#24)
Browse files Browse the repository at this point in the history
* add support for UUID

* Update UUID code to deal with different pyarrow versions; add documentation for UUID conversion

---------

Co-authored-by: simw <[email protected]>
  • Loading branch information
choucavalier and simw authored Jan 31, 2025
1 parent c5d0392 commit 79418af
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 2 deletions.
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,39 @@ List[...] | pa.list_(...) |
Dict[..., ...] | pa.map_(pa key_type, pa value_type) |
Enum of str | pa.dictionary(pa.int32(), pa.string()) |
Enum of int | pa.int64() |
UUID (uuid.UUID or pydantic.types.UUID*) | pa.uuid() | SEE NOTE BELOW!

Note on UUIDs: the UUID type is only supported in pyarrow 18.0 and above. However,
as of pyarrow 19.0, when pyarrow creates a table in eg `pa.Table.from_pylist(objs, schema=schema)`,
it expects bytes not a uuid.UUID type. Hence, if you are using .model_dump() to create
the data for pyarrow, you need to add a serializer on your pydantic model to convert to bytes.
This may be fixed in later versions (see [https://github.com/apache/arrow/issues/43855]).

eg (with pyarrow >= 18.0):
```py
import uuid
from typing import Annotated

import pyarrow as pa
from pydantic import BaseModel, PlainSerializer
from pydantic_to_pyarrow import get_pyarrow_schema

class ModelWithUuid(BaseModel):
uuid: Annotated[uuid.UUID, PlainSerializer(lambda x: x.bytes, return_type=bytes)]


schema = get_pyarrow_schema(ModelWithUuid)

model1 = ModelWithUuid(uuid=uuid.uuid1())
model2 = ModelWithUuid(uuid=uuid.uuid4())
data = [model1.model_dump(), model2.model_dump()]
table = pa.Table.from_pylist(data)
print(table)
#> pyarrow.Table
#> uuid: binary
#> ----
#> uuid: [[BF206AC0DA4711EF8271EF4F4B7A3587,211C4C5D94C74876AE5E32DBCCDC16C7]]
```

## Settings

Expand Down
20 changes: 19 additions & 1 deletion src/pydantic_to_pyarrow/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import types
import uuid
from decimal import Decimal
from enum import EnumMeta
from typing import Any, List, Literal, NamedTuple, Optional, Type, TypeVar, Union, cast
Expand Down Expand Up @@ -156,6 +157,18 @@ def _get_enum_type(field_type: Type[Any]) -> pa.DataType:
raise SchemaCreationError(msg)


def _get_uuid_type() -> pa.DataType:
# Different branches will execute depending on the pyarrow version
# This is tested through nox and python versions, but each one
# won't cover both branches. Hence, excluding from coverage.
if hasattr(pa, "uuid"): # pragma: no cover
return pa.uuid()
else: # pragma: no cover
msg = f"pyarrow version {pa.__version__} does not support pa.uuid() type, "
msg += "needs version 18.0 or higher"
raise SchemaCreationError(msg)


def _is_optional(field_type: Type[Any]) -> bool:
origin = get_origin(field_type)
is_python_39_union = origin is Union
Expand All @@ -167,14 +180,19 @@ def _is_optional(field_type: Type[Any]) -> bool:
return type(None) in get_args(field_type)


def _get_pyarrow_type(
# noqa: PLR0911 - ignore until a refactoring can reduce the number of
# return statements.
def _get_pyarrow_type( # noqa: PLR0911
field_type: Type[Any],
metadata: List[Any],
settings: Settings,
) -> pa.DataType:
if field_type in FIELD_MAP:
return FIELD_MAP[field_type]

if field_type is uuid.UUID:
return _get_uuid_type()

if settings.allow_losing_tz and field_type in LOSING_TZ_TYPES:
return LOSING_TZ_TYPES[field_type]

Expand Down
58 changes: 57 additions & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import tempfile
import uuid
from decimal import Decimal
from enum import Enum, auto
from pathlib import Path
Expand All @@ -11,8 +12,12 @@
import pytest
from annotated_types import Gt
from packaging import version
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, PlainSerializer
from pydantic.types import (
UUID1,
UUID3,
UUID4,
UUID5,
AwareDatetime,
NaiveDatetime,
PositiveInt,
Expand Down Expand Up @@ -592,6 +597,57 @@ class DictModel(BaseModel):
assert objs == [{"foo": dict(t["foo"])} for t in new_objs]


def test_uuid() -> None:
# pyarrow 18.0.0+ is required for UUID support
# Even then, pyarrow doesn't automatically convert UUIDs to bytes
# for the serialization, so we need to do that manually
# (https://github.com/apache/arrow/issues/43855)
as_bytes = PlainSerializer(lambda x: x.bytes, return_type=bytes)

class ModelWithUUID(BaseModel):
foo_0: Annotated[uuid.UUID, as_bytes] = Field(default_factory=uuid.uuid1)
foo_1: Annotated[UUID1, as_bytes] = Field(default_factory=uuid.uuid1)
foo_3: Annotated[UUID3, as_bytes] = Field(
default_factory=lambda: uuid.uuid3(uuid.NAMESPACE_DNS, "pydantic.org")
)
foo_4: Annotated[UUID4, as_bytes] = Field(default_factory=uuid.uuid4)
foo_5: Annotated[UUID5, as_bytes] = Field(
default_factory=lambda: uuid.uuid5(uuid.NAMESPACE_DNS, "pydantic.org")
)

if version.Version(pa.__version__) < version.Version("18.0.0"):
with pytest.raises(SchemaCreationError) as err:
get_pyarrow_schema(ModelWithUUID)
assert "needs version 18.0 or higher" in str(err)
else:
expected = pa.schema(
[
pa.field("foo_0", pa.uuid(), nullable=False),
pa.field("foo_1", pa.uuid(), nullable=False),
pa.field("foo_3", pa.uuid(), nullable=False),
pa.field("foo_4", pa.uuid(), nullable=False),
pa.field("foo_5", pa.uuid(), nullable=False),
]
)

actual = get_pyarrow_schema(ModelWithUUID)
assert actual == expected

objs = [
ModelWithUUID().model_dump(),
ModelWithUUID().model_dump(),
]

new_schema, new_objs = _write_pq_and_read(objs, expected)
assert new_schema == expected
# objs was created with the uuid serializer to bytes,
# but pyarrow will read the uuids into UUID objects directly
for obj in objs:
for key in obj:
obj[key] = uuid.UUID(bytes=obj[key])
assert new_objs == objs


def test_alias() -> None:
class AliasModel(BaseModel):
field1: str
Expand Down

0 comments on commit 79418af

Please sign in to comment.