Skip to content

Commit

Permalink
Add option to create pyarrow schema from pydantic alias (#19)
Browse files Browse the repository at this point in the history
* Move settings to Settings object, for ease of adding new options; update readme

* Use serialization alias for creating pyarrow schema when called with by_alias=True, mirroring functionality in pydantic

* Set specific patch versions for old pydantic testing, to avoid having to build pydantic-core from scratch when testing

* Update README for settings
  • Loading branch information
simw authored Nov 5, 2024
1 parent 8c5517b commit 99f1048
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 58 deletions.
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,15 @@ test-dep-versions: prepare
poetry run pip install pyarrow==14.0.0
poetry run python -m pytest

poetry run pip install pydantic==2.0
poetry run pip install pydantic==2.0.3
poetry run python -m pytest

# Change in alias functionality in 2.5.0
poetry run pip install pydantic==2.4.2
poetry run python -m pytest

poetry run pip install pydantic==2.9.2
poetry run python -m pytest

.PHONY: clean
clean:
Expand Down
20 changes: 17 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
[![pypi](https://img.shields.io/pypi/v/pydantic-to-pyarrow.svg)](https://pypi.python.org/pypi/pydantic-to-pyarrow)
[![versions](https://img.shields.io/pypi/pyversions/pydantic-to-pyarrow.svg)](https://github.com/simw/pydantic-to-pyarrow)
[![license](https://img.shields.io/github/license/simw/pydantic-to-pyarrow.svg)](https://github.com/simw/pydantic-to-pyarrow/blob/main/LICENSE)
[![Download Stats](https://img.shields.io/pypi/dm/pydantic-to-pyarrow)](https://pypistats.org/packages/pydantic-to-pyarrow)

pydantic-to-pyarrow is a library for Python to help with conversion
of pydantic models to pyarrow schemas.

(Please note that this project is not affiliated in any way with the
great teams at [pydantic](https://github.com/pydantic/pydantic) or
[pyarrow](https://github.com/apache/arrow).)

[pydantic](https://github.com/pydantic/pydantic) is a Python library
for data validation, applying type hints / annotations. It enables
the creation of easy or complex data validation rules.
Expand All @@ -27,7 +32,7 @@ processing pipeline:
The easiest approach for steps 3 and 4 above is to let pyarrow infer
the schema from the data. The most involved approach is to
specify the pyarrow schema separate from the pydantic model. In the middle, many
application could benefit from converting the pydantic model to a
applications could benefit from converting the pydantic model to a
pyarrow schema. This library aims to achieve that.

## Installation
Expand Down Expand Up @@ -68,8 +73,17 @@ Dict[..., ...] | pa.map_(pa key_type, pa value_type) |
Enum of str | pa.dictionary(pa.int32(), pa.string()) |
Enum of int | pa.int64() |

If a field is marked as exclude, (`Field(exclude=True)`), then it will be excluded
from the pyarrow schema if exclude_fields is set to True.
## Settings

In a model, if a field is marked as exclude, `Field(exclude=True)`, then it will be excluded
from the pyarrow schema if `get_pyarrow_schema` is called with `exclude_fields=True` (defaults to False).

If `get_pyarrow_schema` is called with `allow_losing_tz=True`, then it will allow conversion
of timezone-aware python datetimes to non-timezone aware pyarrow timestamps
(defaults to False - and loss of timezone information will raise an exception).

By default, `get_pyarrow_schema` will use the field names for the pyarrow schema fields. If
`by_alias=True` is supplied, then the serialization_alias is used. More information about aliases is available in the [Pydantic documentation](https://docs.pydantic.dev/latest/concepts/alias/).

## An Example

Expand Down
12 changes: 6 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ mypy = "^1.6.0"
[tool.poetry.group.test.dependencies]
pytest = "^7.4.2"
coverage = "^7.3.2"
packaging = "^24.1"


[tool.mypy]
Expand Down
86 changes: 39 additions & 47 deletions src/pydantic_to_pyarrow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import types
from decimal import Decimal
from enum import EnumMeta
from typing import Any, List, Literal, Optional, Type, TypeVar, Union, cast
from typing import Any, List, Literal, NamedTuple, Optional, Type, TypeVar, Union, cast

import pyarrow as pa # type: ignore
from annotated_types import Ge, Gt
Expand All @@ -17,6 +17,12 @@ class SchemaCreationError(Exception):
"""Error when creating pyarrow schema."""


class Settings(NamedTuple):
allow_losing_tz: bool
by_alias: bool
exclude_fields: bool


FIELD_MAP = {
str: pa.string(),
bytes: pa.binary(),
Expand Down Expand Up @@ -76,8 +82,7 @@ def _get_decimal_type(metadata: List[Any]) -> pa.DataType:
def _get_literal_type(
field_type: Type[Any],
_metadata: List[Any],
_allow_losing_tz: bool,
_exclude_fields: bool,
_settings: Settings,
) -> pa.DataType:
values = get_args(field_type)
if all(isinstance(value, str) for value in values):
Expand All @@ -94,23 +99,19 @@ def _get_literal_type(
def _get_list_type(
field_type: Type[Any],
metadata: List[Any],
allow_losing_tz: bool,
_exclude_fields: bool,
settings: Settings,
) -> pa.DataType:
sub_type = get_args(field_type)[0]
if _is_optional(sub_type):
# pyarrow lists can have null elements in them
sub_type = list(set(get_args(sub_type)) - {type(None)})[0]
return pa.list_(
_get_pyarrow_type(sub_type, metadata, allow_losing_tz, _exclude_fields)
)
return pa.list_(_get_pyarrow_type(sub_type, metadata, settings))


def _get_annotated_type(
field_type: Type[Any],
metadata: List[Any],
allow_losing_tz: bool,
exclude_fields: bool,
settings: Settings,
) -> pa.DataType:
# TODO: fix / clean up / understand why / if this works in all cases
args = get_args(field_type)[1:]
Expand All @@ -119,29 +120,18 @@ def _get_annotated_type(
]
metadata = [item for sublist in metadatas for item in sublist]
field_type = cast(Type[Any], get_args(field_type)[0])
return _get_pyarrow_type(field_type, metadata, allow_losing_tz, exclude_fields)
return _get_pyarrow_type(field_type, metadata, settings)


def _get_dict_type(
field_type: Type[Any],
metadata: List[Any],
allow_losing_tz: bool,
_exclude_fields: bool,
settings: Settings,
) -> pa.DataType:
key_type, value_type = get_args(field_type)
return pa.map_(
_get_pyarrow_type(
key_type,
metadata,
allow_losing_tz=allow_losing_tz,
exclude_fields=_exclude_fields,
),
_get_pyarrow_type(
value_type,
metadata,
allow_losing_tz=allow_losing_tz,
exclude_fields=_exclude_fields,
),
_get_pyarrow_type(key_type, metadata, settings),
_get_pyarrow_type(value_type, metadata, settings),
)


Expand Down Expand Up @@ -180,16 +170,15 @@ def _is_optional(field_type: Type[Any]) -> bool:
def _get_pyarrow_type(
field_type: Type[Any],
metadata: List[Any],
allow_losing_tz: bool,
exclude_fields: bool,
settings: Settings,
) -> pa.DataType:
if field_type in FIELD_MAP:
return FIELD_MAP[field_type]

if allow_losing_tz and field_type in LOSING_TZ_TYPES:
if settings.allow_losing_tz and field_type in LOSING_TZ_TYPES:
return LOSING_TZ_TYPES[field_type]

if not allow_losing_tz and field_type in LOSING_TZ_TYPES:
if not settings.allow_losing_tz and field_type in LOSING_TZ_TYPES:
raise SchemaCreationError(
f"{field_type} only allowed if ok losing timezone information"
)
Expand All @@ -202,28 +191,27 @@ def _get_pyarrow_type(

if get_origin(field_type) in FIELD_TYPES:
return FIELD_TYPES[get_origin(field_type)](
field_type, metadata, allow_losing_tz, exclude_fields
field_type,
metadata,
settings,
)

# isinstance(filed_type, type) checks whether it's a class
# otherwise eg Deque[int] would casue an exception on issubclass
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
return _get_pyarrow_schema(
field_type, allow_losing_tz, exclude_fields, as_schema=False
)
return _get_pyarrow_schema(field_type, settings, as_schema=False)

raise SchemaCreationError(f"Unknown type: {field_type}")


def _get_pyarrow_schema(
pydantic_class: Type[BaseModelType],
allow_losing_tz: bool,
exclude_fields: bool,
settings: Settings,
as_schema: bool = True,
) -> pa.Schema:
fields = []
for name, field_info in pydantic_class.model_fields.items():
if field_info.exclude and exclude_fields:
if field_info.exclude and settings.exclude_fields:
continue
field_type = field_info.annotation
metadata = field_info.metadata
Expand All @@ -242,18 +230,16 @@ def _get_pyarrow_schema(
# mypy infers field_type as Type[Any] | None here, hence casting
field_type = cast(Type[Any], types_under_union[0])

pa_field = _get_pyarrow_type(
field_type,
metadata,
allow_losing_tz=allow_losing_tz,
exclude_fields=exclude_fields,
)
pa_field = _get_pyarrow_type(field_type, metadata, settings)
except Exception as err: # noqa: BLE001 - ignore blind exception
raise SchemaCreationError(
f"Error processing field {name}: {field_type}, {err}"
) from err

fields.append(pa.field(name, pa_field, nullable=nullable))
serialized_name = name
if settings.by_alias and field_info.serialization_alias is not None:
serialized_name = field_info.serialization_alias
fields.append(pa.field(serialized_name, pa_field, nullable=nullable))

if as_schema:
return pa.schema(fields)
Expand All @@ -264,20 +250,26 @@ def get_pyarrow_schema(
pydantic_class: Type[BaseModelType],
allow_losing_tz: bool = False,
exclude_fields: bool = False,
by_alias: bool = False,
) -> pa.Schema:
"""
Converts a Pydantic model into a PyArrow schema.
Args:
pydantic_class (Type[BaseModelType]): The Pydantic model class to convert.
allow_losing_tz (bool, optional): Whether to allow losing timezone information
when converting datetime fields. Defaults to False.
when converting datetime fields. Defaults to False.
exclude_fields (bool, optional): If True, will exclude fields in the pydantic
model that have `Field(exclude=True)`. Defaults to False.
model that have `Field(exclude=True)`. Defaults to False.
by_alis (bool, optional): If True, will create the pyarrow schema using the
(serialization) alias in the pydantic model. Defaults to False.
Returns:
pa.Schema: The PyArrow schema representing the Pydantic model.
"""
return _get_pyarrow_schema(
pydantic_class, allow_losing_tz, exclude_fields=exclude_fields
settings = Settings(
allow_losing_tz=allow_losing_tz,
by_alias=by_alias,
exclude_fields=exclude_fields,
)
return _get_pyarrow_schema(pydantic_class, settings)
Loading

0 comments on commit 99f1048

Please sign in to comment.