Skip to content

Commit

Permalink
Allow pandas series as accessor to FloatAccessor (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Nov 6, 2023
1 parent 12a34ad commit a2b6913
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
13 changes: 11 additions & 2 deletions lonboard/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def __init__(
super().__init__(*args, **kwargs)
self.tag(sync=True, **COLOR_SERIALIZATION)

# TODO: subclass self.error so that `info` is actually printed?
def validate(
self, obj, value
) -> Union[Tuple[int, ...], List[int], pa.ChunkedArray, pa.FixedSizeListArray]:
Expand Down Expand Up @@ -279,6 +278,9 @@ class FloatAccessor(FixedErrorTraitType):
- A numpy `ndarray` with a numeric data type. This will be casted to an array of
data type [`np.float32`][numpy.float32]. Each value in the array will be used as
the value for the object at the same row index.
- A pandas `Series` with a numeric data type. This will be casted to an array of
data type [`np.float32`][numpy.float32]. Each value in the array will be used as
the value for the object at the same row index.
- A pyarrow [`FloatArray`][pyarrow.FloatArray], [`DoubleArray`][pyarrow.DoubleArray]
or [`ChunkedArray`][pyarrow.ChunkedArray] containing either a `FloatArray` or
`DoubleArray`. Each value in the array will be used as the value for the object at
Expand All @@ -299,11 +301,18 @@ def __init__(
super().__init__(*args, **kwargs)
self.tag(sync=True, **FLOAT_SERIALIZATION)

# TODO: subclass self.error so that `info` is actually printed?
def validate(self, obj, value) -> Union[float, pa.ChunkedArray, pa.DoubleArray]:
if isinstance(value, (int, float)):
return float(value)

# pandas Series
if (
value.__class__.__module__.startswith("pandas")
and value.__class__.__name__ == "Series"
):
# Cast pandas Series to numpy ndarray
value = np.asarray(value)

if isinstance(value, np.ndarray):
if not np.issubdtype(value.dtype, np.number):
self.error(obj, value, info="numeric dtype")
Expand Down
4 changes: 4 additions & 0 deletions tests/test_traits.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from ipywidgets import Widget
Expand Down Expand Up @@ -133,6 +134,9 @@ def test_float_accessor_validation_type():
FloatAccessorWidget(value=np.array([2, 3, 4]))
FloatAccessorWidget(value=np.array([2, 3, 4], dtype=np.float32))
FloatAccessorWidget(value=np.array([2, 3, 4], dtype=np.float64))
FloatAccessorWidget(value=pd.Series([2, 3, 4]))
FloatAccessorWidget(value=pd.Series([2, 3, 4], dtype=np.float32))
FloatAccessorWidget(value=pd.Series([2, 3, 4], dtype=np.float64))

# Must be floating-point array type
with pytest.raises(TraitError):
Expand Down

0 comments on commit a2b6913

Please sign in to comment.