diff --git a/lonboard/traits.py b/lonboard/traits.py index a07f3995..a6cc4cb7 100644 --- a/lonboard/traits.py +++ b/lonboard/traits.py @@ -186,7 +186,6 @@ def __init__( super().__init__(*args, **kwargs) self.tag(sync=True, **COLOR_SERIALIZATION) - # TODO: subclass self.error so that `info` is actually printed? def validate( self, obj, value ) -> Union[Tuple[int, ...], List[int], pa.ChunkedArray, pa.FixedSizeListArray]: @@ -279,6 +278,9 @@ class FloatAccessor(FixedErrorTraitType): - A numpy `ndarray` with a numeric data type. This will be casted to an array of data type [`np.float32`][numpy.float32]. Each value in the array will be used as the value for the object at the same row index. + - A pandas `Series` with a numeric data type. This will be casted to an array of + data type [`np.float32`][numpy.float32]. Each value in the array will be used as + the value for the object at the same row index. - A pyarrow [`FloatArray`][pyarrow.FloatArray], [`DoubleArray`][pyarrow.DoubleArray] or [`ChunkedArray`][pyarrow.ChunkedArray] containing either a `FloatArray` or `DoubleArray`. Each value in the array will be used as the value for the object at @@ -299,11 +301,18 @@ def __init__( super().__init__(*args, **kwargs) self.tag(sync=True, **FLOAT_SERIALIZATION) - # TODO: subclass self.error so that `info` is actually printed? def validate(self, obj, value) -> Union[float, pa.ChunkedArray, pa.DoubleArray]: if isinstance(value, (int, float)): return float(value) + # pandas Series + if ( + value.__class__.__module__.startswith("pandas") + and value.__class__.__name__ == "Series" + ): + # Cast pandas Series to numpy ndarray + value = np.asarray(value) + if isinstance(value, np.ndarray): if not np.issubdtype(value.dtype, np.number): self.error(obj, value, info="numeric dtype") diff --git a/tests/test_traits.py b/tests/test_traits.py index 1d4151dd..6845fb9a 100644 --- a/tests/test_traits.py +++ b/tests/test_traits.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import pyarrow as pa import pytest from ipywidgets import Widget @@ -133,6 +134,9 @@ def test_float_accessor_validation_type(): FloatAccessorWidget(value=np.array([2, 3, 4])) FloatAccessorWidget(value=np.array([2, 3, 4], dtype=np.float32)) FloatAccessorWidget(value=np.array([2, 3, 4], dtype=np.float64)) + FloatAccessorWidget(value=pd.Series([2, 3, 4])) + FloatAccessorWidget(value=pd.Series([2, 3, 4], dtype=np.float32)) + FloatAccessorWidget(value=pd.Series([2, 3, 4], dtype=np.float64)) # Must be floating-point array type with pytest.raises(TraitError):