Skip to content

Commit

Permalink
Infer default view state from data (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Oct 11, 2023
1 parent aad3fa8 commit 5d5ec7b
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 7 deletions.
17 changes: 16 additions & 1 deletion lonboard/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lonboard.constants import EPSG_4326, OGC_84
from lonboard.geoarrow.geopandas_interop import geopandas_to_geoarrow
from lonboard.traits import ColorAccessor, FloatAccessor, PyarrowTableTrait
from lonboard.viewport import compute_view

# bundler yields lonboard/static/{index.js,styles.css}
bundler_output_dir = Path(__file__).parent / "static"
Expand Down Expand Up @@ -46,8 +47,8 @@ class BaseLayer(AnyWidget):

class ScatterplotLayer(BaseLayer):
_esm = bundler_output_dir / "scatterplot-layer.js"

_layer_type = traitlets.Unicode("scatterplot").tag(sync=True)
_initial_view_state = traitlets.Dict().tag(sync=True)

table = PyarrowTableTrait(allowed_geometry_types={b"geoarrow.point"})

Expand Down Expand Up @@ -77,6 +78,10 @@ def from_geopandas(cls, gdf: gpd.GeoDataFrame, **kwargs) -> ScatterplotLayer:
table = geopandas_to_geoarrow(gdf)
return cls(table=table, **kwargs)

@traitlets.default("_initial_view_state")
def _default_initial_view_state(self):
return compute_view(self.table)

@traitlets.validate("get_radius")
def _validate_get_radius_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
Expand Down Expand Up @@ -121,6 +126,7 @@ def _validate_get_line_width_length(self, proposal):
class PathLayer(BaseLayer):
_esm = bundler_output_dir / "path-layer.js"
_layer_type = traitlets.Unicode("path").tag(sync=True)
_initial_view_state = traitlets.Dict().tag(sync=True)

table = PyarrowTableTrait(allowed_geometry_types={b"geoarrow.linestring"})

Expand All @@ -144,6 +150,10 @@ def from_geopandas(cls, gdf: gpd.GeoDataFrame, **kwargs) -> PathLayer:
table = geopandas_to_geoarrow(gdf)
return cls(table=table, **kwargs)

@traitlets.default("_initial_view_state")
def _default_initial_view_state(self):
return compute_view(self.table)

@traitlets.validate("get_color")
def _validate_get_color_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
Expand All @@ -160,6 +170,7 @@ def _validate_get_width_length(self, proposal):
class SolidPolygonLayer(BaseLayer):
_esm = bundler_output_dir / "solid-polygon-layer.js"
_layer_type = traitlets.Unicode("solid-polygon").tag(sync=True)
_initial_view_state = traitlets.Dict().tag(sync=True)

table = PyarrowTableTrait(allowed_geometry_types={b"geoarrow.polygon"})

Expand All @@ -180,6 +191,10 @@ def from_geopandas(cls, gdf: gpd.GeoDataFrame, **kwargs) -> SolidPolygonLayer:
table = geopandas_to_geoarrow(gdf)
return cls(table=table, **kwargs)

@traitlets.default("_initial_view_state")
def _default_initial_view_state(self):
return compute_view(self.table)

@traitlets.validate("get_elevation")
def _validate_get_elevation_length(self, proposal):
if isinstance(proposal["value"], (pa.ChunkedArray, pa.Array)):
Expand Down
24 changes: 24 additions & 0 deletions lonboard/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pyarrow as pa

GEOARROW_EXTENSION_TYPE_NAMES = {
b"geoarrow.point",
b"geoarrow.linestring",
b"geoarrow.polygon",
b"geoarrow.multipoint",
b"geoarrow.multilinestring",
b"geoarrow.multipolygon",
}


def get_geometry_column_index(schema: pa.Schema) -> int:
"""Get the positional index of the geometry column in a pyarrow Schema"""
for field_idx in range(len(schema)):
field_metadata = schema.field(field_idx).metadata
if (
field_metadata
and field_metadata.get(b"ARROW:extension:name")
in GEOARROW_EXTENSION_TYPE_NAMES
):
return field_idx

raise ValueError("No geometry column in table schema.")
200 changes: 200 additions & 0 deletions lonboard/viewport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""Helpers for viewport operations
This is partially derived from pydeck at
(https://github.com/visgl/deck.gl/blob/63728ecbdaa2f99811900ec3709e5df0f9f8d228/bindings/pydeck/pydeck/data_utils/viewport_helpers.py)
under the Apache 2 license.
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np
import pyarrow as pa

from lonboard.utils import get_geometry_column_index


@dataclass
class Bbox:
minx: float = math.inf
miny: float = math.inf
maxx: float = -math.inf
maxy: float = -math.inf

def update(self, other: Bbox):
if other.minx < self.minx:
self.minx = other.minx
if other.miny < self.miny:
self.miny = other.miny
if other.maxx > self.maxx:
self.maxx = other.maxx
if other.maxy > self.maxy:
self.maxy = other.maxy

def to_tuple(self) -> Tuple[float, float, float, float]:
return (self.minx, self.miny, self.maxx, self.maxy)


def geo_mean_overflow(iterable):
return np.exp(np.log(iterable).mean())


def geo_mean(iterable):
a = np.array(iterable)
return a.prod() ** (1.0 / len(a))


@dataclass
class WeightedCentroid:
# Existing average for x and y
x: Optional[float] = None
y: Optional[float] = None
num_items: int = 0

def update(self, coords: pa.FixedSizeListArray):
"""Update the average for x and y based on a new chunk of data
Note that this does not keep a cumulative sum due to precision concerns. Rather
it incrementally updates based on a delta, and never multiplies to large
constant values.
Note: this currently computes the mean weighted _per coordinate_ and not _per
geometry_.
"""
np_arr = coords.flatten().to_numpy().reshape(-1, coords.type.list_size)
new_chunk_len = np_arr.shape[0]

if self.x is None or self.y is None:
assert self.x is None and self.y is None and self.num_items == 0
self.x = np.mean(np_arr[:, 0])
self.y = np.mean(np_arr[:, 1])
self.num_items = new_chunk_len
return

existing_modifier = self.num_items / (self.num_items + new_chunk_len)
new_chunk_modifier = new_chunk_len / (self.num_items + new_chunk_len)

new_chunk_avg_x = np.mean(np_arr[:, 0])
new_chunk_avg_y = np.mean(np_arr[:, 0])

existing_x_avg = self.x
existing_y_avg = self.y

self.x = (
existing_x_avg * existing_modifier + new_chunk_avg_x * new_chunk_modifier
)
self.y = (
existing_y_avg * existing_modifier + new_chunk_avg_y * new_chunk_modifier
)
self.num_items += new_chunk_len


def get_bbox_center(table: pa.Table) -> Tuple[Bbox, WeightedCentroid]:
"""Get the bounding box and geometric (weighted) center of the geometries in the
table."""
geom_col_idx = get_geometry_column_index(table.schema)
geom_col = table.column(geom_col_idx)
extension_type_name = table.schema.field(geom_col_idx).metadata[
b"ARROW:extension:name"
]

if extension_type_name == b"geoarrow.point":
return _get_bbox_center_nest_0(geom_col)

if extension_type_name in [b"geoarrow.linestring", b"geoarrow.multipoint"]:
return _get_bbox_center_nest_1(geom_col)

if extension_type_name in [b"geoarrow.polygon", b"geoarrow.multilinestring"]:
return _get_bbox_center_nest_2(geom_col)

if extension_type_name == b"geoarrow.multipolygon":
return _get_bbox_center_nest_3(geom_col)

assert False


def _coords_bbox(arr: pa.FixedSizeListArray) -> Bbox:
np_arr = arr.flatten().to_numpy().reshape(-1, arr.type.list_size)
min_vals = np.min(np_arr, axis=0)
max_vals = np.max(np_arr, axis=0)
return Bbox(minx=min_vals[0], miny=min_vals[1], maxx=max_vals[0], maxy=max_vals[1])


def _get_bbox_center_nest_0(column: pa.ChunkedArray) -> Tuple[Bbox, WeightedCentroid]:
bbox = Bbox()
centroid = WeightedCentroid()
for chunk in column.chunks:
coords = chunk
bbox.update(_coords_bbox(coords))
centroid.update(coords)

return (bbox, centroid)


def _get_bbox_center_nest_1(column: pa.ChunkedArray) -> Tuple[Bbox, WeightedCentroid]:
bbox = Bbox()
centroid = WeightedCentroid()
for chunk in column.chunks:
coords = chunk.flatten()
bbox.update(_coords_bbox(coords))
centroid.update(coords)

return (bbox, centroid)


def _get_bbox_center_nest_2(column: pa.ChunkedArray) -> Tuple[Bbox, WeightedCentroid]:
bbox = Bbox()
centroid = WeightedCentroid()
for chunk in column.chunks:
coords = chunk.flatten().flatten()
bbox.update(_coords_bbox(coords))
centroid.update(coords)

return (bbox, centroid)


def _get_bbox_center_nest_3(column: pa.ChunkedArray) -> Tuple[Bbox, WeightedCentroid]:
bbox = Bbox()
centroid = WeightedCentroid()
for chunk in column.chunks:
coords = chunk.flatten().flatten().flatten()
bbox.update(_coords_bbox(coords))
centroid.update(coords)

return (bbox, centroid)


def bbox_to_zoom_level(bbox: Bbox) -> int:
"""Computes the zoom level of a bounding box
This is copied from pydeck: https://github.com/visgl/deck.gl/blob/63728ecbdaa2f99811900ec3709e5df0f9f8d228/bindings/pydeck/pydeck/data_utils/viewport_helpers.py#L125C1-L149C22
Returns:
Zoom level of map in a WGS84 Mercator projection
"""
lat_diff = max(bbox.miny, bbox.maxy) - min(bbox.miny, bbox.maxy)
lng_diff = max(bbox.minx, bbox.maxx) - min(bbox.minx, bbox.maxx)

max_diff = max(lng_diff, lat_diff)
zoom_level = None
if max_diff < (360.0 / math.pow(2, 20)):
zoom_level = 21
else:
zoom_level = int(
-1
* ((math.log(max_diff) / math.log(2.0)) - (math.log(360.0) / math.log(2)))
)
if zoom_level < 1:
zoom_level = 1

return zoom_level


def compute_view(table: pa.Table):
"""Automatically computes a view state for the data passed in."""
bbox, center = get_bbox_center(table)
zoom = bbox_to_zoom_level(bbox)
return {"longitude": center.x, "latitude": center.y, "zoom": zoom}
11 changes: 9 additions & 2 deletions src/path-layer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { GeoArrowPathLayer } from "@geoarrow/deck.gl-layers";
import { useParquetWasm } from "./parquet";
import { useAccessorState, useTableBufferState } from "./accessor";

const INITIAL_VIEW_STATE = {
const DEFAULT_INITIAL_VIEW_STATE = {
latitude: 10,
longitude: 0,
zoom: 0.5,
Expand All @@ -20,6 +20,7 @@ const MAP_STYLE =
function App() {
const [wasmReady] = useParquetWasm();

let [viewState] = useModelState<DataView>("_initial_view_state");
let [dataRaw] = useModelState<DataView>("table");
let [widthUnits] = useModelState("width_units");
let [widthScale] = useModelState("width_scale");
Expand Down Expand Up @@ -60,7 +61,13 @@ function App() {
return (
<div style={{ height: 500 }}>
<DeckGL
initialViewState={INITIAL_VIEW_STATE}
initialViewState={
["longitude", "latitude", "zoom"].every((key) =>
Object.keys(viewState).includes(key)
)
? viewState
: DEFAULT_INITIAL_VIEW_STATE
}
controller={true}
layers={layers}
// ContextProvider={MapContext.Provider}
Expand Down
11 changes: 9 additions & 2 deletions src/scatterplot-layer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { GeoArrowScatterplotLayer } from "@geoarrow/deck.gl-layers";
import { useParquetWasm } from "./parquet";
import { useAccessorState, useTableBufferState } from "./accessor";

const INITIAL_VIEW_STATE = {
const DEFAULT_INITIAL_VIEW_STATE = {
latitude: 10,
longitude: 0,
zoom: 0.5,
Expand All @@ -20,6 +20,7 @@ const MAP_STYLE =
function App() {
const [wasmReady] = useParquetWasm();

let [viewState] = useModelState<DataView>("_initial_view_state");
let [dataRaw] = useModelState<DataView>("table");
let [radiusUnits] = useModelState("radius_units");
let [radiusScale] = useModelState("radius_scale");
Expand Down Expand Up @@ -74,7 +75,13 @@ function App() {
return (
<div style={{ height: 500 }}>
<DeckGL
initialViewState={INITIAL_VIEW_STATE}
initialViewState={
["longitude", "latitude", "zoom"].every((key) =>
Object.keys(viewState).includes(key)
)
? viewState
: DEFAULT_INITIAL_VIEW_STATE
}
controller={true}
layers={layers}
// ContextProvider={MapContext.Provider}
Expand Down
11 changes: 9 additions & 2 deletions src/solid-polygon-layer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { GeoArrowSolidPolygonLayer } from "@geoarrow/deck.gl-layers";
import { useParquetWasm } from "./parquet";
import { useAccessorState, useTableBufferState } from "./accessor";

const INITIAL_VIEW_STATE = {
const DEFAULT_INITIAL_VIEW_STATE = {
latitude: 10,
longitude: 0,
zoom: 0.5,
Expand All @@ -20,6 +20,7 @@ const MAP_STYLE =
function App() {
const [wasmReady] = useParquetWasm();

let [viewState] = useModelState<DataView>("_initial_view_state");
let [dataRaw] = useModelState<DataView>("table");
let [filled] = useModelState("filled");
let [extruded] = useModelState("extruded");
Expand Down Expand Up @@ -54,7 +55,13 @@ function App() {
return (
<div style={{ height: 500 }}>
<DeckGL
initialViewState={INITIAL_VIEW_STATE}
initialViewState={
["longitude", "latitude", "zoom"].every((key) =>
Object.keys(viewState).includes(key)
)
? viewState
: DEFAULT_INITIAL_VIEW_STATE
}
controller={true}
layers={layers}
// ContextProvider={MapContext.Provider}
Expand Down

0 comments on commit 5d5ec7b

Please sign in to comment.