Skip to content

Commit

Permalink
Fix inferring number of rows per chunk (#327)
Browse files Browse the repository at this point in the history
We use the `_rows_per_chunk` attribute to determine the number of rows
that are included in each Parquet chunk sent to the frontend. This is a
_layer-level_ construct because we need to ensure the main table and all
accessors have exactly the same chunking, because each chunk is rendered
independently as a separate deck.gl layer

We previously had issues where the number of rows per chunk was either
not the same across all data objects within a layer, or alternatively a
few cases (as with the ArcLayer) where we were accidentally initializing
the number of rows per chunk to be `0`, which let to an infinite loop in
`table.to_batches(max_chunksize=0)`.
  • Loading branch information
kylebarron authored Jan 26, 2024
1 parent 16d4c3c commit 7ab2d93
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 22 deletions.
44 changes: 28 additions & 16 deletions lonboard/_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


class BaseLayer(BaseWidget):
# Note: these are **not** serialized to JS
# Note: these class attributes are **not** serialized to JS
_bbox = Bbox()
_weighted_centroid = WeightedCentroid()

Expand Down Expand Up @@ -107,9 +107,6 @@ class BaseLayer(BaseWidget):
for an example.
"""

_rows_per_chunk = traitlets.Int()
"""Number of rows per chunk for serializing table and accessor columns."""


def default_geoarrow_viewport(
table: pa.Table
Expand Down Expand Up @@ -148,19 +145,36 @@ def default_geoarrow_viewport(


class BaseArrowLayer(BaseLayer):
"""Any Arrow-based layer should subclass from BaseArrowLayer"""

# Note: these class attributes are **not** serialized to JS

# Number of rows per chunk for serializing table and accessor columns.
#
# This is a _layer-level_ construct because we need to ensure the main table and all
# accessors have exactly the same chunking, because each chunk is rendered
# independently as a separate deck.gl layer
_rows_per_chunk: int

# The following traitlets **are** serialized to JS

table: traitlets.TraitType

def __init__(self, *, table: pa.Table, **kwargs):
def __init__(
self, *, table: pa.Table, _rows_per_chunk: Optional[int] = None, **kwargs
):
default_viewport = default_geoarrow_viewport(table)
if default_viewport is not None:
self._bbox = default_viewport[0]
self._weighted_centroid = default_viewport[1]

super().__init__(table=table, **kwargs)
rows_per_chunk = _rows_per_chunk or infer_rows_per_chunk(table)
if rows_per_chunk <= 0:
raise ValueError("Cannot serialize table with 0 rows per chunk.")

self._rows_per_chunk = rows_per_chunk

@traitlets.default("_rows_per_chunk")
def _default_rows_per_chunk(self):
return infer_rows_per_chunk(self.table)
super().__init__(table=table, **kwargs)

@classmethod
def from_geopandas(
Expand Down Expand Up @@ -1034,14 +1048,12 @@ class HeatmapLayer(BaseArrowLayer):
"""

_layer_type = traitlets.Unicode("heatmap").tag(sync=True)
def __init__(self, *args, table: pa.Table, **kwargs):
# NOTE: we override the default for _rows_per_chunk because otherwise we render
# one heatmap per _chunk_ not for the entire dataset.
super().__init__(*args, table=table, _rows_per_chunk=len(self.table), **kwargs)

# NOTE: we override the default for _rows_per_chunk because otherwise we render one
# heatmap per _chunk_ not for the entire dataset.
# TODO: on the JS side, rechunk the table into a single contiguous chunk.
@traitlets.default("_rows_per_chunk")
def _default_rows_per_chunk(self):
return len(self.table)
_layer_type = traitlets.Unicode("heatmap").tag(sync=True)

table = PyarrowTableTrait(allowed_geometry_types={EXTENSION_NAME.POINT})
"""A GeoArrow table with a Point column.
Expand Down
6 changes: 3 additions & 3 deletions lonboard/experimental/_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import traitlets

from lonboard._constants import EXTENSION_NAME
from lonboard._layer import BaseLayer
from lonboard._layer import BaseArrowLayer
from lonboard.experimental.traits import PointAccessor
from lonboard.traits import (
ColorAccessor,
Expand All @@ -14,7 +14,7 @@
)


class ArcLayer(BaseLayer):
class ArcLayer(BaseArrowLayer):
"""Render raised arcs joining pairs of source and target coordinates."""

_layer_type = traitlets.Unicode("arc").tag(sync=True)
Expand Down Expand Up @@ -135,7 +135,7 @@ def _validate_accessor_length(self, proposal):
return proposal["value"]


class TextLayer(BaseLayer):
class TextLayer(BaseArrowLayer):
"""Render text labels at given coordinates."""

_layer_type = traitlets.Unicode("text").tag(sync=True)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import pandas as pd
import pyarrow as pa
import pytest
from ipywidgets import Widget
from traitlets import TraitError

from lonboard._layer import BaseLayer
from lonboard.traits import ColorAccessor, FloatAccessor


class ColorAccessorWidget(Widget):
class ColorAccessorWidget(BaseLayer):
_rows_per_chunk = 2

color = ColorAccessor()
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_color_accessor_validation_string():
ColorAccessorWidget(color="#ff")


class FloatAccessorWidget(Widget):
class FloatAccessorWidget(BaseLayer):
_rows_per_chunk = 2

value = FloatAccessor()
Expand Down

0 comments on commit 7ab2d93

Please sign in to comment.