Skip to content

Commit

Permalink
Validate that accessors have same length as table (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Oct 10, 2023
1 parent e653254 commit 7a9710f
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 7 deletions.
80 changes: 76 additions & 4 deletions lonboard/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path

import geopandas as gpd
import pyarrow as pa
import traitlets
from anywidget import AnyWidget

Expand Down Expand Up @@ -70,10 +71,45 @@ def from_geopandas(cls, gdf: gpd.GeoDataFrame, **kwargs) -> ScatterplotLayer:
table = geopandas_to_geoarrow(gdf)
return cls(table=table, **kwargs)

# TODO: validate that arrays have alignable-dimensions (e.g. length) with main table
# @validate("get_radius")
# def _validate_get_radius_length(self, proposal):
# if proposal["value"]
@traitlets.validate("get_radius")
def _validate_get_radius_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
if len(proposal["value"]) != len(self.table):
raise traitlets.TraitError(
"`get_radius` must have same length as table"
)

return proposal["value"]

@traitlets.validate("get_fill_color")
def _validate_get_fill_color_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
if len(proposal["value"]) != len(self.table):
raise traitlets.TraitError(
"`get_fill_color` must have same length as table"
)

return proposal["value"]

@traitlets.validate("get_line_color")
def _validate_get_line_color_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
if len(proposal["value"]) != len(self.table):
raise traitlets.TraitError(
"`get_line_color` must have same length as table"
)

return proposal["value"]

@traitlets.validate("get_line_width")
def _validate_get_line_width_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
if len(proposal["value"]) != len(self.table):
raise traitlets.TraitError(
"`get_line_width` must have same length as table"
)

return proposal["value"]


class PathLayer(BaseLayer):
Expand All @@ -98,6 +134,18 @@ def from_geopandas(cls, gdf: gpd.GeoDataFrame, **kwargs) -> PathLayer:
table = geopandas_to_geoarrow(gdf)
return cls(table=table, **kwargs)

@traitlets.validate("get_color")
def _validate_get_color_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
if len(proposal["value"]) != len(self.table):
raise traitlets.TraitError("`get_color` must have same length as table")

@traitlets.validate("get_width")
def _validate_get_width_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
if len(proposal["value"]) != len(self.table):
raise traitlets.TraitError("`get_width` must have same length as table")


class SolidPolygonLayer(BaseLayer):
_esm = bundler_output_dir / "solid-polygon-layer.js"
Expand All @@ -117,3 +165,27 @@ class SolidPolygonLayer(BaseLayer):
def from_geopandas(cls, gdf: gpd.GeoDataFrame, **kwargs) -> SolidPolygonLayer:
table = geopandas_to_geoarrow(gdf)
return cls(table=table, **kwargs)

@traitlets.validate("get_elevation")
def _validate_get_elevation_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
if len(proposal["value"]) != len(self.table):
raise traitlets.TraitError(
"`get_elevation` must have same length as table"
)

@traitlets.validate("get_fill_color")
def _validate_get_fill_color_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
if len(proposal["value"]) != len(self.table):
raise traitlets.TraitError(
"`get_fill_color` must have same length as table"
)

@traitlets.validate("get_line_color")
def _validate_get_line_color_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
if len(proposal["value"]) != len(self.table):
raise traitlets.TraitError(
"`get_line_color` must have same length as table"
)
5 changes: 2 additions & 3 deletions lonboard/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
import pyarrow as pa
import traitlets
from numpy.typing import NDArray
from traitlets.traitlets import TraitType
from typing_extensions import Self

Expand Down Expand Up @@ -92,7 +91,7 @@ def __init__(
# TODO: subclass self.error so that `info` is actually printed?
def validate(
self, obj, value
) -> Union[Tuple[int, ...], List[int], pa.FixedSizeListArray]:
) -> Union[Tuple[int, ...], List[int], pa.ChunkedArray, pa.FixedSizeListArray]:
if isinstance(value, (tuple, list)):
if len(value) < 3 or len(value) > 4:
self.error(
Expand Down Expand Up @@ -180,7 +179,7 @@ def __init__(
self.tag(sync=True, **FLOAT_SERIALIZATION)

# TODO: subclass self.error so that `info` is actually printed?
def validate(self, obj, value) -> Union[float, NDArray[np.float32]]:
def validate(self, obj, value) -> Union[float, pa.ChunkedArray, pa.DoubleArray]:
if isinstance(value, (int, float)):
return float(value)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import geopandas as gpd
import numpy as np
import pytest
import shapely
from traitlets import TraitError

from lonboard.layer import ScatterplotLayer


def test_accessor_length_validation():
"""Accessor length must match table length"""
points = shapely.points([1, 2], [3, 4])
gdf = gpd.GeoDataFrame(geometry=points)

with pytest.raises(TraitError):
_layer = ScatterplotLayer.from_geopandas(gdf, get_radius=np.array([1]))

with pytest.raises(TraitError):
_layer = ScatterplotLayer.from_geopandas(gdf, get_radius=np.array([1, 2, 3]))

_layer = ScatterplotLayer.from_geopandas(gdf, get_radius=np.array([1, 2]))

0 comments on commit 7a9710f

Please sign in to comment.