diff --git a/lonboard/layer.py b/lonboard/layer.py index ccfd9817..5daa849d 100644 --- a/lonboard/layer.py +++ b/lonboard/layer.py @@ -3,6 +3,7 @@ from pathlib import Path import geopandas as gpd +import pyarrow as pa import traitlets from anywidget import AnyWidget @@ -70,10 +71,45 @@ def from_geopandas(cls, gdf: gpd.GeoDataFrame, **kwargs) -> ScatterplotLayer: table = geopandas_to_geoarrow(gdf) return cls(table=table, **kwargs) - # TODO: validate that arrays have alignable-dimensions (e.g. length) with main table - # @validate("get_radius") - # def _validate_get_radius_length(self, proposal): - # if proposal["value"] + @traitlets.validate("get_radius") + def _validate_get_radius_length(self, proposal): + if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)): + if len(proposal["value"]) != len(self.table): + raise traitlets.TraitError( + "`get_radius` must have same length as table" + ) + + return proposal["value"] + + @traitlets.validate("get_fill_color") + def _validate_get_fill_color_length(self, proposal): + if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)): + if len(proposal["value"]) != len(self.table): + raise traitlets.TraitError( + "`get_fill_color` must have same length as table" + ) + + return proposal["value"] + + @traitlets.validate("get_line_color") + def _validate_get_line_color_length(self, proposal): + if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)): + if len(proposal["value"]) != len(self.table): + raise traitlets.TraitError( + "`get_line_color` must have same length as table" + ) + + return proposal["value"] + + @traitlets.validate("get_line_width") + def _validate_get_line_width_length(self, proposal): + if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)): + if len(proposal["value"]) != len(self.table): + raise traitlets.TraitError( + "`get_line_width` must have same length as table" + ) + + return proposal["value"] class PathLayer(BaseLayer): @@ -98,6 +134,18 @@ def from_geopandas(cls, gdf: gpd.GeoDataFrame, **kwargs) -> PathLayer: table = geopandas_to_geoarrow(gdf) return cls(table=table, **kwargs) + @traitlets.validate("get_color") + def _validate_get_color_length(self, proposal): + if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)): + if len(proposal["value"]) != len(self.table): + raise traitlets.TraitError("`get_color` must have same length as table") + + @traitlets.validate("get_width") + def _validate_get_width_length(self, proposal): + if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)): + if len(proposal["value"]) != len(self.table): + raise traitlets.TraitError("`get_width` must have same length as table") + class SolidPolygonLayer(BaseLayer): _esm = bundler_output_dir / "solid-polygon-layer.js" @@ -117,3 +165,27 @@ class SolidPolygonLayer(BaseLayer): def from_geopandas(cls, gdf: gpd.GeoDataFrame, **kwargs) -> SolidPolygonLayer: table = geopandas_to_geoarrow(gdf) return cls(table=table, **kwargs) + + @traitlets.validate("get_elevation") + def _validate_get_elevation_length(self, proposal): + if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)): + if len(proposal["value"]) != len(self.table): + raise traitlets.TraitError( + "`get_elevation` must have same length as table" + ) + + @traitlets.validate("get_fill_color") + def _validate_get_fill_color_length(self, proposal): + if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)): + if len(proposal["value"]) != len(self.table): + raise traitlets.TraitError( + "`get_fill_color` must have same length as table" + ) + + @traitlets.validate("get_line_color") + def _validate_get_line_color_length(self, proposal): + if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)): + if len(proposal["value"]) != len(self.table): + raise traitlets.TraitError( + "`get_line_color` must have same length as table" + ) diff --git a/lonboard/traits.py b/lonboard/traits.py index f2abd57c..c0bac8e7 100644 --- a/lonboard/traits.py +++ b/lonboard/traits.py @@ -11,7 +11,6 @@ import numpy as np import pyarrow as pa import traitlets -from numpy.typing import NDArray from traitlets.traitlets import TraitType from typing_extensions import Self @@ -92,7 +91,7 @@ def __init__( # TODO: subclass self.error so that `info` is actually printed? def validate( self, obj, value - ) -> Union[Tuple[int, ...], List[int], pa.FixedSizeListArray]: + ) -> Union[Tuple[int, ...], List[int], pa.ChunkedArray, pa.FixedSizeListArray]: if isinstance(value, (tuple, list)): if len(value) < 3 or len(value) > 4: self.error( @@ -180,7 +179,7 @@ def __init__( self.tag(sync=True, **FLOAT_SERIALIZATION) # TODO: subclass self.error so that `info` is actually printed? - def validate(self, obj, value) -> Union[float, NDArray[np.float32]]: + def validate(self, obj, value) -> Union[float, pa.ChunkedArray, pa.DoubleArray]: if isinstance(value, (int, float)): return float(value) diff --git a/tests/test_layer.py b/tests/test_layer.py new file mode 100644 index 00000000..2ee40d33 --- /dev/null +++ b/tests/test_layer.py @@ -0,0 +1,21 @@ +import geopandas as gpd +import numpy as np +import pytest +import shapely +from traitlets import TraitError + +from lonboard.layer import ScatterplotLayer + + +def test_accessor_length_validation(): + """Accessor length must match table length""" + points = shapely.points([1, 2], [3, 4]) + gdf = gpd.GeoDataFrame(geometry=points) + + with pytest.raises(TraitError): + _layer = ScatterplotLayer.from_geopandas(gdf, get_radius=np.array([1])) + + with pytest.raises(TraitError): + _layer = ScatterplotLayer.from_geopandas(gdf, get_radius=np.array([1, 2, 3])) + + _layer = ScatterplotLayer.from_geopandas(gdf, get_radius=np.array([1, 2]))