Skip to content

Commit

Permalink
Use serialization alias for creating pyarrow schema when called with …
Browse files Browse the repository at this point in the history
…by_alias=True, mirroring functionality in pydantic
  • Loading branch information
simw committed Nov 5, 2024
1 parent dd71ea9 commit 19bfcbb
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 8 deletions.
12 changes: 6 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ mypy = "^1.6.0"
[tool.poetry.group.test.dependencies]
pytest = "^7.4.2"
coverage = "^7.3.2"
packaging = "^24.1"


[tool.mypy]
Expand Down
5 changes: 4 additions & 1 deletion src/pydantic_to_pyarrow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,10 @@ def _get_pyarrow_schema(
f"Error processing field {name}: {field_type}, {err}"
) from err

fields.append(pa.field(name, pa_field, nullable=nullable))
serialized_name = name
if settings.by_alias and field_info.serialization_alias is not None:
serialized_name = field_info.serialization_alias
fields.append(pa.field(serialized_name, pa_field, nullable=nullable))

if as_schema:
return pa.schema(fields)
Expand Down
97 changes: 96 additions & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import pyarrow as pa # type: ignore
import pyarrow.parquet as pq # type: ignore
import pydantic
import pytest
from annotated_types import Gt
from pydantic import BaseModel, Field
from packaging import version
from pydantic import BaseModel, ConfigDict, Field
from pydantic.types import (
AwareDatetime,
NaiveDatetime,
Expand Down Expand Up @@ -588,3 +590,96 @@ class DictModel(BaseModel):

# pyarrow converts to tuples, need to convert back to dicts
assert objs == [{"foo": dict(t["foo"])} for t in new_objs]


def test_alias() -> None:
class AliasModel(BaseModel):
field1: str
field2: str = Field(alias="b2")
field3: str = Field(validation_alias="b3")
field4: str = Field(serialization_alias="b4")
field5: str = Field(alias="b5", serialization_alias="c5")
field6: Annotated[str, Field(alias="b6")]

expected_no_alias = pa.schema(
[
pa.field("field1", pa.string(), nullable=False),
pa.field("field2", pa.string(), nullable=False),
pa.field("field3", pa.string(), nullable=False),
pa.field("field4", pa.string(), nullable=False),
pa.field("field5", pa.string(), nullable=False),
pa.field("field6", pa.string(), nullable=False),
]
)

actual_no_alias = get_pyarrow_schema(AliasModel)
assert actual_no_alias == expected_no_alias

expected_by_alias = pa.schema(
[
pa.field("field1", pa.string(), nullable=False),
pa.field("b2", pa.string(), nullable=False),
pa.field("field3", pa.string(), nullable=False),
pa.field("b4", pa.string(), nullable=False),
pa.field("c5", pa.string(), nullable=False),
pa.field("b6", pa.string(), nullable=False),
]
)
actual_by_alias = get_pyarrow_schema(AliasModel, by_alias=True)
assert actual_by_alias == expected_by_alias


def test_alias_generator() -> None:
class AliasModel(BaseModel):
model_config = ConfigDict(alias_generator=lambda field_name: field_name.upper())
field1: str
field2: str = Field(alias="b2")
field3: str = Field(validation_alias="b3")
field4: str = Field(serialization_alias="b4")
field5: str = Field(alias="b5", serialization_alias="c5")
field6: str = Field(alias="b6", alias_priority=1)

expected_no_alias = pa.schema(
[
pa.field("field1", pa.string(), nullable=False),
pa.field("field2", pa.string(), nullable=False),
pa.field("field3", pa.string(), nullable=False),
pa.field("field4", pa.string(), nullable=False),
pa.field("field5", pa.string(), nullable=False),
pa.field("field6", pa.string(), nullable=False),
]
)

actual_no_alias = get_pyarrow_schema(AliasModel)
assert actual_no_alias == expected_no_alias

expected_by_alias = pa.schema(
[
pa.field("FIELD1", pa.string(), nullable=False),
pa.field("b2", pa.string(), nullable=False),
pa.field("FIELD3", pa.string(), nullable=False),
pa.field("b4", pa.string(), nullable=False),
pa.field("c5", pa.string(), nullable=False),
pa.field("FIELD6", pa.string(), nullable=False),
]
)

pydantic_version = version.parse(pydantic.__version__)
if pydantic_version < version.parse("2.5.0"):
# pydantic 2.5.0 fixed an issue / bug, that setting validation_alias
# would then remove the alias_generator from the serialization_alias
# This library follows the functionality in the installed version
# of pydantic.
expected_by_alias = pa.schema(
[
pa.field("FIELD1", pa.string(), nullable=False),
pa.field("b2", pa.string(), nullable=False),
pa.field("field3", pa.string(), nullable=False),
pa.field("b4", pa.string(), nullable=False),
pa.field("c5", pa.string(), nullable=False),
pa.field("FIELD6", pa.string(), nullable=False),
]
)

actual_by_alias = get_pyarrow_schema(AliasModel, by_alias=True)
assert actual_by_alias == expected_by_alias

0 comments on commit 19bfcbb

Please sign in to comment.