Skip to content

Commit

Permalink
Update UUID code to deal with different pyarrow versions; add documen…
Browse files Browse the repository at this point in the history
…tation for UUID conversion
  • Loading branch information
simw committed Jan 24, 2025
1 parent 404f918 commit 154c80b
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 28 deletions.
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,39 @@ List[...] | pa.list_(...) |
Dict[..., ...] | pa.map_(pa key_type, pa value_type) |
Enum of str | pa.dictionary(pa.int32(), pa.string()) |
Enum of int | pa.int64() |
UUID (uuid.UUID or pydantic.types.UUID*) | pa.uuid() | SEE NOTE BELOW!

Note on UUIDs: the UUID type is only supported in pyarrow 18.0 and above. However,
as of pyarrow 19.0, when pyarrow creates a table in eg `pa.Table.from_pylist(objs, schema=schema)`,
it expects bytes not a uuid.UUID type. Hence, if you are using .model_dump() to create
the data for pyarrow, you need to add a serializer on your pydantic model to convert to bytes.
This may be fixed in later versions (see [https://github.com/apache/arrow/issues/43855]).

eg (with pyarrow >= 18.0):
```py
import uuid
from typing import Annotated

import pyarrow as pa
from pydantic import BaseModel, PlainSerializer
from pydantic_to_pyarrow import get_pyarrow_schema

class ModelWithUuid(BaseModel):
uuid: Annotated[uuid.UUID, PlainSerializer(lambda x: x.bytes, return_type=bytes)]


schema = get_pyarrow_schema(ModelWithUuid)

model1 = ModelWithUuid(uuid=uuid.uuid1())
model2 = ModelWithUuid(uuid=uuid.uuid4())
data = [model1.model_dump(), model2.model_dump()]
table = pa.Table.from_pylist(data)
print(table)
#> pyarrow.Table
#> uuid: binary
#> ----
#> uuid: [[BF206AC0DA4711EF8271EF4F4B7A3587,211C4C5D94C74876AE5E32DBCCDC16C7]]
```

## Settings

Expand Down
20 changes: 18 additions & 2 deletions src/pydantic_to_pyarrow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class Settings(NamedTuple):
datetime.date: pa.date32(),
NaiveDatetime: pa.timestamp("ms", tz=None),
datetime.time: pa.time64("us"),
uuid.UUID: pa.uuid(),
}

# Timezone aware datetimes will lose their timezone information
Expand Down Expand Up @@ -158,6 +157,18 @@ def _get_enum_type(field_type: Type[Any]) -> pa.DataType:
raise SchemaCreationError(msg)


def _get_uuid_type() -> pa.DataType:
# Different branches will execute depending on the pyarrow version
# This is tested through nox and python versions, but each one
# won't cover both branches. Hence, excluding from coverage.
if hasattr(pa, "uuid"): # pragma: no cover
return pa.uuid()
else: # pragma: no cover
msg = f"pyarrow version {pa.__version__} does not support pa.uuid() type, "
msg += "needs version 18.0 or higher"
raise SchemaCreationError(msg)


def _is_optional(field_type: Type[Any]) -> bool:
origin = get_origin(field_type)
is_python_39_union = origin is Union
Expand All @@ -169,14 +180,19 @@ def _is_optional(field_type: Type[Any]) -> bool:
return type(None) in get_args(field_type)


def _get_pyarrow_type(
# noqa: PLR0911 - ignore until a refactoring can reduce the number of
# return statements.
def _get_pyarrow_type( # noqa: PLR0911
field_type: Type[Any],
metadata: List[Any],
settings: Settings,
) -> pa.DataType:
if field_type in FIELD_MAP:
return FIELD_MAP[field_type]

if field_type is uuid.UUID:
return _get_uuid_type()

if settings.allow_losing_tz and field_type in LOSING_TZ_TYPES:
return LOSING_TZ_TYPES[field_type]

Expand Down
65 changes: 39 additions & 26 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@

from pydantic_to_pyarrow import SchemaCreationError, get_pyarrow_schema

_uuid_bytes_serializer = PlainSerializer(lambda x: x.bytes, return_type=bytes)
_ByteSerializedUUID1Field = Annotated[UUID1, _uuid_bytes_serializer]
_ByteSerializedUUID3Field = Annotated[UUID3, _uuid_bytes_serializer]
_ByteSerializedUUID4Field = Annotated[UUID4, _uuid_bytes_serializer]
_ByteSerializedUUID5Field = Annotated[UUID5, _uuid_bytes_serializer]


def _write_pq_and_read(
objs: List[Dict[str, Any]], schema: pa.Schema
Expand Down Expand Up @@ -604,35 +598,54 @@ class DictModel(BaseModel):


def test_uuid() -> None:
# pyarrow 18.0.0+ is required for UUID support
# Even then, pyarrow doesn't automatically convert UUIDs to bytes
# for the serialization, so we need to do that manually
# (https://github.com/apache/arrow/issues/43855)
as_bytes = PlainSerializer(lambda x: x.bytes, return_type=bytes)

class ModelWithUUID(BaseModel):
foo_1: _ByteSerializedUUID1Field = Field(default_factory=uuid.uuid1)
foo_3: _ByteSerializedUUID3Field = Field(
foo_0: Annotated[uuid.UUID, as_bytes] = Field(default_factory=uuid.uuid1)
foo_1: Annotated[UUID1, as_bytes] = Field(default_factory=uuid.uuid1)
foo_3: Annotated[UUID3, as_bytes] = Field(
default_factory=lambda: uuid.uuid3(uuid.NAMESPACE_DNS, "pydantic.org")
)
foo_4: _ByteSerializedUUID4Field = Field(default_factory=uuid.uuid4)
foo_5: _ByteSerializedUUID5Field = Field(
foo_4: Annotated[UUID4, as_bytes] = Field(default_factory=uuid.uuid4)
foo_5: Annotated[UUID5, as_bytes] = Field(
default_factory=lambda: uuid.uuid5(uuid.NAMESPACE_DNS, "pydantic.org")
)

expected = pa.schema(
[
pa.field("foo_1", pa.uuid(), nullable=False),
pa.field("foo_3", pa.uuid(), nullable=False),
pa.field("foo_4", pa.uuid(), nullable=False),
pa.field("foo_5", pa.uuid(), nullable=False),
]
)
if version.Version(pa.__version__) < version.Version("18.0.0"):
with pytest.raises(SchemaCreationError) as err:
get_pyarrow_schema(ModelWithUUID)
assert "needs version 18.0 or higher" in str(err)
else:
expected = pa.schema(
[
pa.field("foo_0", pa.uuid(), nullable=False),
pa.field("foo_1", pa.uuid(), nullable=False),
pa.field("foo_3", pa.uuid(), nullable=False),
pa.field("foo_4", pa.uuid(), nullable=False),
pa.field("foo_5", pa.uuid(), nullable=False),
]
)

objs = [
ModelWithUUID().model_dump(),
ModelWithUUID().model_dump(),
]
actual = get_pyarrow_schema(ModelWithUUID)
assert actual == expected

actual = get_pyarrow_schema(ModelWithUUID)
assert actual == expected
objs = [
ModelWithUUID().model_dump(),
ModelWithUUID().model_dump(),
]

new_schema, new_objs = _write_pq_and_read(objs, expected)
assert new_schema == expected
new_schema, new_objs = _write_pq_and_read(objs, expected)
assert new_schema == expected
# objs was created with the uuid serializer to bytes,
# but pyarrow will read the uuids into UUID objects directly
for obj in objs:
for key in obj:
obj[key] = uuid.UUID(bytes=obj[key])
assert new_objs == objs


def test_alias() -> None:
Expand Down

0 comments on commit 154c80b

Please sign in to comment.