Skip to content

Commit

Permalink
feat: Added PydanticPintQuantity as an option to enforce unit valid…
Browse files Browse the repository at this point in the history
…ation for fields (#56)
  • Loading branch information
pesap authored Jan 15, 2025
1 parent 420e7d0 commit 7d7fbbf
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "infrasys"
version = "0.2.2"
version = "0.2.3"
description = ''
readme = "README.md"
requires-python = ">=3.10, <3.13"
Expand Down
200 changes: 200 additions & 0 deletions src/infrasys/pint_quantities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""Defines the Pydantic `pint.Quantity`."""

from __future__ import annotations

from numbers import Number
from typing import TYPE_CHECKING, Any, Literal

if TYPE_CHECKING:
from pydantic import GetCoreSchemaHandler

import pint
from pint.facets.plain.quantity import PlainQuantity as Quantity
from pydantic_core import core_schema


class PydanticPintQuantity:
"""Pydantic-compatible annotation for validating and serializing `pint.Quantity` fields.
This class allows Pydantic to handle fields that represent quantities with units,
leveraging the `pint` library for unit conversion and validation.
Parameters
----------
units : str
The base units of the Pydantic field. All input units must be convertible
to these base units.
ureg : pint.UnitRegistry, optional
A custom Pint unit registry. If not provided, the default registry is used.
ureg_contexts : str or list of str, optional
A custom Pint context (or a list of contexts) for the default unit registry.
All contexts are applied during validation and conversion.
ser_mode : {"str", "dict"}, optional
The mode for serializing the field. Can be one of:
- `"str"`: Serialize to a string representation of the quantity (default in JSON mode).
- `"dict"`: Serialize to a dictionary representation.
By default, fields are serialized in Pydantic's `"python"` mode, which preserves
the `pint.Quantity` type. In `"json"` mode, the field is serialized as a string.
strict : bool, optional
If `True` (default), forces users to specify units. If `False`, a value without
units (provided by the user) is treated as having the base units of the field.
Notes
-----
This class integrates with Pydantic's validation and serialization system to ensure
that fields representing physical quantities are handled correctly with respect to units.
"""

def __init__(
self,
units: str,
*,
ureg: pint.UnitRegistry | None = None,
ser_mode: Literal["str", "dict"] | None = None,
strict: bool = True,
):
self.ser_mode = ser_mode.lower() if ser_mode else None
self.strict = strict
self.ureg = ureg if ureg else pint.UnitRegistry()
self.units = self.ureg(units)

def validate(
self,
input_value: Any,
info: core_schema.ValidationInfo | None = None,
) -> Quantity:
"""Validate a `PydanticPintQuantity`.
Parameters
----------
input_value : Any
The quantity to validate. This can be a dictionary containing keys `"magnitude"`
and `"units"`, a string representing the quantity, or a `Number` or `Quantity`
object that can be validated and converted to a `pint.Quantity`.
info : core_schema.ValidationInfo, optional
Additional validation information provided by the Pydantic schema. Default is `None`.
Returns
-------
pint.Quantity
The validated `pint.Quantity` with the correct units.
Raises
------
ValueError
If validation fails due to one of the following reasons:
- The provided `dict` does not contain the required `"magnitude"` and `"units"` keys.
- No units are provided when strict mode is enabled.
- The provided units cannot be converted to the base units.
- An unknown unit is provided.
- An invalid type is provided for the value.
TypeError
If the type is not supported.
"""
# NOTE: `self.ureg` when passed returns the right type
if not isinstance(input_value, Quantity):
input_value = self.ureg(input_value) # This convert string to numbers

if isinstance(input_value, Number | list):
input_value = input_value * self.units

# At this point `input_value` should be a `pint.Quantity`.
if not isinstance(input_value, Quantity):
msg = f"{type(input_value)} not supported"
raise TypeError(msg)
try:
input_value = input_value.to(self.units)
except pint.DimensionalityError:
msg = f"Dimension mismatch from {input_value.units} to {self.units}"
raise ValueError(msg)
return input_value

def serialize(
self,
value: Quantity,
info: core_schema.SerializationInfo | None = None,
) -> dict[str, Any] | str | Quantity:
"""
Serialize a `PydanticPintQuantity`.
Parameters
----------
value : pint.Quantity
The quantity to serialize. This should be a `pint.Quantity` object.
info : core_schema.SerializationInfo, optional
The serialization information provided by the Pydantic schema. Default is `None`.
Returns
-------
dict, str, or pint.Quantity
The serialized representation of the quantity.
- If `ser_mode='dict'` or `info.mode='dict'` a dictionary with magnitude and units.
Notes
-----
This method is useful when working with `PydanticPintQuantity` fields outside
of Pydantic models, as it allows control over the serialization format
(e.g., JSON-compatible representation).
"""
if info is not None:
mode = info.mode
else:
mode = self.ser_mode

if mode == "dict":
return {
"magnitude": value.magnitude,
"units": f"{value.units}",
}
elif mode == "str" or mode == "json":
return str(value)
else:
return value

def __get_pydantic_core_schema__(
self,
source_type: Any,
handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
_from_typedict_schema = {
"magnitude": core_schema.typed_dict_field(
core_schema.str_schema(coerce_numbers_to_str=True)
),
"units": core_schema.typed_dict_field(core_schema.str_schema()),
}

validate_schema = core_schema.chain_schema(
[
core_schema.union_schema(
[
core_schema.is_instance_schema(Quantity),
core_schema.str_schema(coerce_numbers_to_str=True),
core_schema.typed_dict_schema(_from_typedict_schema),
]
),
core_schema.with_info_plain_validator_function(self.validate),
]
)

validate_json_schema = core_schema.chain_schema(
[
core_schema.union_schema(
[
core_schema.str_schema(coerce_numbers_to_str=True),
core_schema.typed_dict_schema(_from_typedict_schema),
]
),
core_schema.no_info_plain_validator_function(self.validate),
]
)

serialize_schema = core_schema.plain_serializer_function_ser_schema(
self.serialize,
info_arg=True,
)

return core_schema.json_or_python_schema(
json_schema=validate_json_schema,
python_schema=validate_schema,
serialization=serialize_schema,
)
92 changes: 92 additions & 0 deletions tests/test_pint_quantities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
from typing import Annotated

from pydantic import ValidationError, Field
from infrasys.base_quantity import ureg
from infrasys.component import Component
from infrasys.pint_quantities import PydanticPintQuantity
from infrasys.quantities import Voltage
from pint import Quantity


class PintQuantityStrict(Component):
voltage: Annotated[Quantity, PydanticPintQuantity("volts")]


class PintQuantityNoStrict(Component):
voltage: Annotated[Quantity, PydanticPintQuantity("volts", strict=False)]


class PintQuantityStrictDict(Component):
voltage: Annotated[Quantity, PydanticPintQuantity("volts", ser_mode="dict")]


class PintQuantityStrictDictPositive(Component):
voltage: Annotated[Quantity, PydanticPintQuantity("volts", ser_mode="dict"), Field(gt=0)]


@pytest.mark.parametrize(
"input_value",
[10.0 * ureg.volts, Quantity(10.0, "volt"), Voltage(10.0, "volts")],
ids=["float", "Quantity", "BaseQuantity"],
)
def test_pydantic_pint_multiple_input(input_value):
component = PintQuantityStrict(name="TestComponent", voltage=input_value)
assert isinstance(component.voltage, Quantity)
assert component.voltage.magnitude == 10.0
assert component.voltage.units == "volt"


def test_pydantic_pint_validation():
with pytest.raises(ValidationError):
_ = PintQuantityStrict(name="test", voltage=10.0 * ureg.meter)

# Pass wrong type
with pytest.raises(ValidationError):
_ = PintQuantityStrict(name="test", voltage={10: 2})


def test_compatibility_with_base_quantity():
voltage = Voltage(10.0, "volts")
component = PintQuantityStrict(name="TestComponent", voltage=voltage)
assert isinstance(component.voltage, Quantity)
assert isinstance(component.voltage, Voltage)
assert component.voltage.magnitude == 10.0
assert component.voltage.units == "volt"


def test_pydantic_pint_arguments():
# Single float should work
component = PintQuantityNoStrict(name="TestComponent", voltage=10.0)
assert isinstance(component.voltage, Quantity)
assert component.voltage.magnitude == 10.0
assert component.voltage.units == "volt"

with pytest.raises(ValidationError):
_ = PintQuantityStrictDictPositive(name="TestComponent", voltage=-10)


def test_serialization():
component = PintQuantityStrict(name="TestComponent", voltage=10.0 * ureg.volts)
component_serialized = component.model_dump()
assert isinstance(component_serialized["voltage"], Quantity)
assert component_serialized["voltage"].magnitude == 10.0
assert component_serialized["voltage"].units == "volt"

component_json = component.model_dump(mode="json")
assert component_json["voltage"] == "10.0 volt"

component_dict = component.model_dump(mode="dict")
assert isinstance(component_dict["voltage"], dict)
assert component_dict["voltage"].get("magnitude", False)
assert component_dict["voltage"].get("units", False)
assert component_dict["voltage"]["magnitude"] == 10.0
assert str(component_dict["voltage"]["units"]) == "volt"

component = PintQuantityStrict(name="TestComponent", voltage=10.0 * ureg.volts)
component_json = component.model_dump(mode="json")
assert isinstance(component_dict["voltage"], dict)
assert component_dict["voltage"].get("magnitude", False)
assert component_dict["voltage"].get("units", False)
assert component_dict["voltage"]["magnitude"] == 10.0
assert component_dict["voltage"]["units"] == "volt"

0 comments on commit 7d7fbbf

Please sign in to comment.