Skip to content

Commit

Permalink
wip: start supporting slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
jourdain committed Jan 25, 2025
1 parent 891cef4 commit 4edb1ab
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 106 deletions.
56 changes: 39 additions & 17 deletions pan3d/ui/preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,19 @@ def update_information(self, xr, available_arrays=None):
"value": f"[{xr[name].values[0]}, {xr[name].values[-1]}]",
}
)
elif len(shape) > 1:
attrs.append(
{
"key": "dims",
"value": f'({", ".join(xr[name].dims)})',
}
)
attrs.append(
{
"key": "shape",
"value": f'({", ".join([str(v) for v in xr[name].shape])})',
}
)
if name in data:
icon = "mdi-database"
order = 2
Expand Down Expand Up @@ -530,10 +543,11 @@ def __init__(self, source, update_rendering):
size="sm",
classes="mx-2",
)
v3.VDivider()

# Slice steps
with v3.VTooltip(text="Level Of Details / Slice stepping"):
with v3.VTooltip(
text="Level Of Details / Slice stepping", v_if="axis_names.length"
):
with html.Template(v_slot_activator="{ props }"):
with v3.VRow(
v_bind="props",
Expand All @@ -544,7 +558,7 @@ def __init__(self, source, update_rendering):
"mdi-stairs",
classes="ml-2 text-medium-emphasis",
)
with v3.VCol(classes="pa-0", v_if="axis_names?.[0]"):
with v3.VCol(classes="pa-0", v_if="axis_names.length > 0"):
v3.VTextField(
v_model_number=("slice_x_step", 1),
hide_details=True,
Expand All @@ -555,7 +569,7 @@ def __init__(self, source, update_rendering):
raw_attrs=['min="1"'],
type="number",
)
with v3.VCol(classes="pa-0", v_if="axis_names?.[1]"):
with v3.VCol(classes="pa-0", v_if="axis_names.length > 1"):
v3.VTextField(
v_model_number=("slice_y_step", 1),
hide_details=True,
Expand All @@ -566,7 +580,7 @@ def __init__(self, source, update_rendering):
raw_attrs=['min="1"'],
type="number",
)
with v3.VCol(classes="pa-0", v_if="axis_names?.[2]"):
with v3.VCol(classes="pa-0", v_if="axis_names.length > 2"):
v3.VTextField(
v_model_number=("slice_z_step", 1),
hide_details=True,
Expand Down Expand Up @@ -730,7 +744,7 @@ def update_from_source(self, source=None):
self.state.data_arrays_available = source.available_arrays
self.state.data_arrays = source.arrays
self.state.color_by = None
self.state.axis_names = [source.x, source.y, source.z]
self.state.axis_names = []
self.state.slice_extents = source.slice_extents
self.state.projection_mode = (
"spherical"
Expand All @@ -740,16 +754,19 @@ def update_from_source(self, source=None):
self.state.spherical_bias = source.vertical_bias
self.state.spherical_scale = source.vertical_scale
slices = source.slices
for axis in XYZ:
for idx, name in enumerate(self.state.slice_extents):
axis = XYZ[idx]
self.state.axis_names.append(name)
self.state.dirty("axis_names")
# default
axis_extent = self.state.slice_extents.get(getattr(source, axis))
axis_extent = self.state.slice_extents.get(name)
self.state[f"slice_{axis}_range"] = axis_extent
self.state[f"slice_{axis}_cut"] = 0
self.state[f"slice_{axis}_step"] = 1
self.state[f"slice_{axis}_type"] = "range"

# use slice info if available
axis_slice = slices.get(getattr(source, axis))
axis_slice = slices.get(name)
if axis_slice is not None:
if isinstance(axis_slice, int):
# cut
Expand Down Expand Up @@ -862,19 +879,24 @@ def _on_array_selection(self, data_arrays, **_):
elif len(data_arrays) == 0:
self.state.color_by = None

self.source.arrays = data_arrays
if set(self.source.arrays) != set(data_arrays):
self.source.arrays = data_arrays
self.update_from_source(self.source)

@change("spherical_bias", "spherical_scale", "projection_mode")
def _on_projection_change(
self, spherical_bias, spherical_scale, projection_mode, **_
):
self.source.projection = (
Projection.SPHERICAL
if projection_mode == "spherical"
else Projection.EUCLIDEAN
)
self.source.vertical_bias = spherical_bias
self.source.vertical_scale = spherical_scale
if projection_mode == "spherical":
self.source.projection = Projection.SPHERICAL
self.source.vertical_bias = spherical_bias
self.source.vertical_scale = spherical_scale
else:
self.source.projection = Projection.EUCLIDEAN
self.source.vertical_bias = 0
self.source.vertical_scale = 1

self.ctrl.view_reset_camera()


class ControlPanel(v3.VCard):
Expand Down
31 changes: 31 additions & 0 deletions pan3d/xarray/cf/coords/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,33 @@
import numpy as np


def to_isel(slices_info, *array_names):
slices = {}
for name in array_names:
if name is None:
continue

info = slices_info.get(name)
if info is None:
continue
if isinstance(info, int):
slices[name] = info
else:
start, stop, step = info
stop -= (stop - start) % step
slices[name] = slice(start, stop, step)

return slices if slices else None


def slice_array(array_name, dataset, slice_info):
if array_name is None:
return np.zeros(1, dtype=np.float32)
array = dataset[array_name]
dims = array.dims
return array.isel(to_isel(slice_info, *dims)).values


def extract_uniform_info(array):
origin = float(array[0])
spacing = (float(array[-1]) - origin) / (array.size - 1)
Expand All @@ -21,6 +48,10 @@ def is_uniform(array):


def cell_center_to_point(in_array):
if in_array.size == 1:
print("size 1")
return [float(in_array)]

uniform_data = extract_uniform_info(in_array)
if uniform_data is not None:
origin, spacing, size = uniform_data
Expand Down
2 changes: 1 addition & 1 deletion pan3d/xarray/cf/coords/index_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, xr_dataset, in_dims, out_name):
name_to_ijk = {in_dims[-(i + 1)]: "ijk"[i] for i in range(len(in_dims))}
out_dims = xr_dataset[out_name].dims
map_method_name = "".join([name_to_ijk[name] for name in out_dims])
print(out_name, "=>", map_method_name)
# print(out_name, "=>", map_method_name)

setattr(self, "fn", getattr(self, map_method_name))

Expand Down
62 changes: 41 additions & 21 deletions pan3d/xarray/cf/coords/meta.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
import numpy as np
from pan3d.xarray.cf import mesh
from pan3d.xarray.cf.coords.convert import is_uniform
from pan3d.xarray.cf.coords.convert import is_uniform, slice_array
from pan3d.xarray.cf.constants import Projection
from vtkmodules.vtkCommonDataModel import (
vtkImageData,
Expand Down Expand Up @@ -412,17 +412,34 @@ def timeless_dimensions(self, field):
dims = self.xr_dataset[field].dims
return dims[1:] if dims[0] == self.time else dims

def field_extent(self, field):
def dims_extent(self, dimensions, slices=None):
extent = [0, 0, 0, 0, 0, 0]
dimensions = self.timeless_dimensions(field)

if slices is None:
slices = {}

for idx in range(len(dimensions)):
array = self.xr_dataset[dimensions[-(1 + idx)]]
name = dimensions[-(1 + idx)]
array = self.xr_dataset[name]
# Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size]
# And extent include both index so (len-1)
extent[idx * 2 + 1] = array.size - 1
if name in slices:
slice_info = slices[name]
if isinstance(slice_info, int):
# size of 1
pass
else:
size = int((slice_info[1] - slice_info[0]) / slice_info[2])
extent[idx * 2 + 1] = size - 1
else:
extent[idx * 2 + 1] = array.size - 1

return extent

def field_extent(self, field, slices=None):
dimensions = self.timeless_dimensions(field)
return self.dims_extent(dimensions, slices)

def get_vtk_mesh_type(self, projection, fields=None):
fields = self.compatible_fields(fields)

Expand Down Expand Up @@ -450,7 +467,7 @@ def get_vtk_mesh_type(self, projection, fields=None):
# imagedata
return vtkImageData()

def get_vtk_whole_extent(self, projection, fields=None):
def get_vtk_whole_extent(self, projection, fields=None, slices=None):
if self.longitude is None or self.latitude is None or not fields:
return [
0,
Expand All @@ -463,11 +480,8 @@ def get_vtk_whole_extent(self, projection, fields=None):

mesh_type = self.get_vtk_mesh_type(projection, fields)
fields = self.compatible_fields(fields)
extent = self.field_extent(fields[0])
dimensions = self.timeless_dimensions(fields[0])

print(f"before {extent=}")
print(f"class {mesh_type.GetClassName()}")
extent = self.dims_extent(dimensions, slices)

if mesh_type.IsA("vtkStructuredGrid") and not (
self.uniform_lat_lon and self.use_coords(dimensions)
Expand All @@ -480,11 +494,14 @@ def get_vtk_whole_extent(self, projection, fields=None):
if extent[i * 2 + 1] > 0:
extent[i * 2 + 1] += 1

print(f"after {extent=}")
print(f"Whole extent: {extent}")

return extent

def get_vtk_mesh(self, time_index=0, projection=None, fields=None):
def get_vtk_mesh(self, time_index=0, projection=None, fields=None, slices=None):
if slices is None:
slices = {}

vtk_mesh, data_location = None, None
if self.xr_dataset is None or not fields:
return vtk_mesh
Expand All @@ -505,15 +522,15 @@ def get_vtk_mesh(self, time_index=0, projection=None, fields=None):
# Unstructured
if len(data_dims_no_time) == 1:
vtk_mesh, data_location = mesh.unstructured.generate_mesh(
self, data_dims_no_time, time_index, spherical_proj
self, data_dims_no_time, time_index, spherical_proj, slices
)

# Structured
if vtk_mesh is None and (
self.coords_has_bounds or spherical_proj or not self.coords_1d
):
vtk_mesh, data_location = mesh.structured.generate_mesh(
self, data_dims_no_time, time_index, spherical_proj
self, data_dims_no_time, time_index, spherical_proj, slices
)

# This should only happen if we don't want spherical_proj
Expand All @@ -523,24 +540,27 @@ def get_vtk_mesh(self, time_index=0, projection=None, fields=None):
# Rectilinear
if vtk_mesh is None and not self.uniform_spacing:
vtk_mesh, data_location = mesh.rectilinear.generate_mesh(
self, data_dims_no_time, time_index
self, data_dims_no_time, time_index, slices
)

# Uniform
if vtk_mesh is None:
vtk_mesh, data_location = mesh.uniform.generate_mesh(
self, data_dims_no_time, time_index
self, data_dims_no_time, time_index, slices
)

# Add fields
if vtk_mesh:
container = getattr(vtk_mesh, data_location)
for field_name in fields:
field = (
self.xr_dataset[field_name][time_index].values
if self.time
else self.xr_dataset[field_name].values
)
field = slice_array(field_name, self.xr_dataset, slices)
# # FIXME to select slices
# print(f"fields: {slices=}")
# field = (
# self.xr_dataset[field_name][time_index].values
# if self.time
# else self.xr_dataset[field_name].values
# )
container[field_name] = field.ravel()
else:
print(" !!! No mesh for data")
Expand Down
1 change: 0 additions & 1 deletion pan3d/xarray/cf/coords/parametric_vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(self, formula, bias=0, scale=1):
self._fn = formula
self._bias = bias
self._scale = scale
print(f"{bias=} {scale=}")

def __call__(self, n=0, k=0, j=0, i=0):
return self._bias + self._scale * self._fn(n=n, k=k, j=j, i=i)
Expand Down
2 changes: 1 addition & 1 deletion pan3d/xarray/cf/mesh/rectilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ..coords.convert import cell_center_to_point


def generate_mesh(metadata, dimensions, time_index):
def generate_mesh(metadata, dimensions, time_index, slices):
data_location = "cell_data"
extent = [0, 0, 0, 0, 0, 0]
empty_coords = np.zeros((1,), dtype=np.double)
Expand Down
Loading

0 comments on commit 4edb1ab

Please sign in to comment.