diff --git a/poetry.lock b/poetry.lock index 8b939fe..a283e1b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -307,13 +307,13 @@ files = [ [[package]] name = "packaging" -version = "23.2" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, - {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -615,4 +615,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "3f87dc9aecac896fba50c12317c100e946a0a00f350fbcd87532342ec8fd403a" +content-hash = "7c309b88fc2cf6ad279c7b857ad8210d8eed394e2ce719fb4ebff47253da3785" diff --git a/pyproject.toml b/pyproject.toml index 1d948e1..8630b6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ mypy = "^1.6.0" [tool.poetry.group.test.dependencies] pytest = "^7.4.2" coverage = "^7.3.2" +packaging = "^24.1" [tool.mypy] diff --git a/src/pydantic_to_pyarrow/schema.py b/src/pydantic_to_pyarrow/schema.py index 884ebf4..81e91d2 100644 --- a/src/pydantic_to_pyarrow/schema.py +++ b/src/pydantic_to_pyarrow/schema.py @@ -236,7 +236,10 @@ def _get_pyarrow_schema( f"Error processing field {name}: {field_type}, {err}" ) from err - fields.append(pa.field(name, pa_field, nullable=nullable)) + serialized_name = name + if settings.by_alias and field_info.serialization_alias is not None: + serialized_name = field_info.serialization_alias + fields.append(pa.field(serialized_name, pa_field, nullable=nullable)) if as_schema: return pa.schema(fields) diff --git a/tests/test_schema.py b/tests/test_schema.py index b8044ef..3729eab 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -7,9 +7,11 @@ import pyarrow as pa # type: ignore import pyarrow.parquet as pq # type: ignore +import pydantic import pytest from annotated_types import Gt -from pydantic import BaseModel, Field +from packaging import version +from pydantic import BaseModel, ConfigDict, Field from pydantic.types import ( AwareDatetime, NaiveDatetime, @@ -588,3 +590,96 @@ class DictModel(BaseModel): # pyarrow converts to tuples, need to convert back to dicts assert objs == [{"foo": dict(t["foo"])} for t in new_objs] + + +def test_alias() -> None: + class AliasModel(BaseModel): + field1: str + field2: str = Field(alias="b2") + field3: str = Field(validation_alias="b3") + field4: str = Field(serialization_alias="b4") + field5: str = Field(alias="b5", serialization_alias="c5") + field6: Annotated[str, Field(alias="b6")] + + expected_no_alias = pa.schema( + [ + pa.field("field1", pa.string(), nullable=False), + pa.field("field2", pa.string(), nullable=False), + pa.field("field3", pa.string(), nullable=False), + pa.field("field4", pa.string(), nullable=False), + pa.field("field5", pa.string(), nullable=False), + pa.field("field6", pa.string(), nullable=False), + ] + ) + + actual_no_alias = get_pyarrow_schema(AliasModel) + assert actual_no_alias == expected_no_alias + + expected_by_alias = pa.schema( + [ + pa.field("field1", pa.string(), nullable=False), + pa.field("b2", pa.string(), nullable=False), + pa.field("field3", pa.string(), nullable=False), + pa.field("b4", pa.string(), nullable=False), + pa.field("c5", pa.string(), nullable=False), + pa.field("b6", pa.string(), nullable=False), + ] + ) + actual_by_alias = get_pyarrow_schema(AliasModel, by_alias=True) + assert actual_by_alias == expected_by_alias + + +def test_alias_generator() -> None: + class AliasModel(BaseModel): + model_config = ConfigDict(alias_generator=lambda field_name: field_name.upper()) + field1: str + field2: str = Field(alias="b2") + field3: str = Field(validation_alias="b3") + field4: str = Field(serialization_alias="b4") + field5: str = Field(alias="b5", serialization_alias="c5") + field6: str = Field(alias="b6", alias_priority=1) + + expected_no_alias = pa.schema( + [ + pa.field("field1", pa.string(), nullable=False), + pa.field("field2", pa.string(), nullable=False), + pa.field("field3", pa.string(), nullable=False), + pa.field("field4", pa.string(), nullable=False), + pa.field("field5", pa.string(), nullable=False), + pa.field("field6", pa.string(), nullable=False), + ] + ) + + actual_no_alias = get_pyarrow_schema(AliasModel) + assert actual_no_alias == expected_no_alias + + expected_by_alias = pa.schema( + [ + pa.field("FIELD1", pa.string(), nullable=False), + pa.field("b2", pa.string(), nullable=False), + pa.field("FIELD3", pa.string(), nullable=False), + pa.field("b4", pa.string(), nullable=False), + pa.field("c5", pa.string(), nullable=False), + pa.field("FIELD6", pa.string(), nullable=False), + ] + ) + + pydantic_version = version.parse(pydantic.__version__) + if pydantic_version < version.parse("2.5.0"): + # pydantic 2.5.0 fixed an issue / bug, that setting validation_alias + # would then remove the alias_generator from the serialization_alias + # This library follows the functionality in the installed version + # of pydantic. + expected_by_alias = pa.schema( + [ + pa.field("FIELD1", pa.string(), nullable=False), + pa.field("b2", pa.string(), nullable=False), + pa.field("field3", pa.string(), nullable=False), + pa.field("b4", pa.string(), nullable=False), + pa.field("c5", pa.string(), nullable=False), + pa.field("FIELD6", pa.string(), nullable=False), + ] + ) + + actual_by_alias = get_pyarrow_schema(AliasModel, by_alias=True) + assert actual_by_alias == expected_by_alias