diff --git a/lonboard/_constants.py b/lonboard/_constants.py index 4899b193..cff9ed47 100644 --- a/lonboard/_constants.py +++ b/lonboard/_constants.py @@ -2,6 +2,14 @@ import pyproj +# Minimum integer representable in a float32 +# https://stackoverflow.com/a/3793950 +MIN_INTEGER_FLOAT32 = -16777216 + +# Maximum integer representable in a float32 +# https://stackoverflow.com/a/3793950 +MAX_INTEGER_FLOAT32 = 16777216 + EPSG_4326 = pyproj.CRS("epsg:4326") # In pyodide, the pyproj PROJ data directory is much smaller, and it currently diff --git a/lonboard/experimental/traits.py b/lonboard/experimental/traits.py index ad3c86e7..79826e58 100644 --- a/lonboard/experimental/traits.py +++ b/lonboard/experimental/traits.py @@ -1,10 +1,22 @@ from __future__ import annotations +import math +import warnings from typing import TYPE_CHECKING, Any -from arro3.core import Array, ChunkedArray, DataType +import arro3.compute as ac +from arro3.core import ( + Array, + ChunkedArray, + DataType, + Scalar, + list_array, + list_flatten, + list_offsets, +) from traitlets.traitlets import TraitType +from lonboard._constants import MAX_INTEGER_FLOAT32, MIN_INTEGER_FLOAT32 from lonboard._serialization import ACCESSOR_SERIALIZATION from lonboard.traits import FixedErrorTraitType @@ -15,10 +27,18 @@ class TimestampAccessor(FixedErrorTraitType): """A representation of a deck.gl coordinate-timestamp accessor. - - A pyarrow [`ListArray`][pyarrow.ListArray] containing either a numeric array. Each - value in the array will be used as the value for the object at the same row - index. - - Any Arrow list array from a library that implements the [Arrow PyCapsule + deck.gl handles timestamps on the GPU as float32 values. This class will validate + that the input timestamps are representable as float32 integers, and will + automatically reduce the precision of input data if necessary to fit inside a + float32. + + Accepted input includes: + + - A pyarrow [`ListArray`][pyarrow.ListArray] containing a temporal array such as a + TimestampArray. Each value in the array will be used as the value for the object + at the same row index. + - Any Arrow list array containing a temporal array from a library that implements + the [Arrow PyCapsule Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html). """ @@ -33,6 +53,117 @@ def __init__( super().__init__(*args, **kwargs) self.tag(sync=True, **ACCESSOR_SERIALIZATION) + def reduce_precision( + self, obj: BaseArrowLayer, value: ChunkedArray + ) -> ChunkedArray: + # First, find the "spread" of existing values: the range between min and max of + # the existing timestamps. + min_val = ac.min(list_flatten(value)) + max_val = ac.max(list_flatten(value)) + + # deck.gl stores timestamps as float32. Therefore, we need to validate that the + # actual range of our data fits into the float32 integer range. + actual_spread = ac.sub(max_val, min_val).cast(DataType.int64())[0].as_py() + acceptable_spread = MAX_INTEGER_FLOAT32 - MIN_INTEGER_FLOAT32 + + # If if already fits into the float32 integer range, we're done. + if actual_spread <= acceptable_spread: + return value + + # Otherwise, we need to reduce the precision of our data. We can do this by + # changing the time unit of our data. So if our timestamps start with nanosecond + # precision, we can downcast our data to microsecond, millisecond, or second + # precision. + # First, we want to figure out how many digits of precision to remove. + digits_to_remove = math.log10(actual_spread / acceptable_spread) + + # Access the existing time unit and timezone + value_type = value.type.value_type + assert value_type is not None + time_unit = value_type.time_unit + tz = value_type.tz + + # We can only reduce precision in orders of 1,000. + if digits_to_remove <= 3: + digits_to_remove = 3 + elif digits_to_remove <= 6: + digits_to_remove = 6 + elif digits_to_remove <= 9: + digits_to_remove = 9 + else: + # Cause the exception to be raised in the next block + digits_to_remove = 999999 + + # Validate that the precision reduction is possible. E.g. we can't reduce the + # precision of second timestamps because we have no way to represent timestamps + # larger than a second. + if ( + (time_unit == "ns" and digits_to_remove > 9) + or (time_unit == "us" and digits_to_remove > 6) + or (time_unit == "ms" and digits_to_remove > 3) + or (time_unit == "s") + ): + self.error( + obj, + value, + info=( + "The range of the timestamp column cannot is larger than what " + "can be represented at second precision in a float32. deck.gl " + "uses float32 for the timestamps in the TripsLayer. " + "Choose a smaller temporal subset of data." + ), + ) + + # Figure out the new data type given the existing data type and the number of + # digits of precision we're removing. + if time_unit == "ns": + if digits_to_remove == 3: + new_data_type = DataType.timestamp("us", tz=tz) + elif digits_to_remove == 6: + new_data_type = DataType.timestamp("ms", tz=tz) + elif digits_to_remove == 9: + new_data_type = DataType.timestamp("s", tz=tz) + else: + assert False + elif time_unit == "us": + if digits_to_remove == 3: + new_data_type = DataType.timestamp("ms", tz=tz) + elif digits_to_remove == 6: + new_data_type = DataType.timestamp("s", tz=tz) + else: + assert False + elif time_unit == "ms": + if digits_to_remove == 3: + new_data_type = DataType.timestamp("s", tz=tz) + else: + assert False + else: + assert False + + new_time_unit = new_data_type.time_unit + warnings.warn( + f"Reducing precision of input timestamp data to '{new_time_unit}'" + " to fit into available GPU precision." + ) + + # Actually reduce the precision of each chunk of the input data, assigning the + # new data type + offsets_reader = list_offsets(value) + inner_values_reader = list_flatten(value) + + divisor = Scalar(int(math.pow(10, digits_to_remove)), type=DataType.int64()) + + reduced_precision_chunks = [] + for offsets, inner_values in zip(offsets_reader, inner_values_reader): + reduced_precision_values = ac.div( + inner_values.cast(DataType.int64()), divisor + ) + reduced_precision_chunks.append( + list_array(offsets, reduced_precision_values.cast(new_data_type)) + ) + + return ChunkedArray(reduced_precision_chunks) + def validate(self, obj: BaseArrowLayer, value) -> ChunkedArray: if hasattr(value, "__arrow_c_array__"): value = ChunkedArray([Array.from_arrow(value)]) @@ -46,4 +177,10 @@ def validate(self, obj: BaseArrowLayer, value) -> ChunkedArray: if not DataType.is_list(value.type): self.error(obj, value, info="timestamp array to be a list-type array") + value_type = value.type.value_type + assert value_type is not None + if not DataType.is_temporal(value_type): + self.error(obj, value, info="timestamp array to have a temporal child.") + + value = self.reduce_precision(obj, value) return value.rechunk(max_chunksize=obj._rows_per_chunk)