Skip to content

Commit

Permalink
feature: if pydantic field info has exclude=True, exclude it from the…
Browse files Browse the repository at this point in the history
… schema (#14)

* if field is set to "exclude", and exclude_fields flag is true, don't serialize it
* add tests for exclude fields
  • Loading branch information
mae5357 authored May 23, 2024
1 parent fe2cc1a commit 0e84490
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 14 deletions.
80 changes: 66 additions & 14 deletions src/pydantic_to_pyarrow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def _get_decimal_type(metadata: List[Any]) -> pa.DataType:


def _get_literal_type(
field_type: Type[Any], _metadata: List[Any], _allow_losing_tz: bool
field_type: Type[Any],
_metadata: List[Any],
_allow_losing_tz: bool,
_exclude_fields: bool,
) -> pa.DataType:
values = get_args(field_type)
if all(isinstance(value, str) for value in values):
Expand All @@ -89,17 +92,25 @@ def _get_literal_type(


def _get_list_type(
field_type: Type[Any], metadata: List[Any], allow_losing_tz: bool
field_type: Type[Any],
metadata: List[Any],
allow_losing_tz: bool,
_exclude_fields: bool,
) -> pa.DataType:
sub_type = get_args(field_type)[0]
if _is_optional(sub_type):
# pyarrow lists can have null elements in them
sub_type = list(set(get_args(sub_type)) - {type(None)})[0]
return pa.list_(_get_pyarrow_type(sub_type, metadata, allow_losing_tz))
return pa.list_(
_get_pyarrow_type(sub_type, metadata, allow_losing_tz, _exclude_fields)
)


def _get_annotated_type(
field_type: Type[Any], metadata: List[Any], allow_losing_tz: bool
field_type: Type[Any],
metadata: List[Any],
allow_losing_tz: bool,
exclude_fields: bool,
) -> pa.DataType:
# TODO: fix / clean up / understand why / if this works in all cases
args = get_args(field_type)[1:]
Expand All @@ -108,16 +119,29 @@ def _get_annotated_type(
]
metadata = [item for sublist in metadatas for item in sublist]
field_type = cast(Type[Any], get_args(field_type)[0])
return _get_pyarrow_type(field_type, metadata, allow_losing_tz)
return _get_pyarrow_type(field_type, metadata, allow_losing_tz, exclude_fields)


def _get_dict_type(
field_type: Type[Any], metadata: List[Any], allow_losing_tz: bool
field_type: Type[Any],
metadata: List[Any],
allow_losing_tz: bool,
_exclude_fields: bool,
) -> pa.DataType:
key_type, value_type = get_args(field_type)
return pa.map_(
_get_pyarrow_type(key_type, metadata, allow_losing_tz=allow_losing_tz),
_get_pyarrow_type(value_type, metadata, allow_losing_tz=allow_losing_tz),
_get_pyarrow_type(
key_type,
metadata,
allow_losing_tz=allow_losing_tz,
exclude_fields=_exclude_fields,
),
_get_pyarrow_type(
value_type,
metadata,
allow_losing_tz=allow_losing_tz,
exclude_fields=_exclude_fields,
),
)


Expand Down Expand Up @@ -154,7 +178,10 @@ def _is_optional(field_type: Type[Any]) -> bool:


def _get_pyarrow_type(
field_type: Type[Any], metadata: List[Any], allow_losing_tz: bool
field_type: Type[Any],
metadata: List[Any],
allow_losing_tz: bool,
exclude_fields: bool,
) -> pa.DataType:
if field_type in FIELD_MAP:
return FIELD_MAP[field_type]
Expand All @@ -175,24 +202,29 @@ def _get_pyarrow_type(

if get_origin(field_type) in FIELD_TYPES:
return FIELD_TYPES[get_origin(field_type)](
field_type, metadata, allow_losing_tz
field_type, metadata, allow_losing_tz, exclude_fields
)

# isinstance(filed_type, type) checks whether it's a class
# otherwise eg Deque[int] would casue an exception on issubclass
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
return _get_pyarrow_schema(field_type, allow_losing_tz, as_schema=False)
return _get_pyarrow_schema(
field_type, allow_losing_tz, exclude_fields, as_schema=False
)

raise SchemaCreationError(f"Unknown type: {field_type}")


def _get_pyarrow_schema(
pydantic_class: Type[BaseModelType],
allow_losing_tz: bool,
exclude_fields: bool,
as_schema: bool = True,
) -> pa.Schema:
fields = []
for name, field_info in pydantic_class.model_fields.items():
if field_info.exclude and exclude_fields:
continue
field_type = field_info.annotation
metadata = field_info.metadata

Expand All @@ -211,7 +243,10 @@ def _get_pyarrow_schema(
field_type = cast(Type[Any], types_under_union[0])

pa_field = _get_pyarrow_type(
field_type, metadata, allow_losing_tz=allow_losing_tz
field_type,
metadata,
allow_losing_tz=allow_losing_tz,
exclude_fields=exclude_fields,
)
except Exception as err: # noqa: BLE001 - ignore blind exception
raise SchemaCreationError(
Expand All @@ -226,6 +261,23 @@ def _get_pyarrow_schema(


def get_pyarrow_schema(
pydantic_class: Type[BaseModelType], allow_losing_tz: bool = False
pydantic_class: Type[BaseModelType],
allow_losing_tz: bool = False,
exclude_fields: bool = False,
) -> pa.Schema:
return _get_pyarrow_schema(pydantic_class, allow_losing_tz)
"""
Converts a Pydantic model into a PyArrow schema.
Args:
pydantic_class (Type[BaseModelType]): The Pydantic model class to convert.
allow_losing_tz (bool, optional): Whether to allow losing timezone information
when converting datetime fields. Defaults to False.
exclude_fields (bool, optional): If True, will exclude fields in the pydantic
model that have `Field(exclude=True)`. Defaults to False.
Returns:
pa.Schema: The PyArrow schema representing the Pydantic model.
"""
return _get_pyarrow_schema(
pydantic_class, allow_losing_tz, exclude_fields=exclude_fields
)
33 changes: 33 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,39 @@ class EnumModel(BaseModel):
get_pyarrow_schema(EnumModel)


def test_exclude_field_true() -> None:
class SimpleModel(BaseModel):
a: str
b: str = Field(exclude=True)

expected = pa.schema(
[
pa.field("a", pa.string(), nullable=False),
]
)

actual = get_pyarrow_schema(SimpleModel, exclude_fields=True)

assert actual == expected


def test_exclude_fields_false() -> None:
class SimpleModel(BaseModel):
a: str
b: str = Field(exclude=True)

expected = pa.schema(
[
pa.field("a", pa.string(), nullable=False),
pa.field("b", pa.string(), nullable=False),
]
)

actual = get_pyarrow_schema(SimpleModel)

assert actual == expected


def test_dict() -> None:
class DictModel(BaseModel):
foo: Dict[str, int]
Expand Down

0 comments on commit 0e84490

Please sign in to comment.