From 974e875760e9ff85f6661456b6bbf997c9cf291f Mon Sep 17 00:00:00 2001 From: Zhengui Wang Date: Thu, 3 Mar 2022 16:45:50 -0500 Subject: [PATCH 1/4] support more schism variables --- run.py | 1 - thalassa/__init__.py | 8 -- thalassa/api.py | 140 ++++++++++++++++-------------- thalassa/ui.py | 199 +++++++++++++++---------------------------- thalassa/utils.py | 103 ++++++++++++++++++---- 5 files changed, 234 insertions(+), 217 deletions(-) diff --git a/run.py b/run.py index 8f9dd32..f93d4ab 100644 --- a/run.py +++ b/run.py @@ -39,7 +39,6 @@ ui = thalassa.ThalassaUI( - display_variables=True, display_stations=True, ) diff --git a/thalassa/__init__.py b/thalassa/__init__.py index 591eb6c..4742e1e 100644 --- a/thalassa/__init__.py +++ b/thalassa/__init__.py @@ -1,20 +1,12 @@ from __future__ import annotations -from .api import get_elevation_dmap from .api import get_tiles -from .api import get_trimesh -from .api import get_wireframe from .api import get_timeseries from .ui import ThalassaUI -from .utils import open_dataset from .utils import reload __all__: list[str] = [ - "open_dataset", "reload", - "get_trimesh", "get_tiles", - "get_wireframe", - "get_elevation_dmap", "ThalassaUI", ] diff --git a/thalassa/api.py b/thalassa/api.py index e9338a2..b7b9a60 100644 --- a/thalassa/api.py +++ b/thalassa/api.py @@ -10,64 +10,73 @@ from holoviews.operation.datashader import rasterize from holoviews.streams import PointerXY,DoubleTap import numpy as np - -logger = logging.getLogger(__name__) - -# Load bokeh backend +from . import utils hv.extension("bokeh") - -def get_trimesh( - dataset: xr.Dataset, - longitude_var: str, - latitude_var: str, - elevation_var: str, - simplices_var: str, - time_var: str, - timestamp: str | pd.Timestamp, -) -> gv.TriMesh: - simplices = dataset[simplices_var].values - columns = [longitude_var, latitude_var, elevation_var] - if timestamp == "MAXIMUM": - points_df = dataset.max(time_var)[columns].to_dataframe() - elif timestamp == "MINIMUM": - points_df = dataset.min(time_var)[columns].to_dataframe() - else: - points_df = dataset.sel({time_var: timestamp})[columns].to_dataframe().drop(columns=time_var) - points_df = points_df.reset_index(drop=True) - points_gv = gv.Points(points_df, kdims=[longitude_var, latitude_var], vdims=elevation_var) - trimesh = gv.TriMesh((simplices, points_gv)) - return trimesh - +logger = logging.getLogger(__name__) def get_tiles() -> gv.Tiles: tiles = gv.WMTS("http://c.tile.openstreetmap.org/{Z}/{X}/{Y}.png") return tiles - -def get_wireframe(trimesh: gv.TriMesh) -> hv.Layout: - wireframe = dynspread(rasterize(trimesh.edgepaths, precompute=True)) - return wireframe - - -def get_elevation_dmap(trimesh: gv.TriMesh, show_grid: bool = False) -> hv.Overlay: - tiles = get_tiles() - elevation = rasterize(trimesh, precompute=True).opts( # pylint: disable=no-member - title="Elevation Forecast", - colorbar=True, - clabel="meters", - show_legend=True, - ) - logger.debug("show grid: %s", show_grid) - if show_grid: - overlay = tiles * elevation * get_wireframe(trimesh=trimesh) - else: - overlay = tiles * elevation - return overlay - -#---------------------------------------------------------------------------------------- -#time series -#---------------------------------------------------------------------------------------- +class MapData: + ''' + define a class to store data related to dynamic map + ''' + def __init__(self): + #dataset info + self.name = None + self.format = None + self.prj = None + #header info + self.dataset = None #file handle -> xr.Dataset + self.times = None + self.variables = None + #connectivity + self.x = None + self.y = None + self.elnode = None + #dataset snapshot + self.time = None + self.variable = None + self.data = None + self.grid = None + self.trimesh = None + self.trimap = None + self.tiles = get_tiles() + + def get_data(self,time,variable,layer): + ''' + extract a snapshot from dataset + ''' + self.time = time + self.variable = variable + self.layer = layer + tid=int(np.nonzero(np.array(self.times)==time)[0][0]) + self.data=utils.read_dataset(self.dataset,2,self.format,time=tid,variable=variable,layer=layer) + + def get_plot_map(self): + ''' + plot a snapshot: only SCHISM method is defined so far + ''' + if self.format=="SCHISM": + if self.x.min()<-360 or self.x.max()>360 or self.y.min()<-90 or self.y.max()>90: + raise ValueError(f"check dataset projection: abs(lat)>360 or abs(lon)>90") + df=pd.DataFrame({'longitude':self.x, 'latitude':self.y, 'data':self.data}) + pdf=gv.Points(df,kdims=['longitude','latitude'],vdims='data') + self.trimesh=gv.TriMesh((self.elnode,pdf)) + if self.grid is None: + self.grid=dynspread(rasterize(self.trimesh.edgepaths, precompute=True)) + self.trimap=rasterize(self.trimesh, precompute=True).opts( + title=f"SCHISM Forecast: {self.variable}", + colorbar=True, + clabel="meters", + cmap="jet", + show_legend=True, + ) + else: + raise ValueError(f"please define plot method for dataset format: {dataset_format}") + class TimeseriesData: ''' define a class to store data related to time series points @@ -77,32 +86,32 @@ def __init__(self): def clear(self): self.init=False -def extract_timeseries(x,y,sx,sy,data): +def extract_timeseries(x,y,sx,sy,dataset,variable): ''' function for extracting time series@(x,y) from data ''' dist=abs(sx+1j*sy-x-1j*y) mdist=dist.min() nid=np.nonzero(dist==mdist)[0][0] - mdata=data['elev'].data[:,nid].copy() + mdata=dataset[variable].data[:,nid].copy() return mdist,mdata -def add_remove_pts(x,y,data,dataset,fmt): +def add_remove_pts(x,y,data,dataset,fmt,variable): ''' function to dynamically add or remove pts by double clicking on the map ''' if fmt=='add pts': if len(data.xys)==0: - mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset) - hcurve=hv.Curve((data.time,mdata),'time','elevation').opts(tools=["hover"]) + mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable) + hcurve=hv.Curve((data.time,mdata),'time',variable).opts(tools=["hover"]) if mdist<=data.mdist: data.xys.append((x,y)) data.elev.append(mdata) data.curve.append(hcurve) else: if data.xys[-1][0]!=x and data.xys[-1][1]!=y: - mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset) - hcurve=hv.Curve((data.time,mdata),'time','elevation').opts(tools=["hover"]) + mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable) + hcurve=hv.Curve((data.time,mdata),'time',variable).opts(tools=["hover"]) if mdist<=data.mdist: data.xys.append((x,y)) data.elev.append(mdata) @@ -120,15 +129,20 @@ def add_remove_pts(x,y,data,dataset,fmt): else: pass -def get_timeseries(source,data,dataset,ymin,ymax,fmt): +def get_timeseries(MData,data,ymin,ymax,fmt): ''' get time series plots ''' + + source, dataset = MData.trimesh, MData.dataset + variable='elev' #todo: add an input for time series variable + #initialize timeseries_data if data.init is False: #find the maximum side length - x,y=dataset['SCHISM_hgrid_node_x'].data,dataset['SCHISM_hgrid_node_y'].data - e1,e2,e3=dataset['SCHISM_hgrid_face_nodes'].data.T + x,y=MData.x,MData.y #tmp fix, improve: todo + e1,e2,e3=MData.elnode.T + s1=abs((x[e1]-x[e2])+1j*(y[e1]-y[e2])).max() s2=abs((x[e2]-x[e3])+1j*(y[e2]-y[e3])).max() s3=abs((x[e3]-x[e1])+1j*(y[e3]-y[e1])).max() @@ -144,7 +158,7 @@ def get_timeseries(source,data,dataset,ymin,ymax,fmt): def get_plot_point(x,y): if None not in [x,y]: - add_remove_pts(x,y,data,dataset,fmt) + add_remove_pts(x,y,data,dataset,fmt,variable) if ((x is None) or (y is None)) and len(data.xys)==0: xys=[(data.x0,data.y0)] @@ -159,7 +173,7 @@ def get_plot_point(x,y): return hpoint*htext def get_plot_curve(x,y): - mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset) + mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable) if mdist>data.mdist: mdata=mdata*np.nan hdynamic=hv.Curve((data.time,mdata)).opts(color='k',line_width=2,line_dash='dotted') diff --git a/thalassa/ui.py b/thalassa/ui.py index 67e3c50..0a624d3 100644 --- a/thalassa/ui.py +++ b/thalassa/ui.py @@ -4,35 +4,28 @@ import glob import logging import os.path - import panel as pn import xarray as xr - +from pyproj import Transformer from . import api from . import utils - logger = logging.getLogger(__name__) - -DATA_DIR = "./data/" -DATA_GLOB = DATA_DIR + os.path.sep + "*" +DATA_DIR = "data" + os.path.sep + "*" # CSS Styles ERROR = {"border": "3px solid red"} INFO = {"border": "2px solid blue"} - # Help functions that log messages on stdout AND render them on the browser def info(msg: str) -> pn.Column: logger.info(msg) return pn.pane.Markdown(msg, style=INFO) - def error(msg: str) -> pn.Column: logger.error(msg) return pn.pane.Markdown(msg, style=ERROR) - class ThalassaUI: # pylint: disable=too-many-instance-attributes """ This UI is supposed to be used with a Bootstrap-like template supporting @@ -51,81 +44,54 @@ class ThalassaUI: # pylint: disable=too-many-instance-attributes These objects should be of `pn.Column` type. You can append """ - def __init__( - self, - display_variables: bool = True, - display_stations: bool = False, - ) -> None: - self._display_variables = display_variables + def __init__( self, display_stations: bool = False) -> None: self._display_stations = display_stations - # data variables - self._dataset: xr.Dataset - self._variables: list[str] - self._TimeseriesData=api.TimeseriesData() - self._timestamp='None' - - # UI components - self._main = pn.Column(info("## Please select a `dataset_file` and click on the `Render` button.")) + #UI components + self._main = pn.Column(error("## Please select a `dataset_file` and click on the `Render` button.")) self._sidebar = pn.Column() - ## Define widgets # noqa + #Define widgets self.dataset_file = pn.widgets.Select( - name="Dataset file", options=sorted(filter(utils.can_be_opened_by_xarray, glob.glob(DATA_GLOB))) + name="Dataset file", options=sorted(filter(utils.can_be_opened_by_xarray, glob.glob(DATA_DIR))), ) - # variables - self.longitude_var = pn.widgets.Select(name="Longitude") - self.latitude_var = pn.widgets.Select(name="Latitude") - self.elevation_var = pn.widgets.Select(name="Elevation") - self.simplices_var = pn.widgets.Select(name="Simplices") - self.time_var = pn.widgets.Select(name="Time") - # display options - self.timestamp = pn.widgets.Select(name="Timestamp") + self.dataset_format = pn.widgets.Select(name="Format",options=["SCHISM",]) + self.prj = pn.widgets.TextInput(value='epsg:4326',name="Projection") + self.time = pn.widgets.Select(name="Time") + self.variable = pn.widgets.Select(name="Variable") + self.layer = pn.widgets.Select(name="Layer",options=["surface","bottom"],value="surface") self.relative_colorbox = pn.widgets.Checkbox(name="Relative colorbox") - self.show_grid = pn.widgets.Checkbox(name="Show Grid") + self.show_grid = pn.widgets.Checkbox(name="Show Grid") + #time series - self.timeseries = pn.widgets.Checkbox(name="Time Series (double click)",width=150) - self.timeseries_pts=pn.widgets.RadioButtonGroup(options=['add pts','remove pts','clear'],width=300) - self.timeseries_ymin = pn.widgets.TextInput(value='-1.0',name="ymin",width=100) - self.timeseries_ymax = pn.widgets.TextInput(value='1.0',name="ymax",width=100) - # stations - self.stations_file = pn.widgets.Select(name="Stations file") - self.stations = pn.widgets.CrossSelector(name="Stations") + self.timeseries = pn.widgets.Checkbox(name="Time Series (double click)") + self.timeseries_pts = pn.widgets.RadioButtonGroup(options=['add pts','remove pts','clear']) + self.timeseries_ymin = pn.widgets.TextInput(value='-1.0',name="ymin") + self.timeseries_ymax = pn.widgets.TextInput(value='1.0',name="ymax") + + #stations + self.stations_file = pn.widgets.Select(name="Stations file") + self.stations = pn.widgets.CrossSelector(name="Stations") + # render button - self.render_button = pn.widgets.Button(name="Render", button_type="primary") + self.render_button = pn.widgets.Button(name="Render", button_type="primary") self._define_widget_callbacks() self._populate_widgets() self._setup_ui() def _setup_ui(self) -> None: - self._sidebar.append(pn.Accordion(("Input Files", pn.WidgetBox(self.dataset_file)), active=[0])) - if self._display_variables: - self._sidebar.append( - pn.Accordion( - ( - "Variables", - pn.WidgetBox( - self.longitude_var, - self.latitude_var, - self.elevation_var, - self.simplices_var, - self.time_var, - ), - ) - ) - ) self._sidebar.append( pn.Accordion( - ("Display Options", pn.WidgetBox(self.timestamp, self.relative_colorbox, - self.show_grid,)), + ("Input Files", pn.WidgetBox(self.dataset_file, pn.Row(self.dataset_format,self.prj), + self.time, pn.Row(self.variable,self.layer), pn.Row(self.relative_colorbox,self.show_grid),)), active=[0], ), ) self._sidebar.append( pn.Accordion( ("Time Series", pn.WidgetBox(self.timeseries, - pn.Row(self.timeseries_ymin,self.timeseries_ymax),self.timeseries_pts,)), + pn.Row(self.timeseries_ymin, self.timeseries_ymax), self.timeseries_pts,)), active=[0], ), ) @@ -137,34 +103,11 @@ def _setup_ui(self) -> None: def _define_widget_callbacks(self) -> None: # Dataset callback - self.dataset_file.param.watch(fn=self._update_dataset_file, parameter_names="value") - # Variable callbacks - self.dataset_file.param.watch( - fn=lambda event: self._set_variable(event, self.longitude_var, 1, "SCHISM_hgrid_node_x"), - parameter_names="value", - ) - self.dataset_file.param.watch( - fn=lambda event: self._set_variable(event, self.latitude_var, 2, "SCHISM_hgrid_node_y"), - parameter_names="value", - ) - self.dataset_file.param.watch( - fn=lambda event: self._set_variable(event, self.elevation_var, 0, "elev"), - parameter_names="value", - ) - self.dataset_file.param.watch( - fn=lambda event: self._set_variable(event, self.simplices_var, 3, "SCHISM_hgrid_face_nodes"), - parameter_names="value", - ) - self.dataset_file.param.watch( - fn=lambda event: self._set_variable(event, self.time_var, 4, "time"), - parameter_names="value", - ) - # Display options callbacks - self.dataset_file.param.watch(fn=self._update_timestamp, parameter_names="value") + self.dataset_file.param.watch(fn=self._read_header_info, parameter_names="value") + self.prj.param.watch(fn=self._read_header_info,parameter_names="value") self.timeseries.param.watch(fn=self._update_main,parameter_names="value") self.timeseries_pts.param.watch(fn=self._update_main,parameter_names="value") # Station callbacks - # # Render button self.render_button.on_click(self._update_main) @@ -179,28 +122,28 @@ def sidebar(self) -> pn.Column: def main(self) -> pn.Column: return self._main - def _update_dataset_file(self, event: pn.Event) -> None: - logger.debug("Using dataset: %s", self.dataset_file.value) - self._dataset = utils.open_dataset(self.dataset_file.value, load=False) - self._variables = list(self._dataset.variables.keys()) # type: ignore[arg-type] - - def _set_variable(self, event: pn.Event, widget: pn.Widget, index: int, schism_name: str) -> None: - # logger.debug("Updating %s", widget.name) - if schism_name in self._variables: - value = schism_name - else: - try: - value = self._variables[index] - except IndexError: - logger.error("Not enough variables: %d, %s", index, self._variables) - raise - widget.param.set_param(options=self._variables, value=value) - - def _update_timestamp(self, event: pn.Event) -> None: - dataset_timestamps = self._dataset[self.time_var.value].to_pandas().dt.to_pydatetime() - dataset_options = ["MAXIMUM"] + [v.strftime("%Y-%m-%d %H-%M-%S") for v in dataset_timestamps] - # self.timestamp.param.set_param(options=dataset_options, value="MAXIMUM") - self.timestamp.param.set_param(options=dataset_options, value=dataset_timestamps[0]) + def _read_header_info(self,event: pn.Event): + self._MData=api.MapData() + self._TSData=api.TimeseriesData() + + hdata=utils.read_dataset(self.dataset_file.value,1,self.dataset_format.value,self.prj.value) + self._MData.name = self.dataset_file.value + self._MData.format = self.dataset_format.value + self._MData.prj = self.prj.value + self._MData.dataset = hdata[0] + self._MData.times = hdata[1] + self._MData.variables = hdata[2] + self._MData.x = hdata[3] + self._MData.y = hdata[4] + self._MData.elnode = hdata[5] + self._MData.grid = None + + #transform projection + if self.prj.value!='epsg:4326': + self._MData.y, self._MData.x = Transformer.from_crs(self.prj.value,'epsg:4326').transform(self._MData.x,self._MData.y) + + self.time.param.set_param(options=[*hdata[1]], value=hdata[1][0]) + self.variable.param.set_param(options=hdata[2], value=hdata[2][0]) def _debug_ui(self) -> None: logger.debug("Widget values:") @@ -213,41 +156,39 @@ def _update_main(self, event: pn.Event) -> None: # Not sure what is going on here, but panel seems to shallow exceptions within callbacks # Having an explicit try/except at least allows to log the error try: - if self._timestamp!=self.timestamp.value: + if self._MData.time!=self.time.value or self._MData.variable!=self.variable.value or self._MData.layer!=self.layer.value: + if self._MData.variable!=self.variable.value or self._MData.layer!=self.layer.value: + self._TSData=api.TimeseriesData() + + #read dataset snapshot values self._debug_ui() - self._dataset = utils.open_dataset(self.dataset_file.value, load=True) - trimesh = api.get_trimesh( - self._dataset, - self.longitude_var.value, - self.latitude_var.value, - self.elevation_var.value, - self.simplices_var.value, - self.time_var.value, - timestamp=self.timestamp.value, - ) - logger.debug("Created trimesh") - dmap = api.get_elevation_dmap(trimesh, show_grid=self.show_grid.value) + self._MData.get_data(self.time.value,self.variable.value,self.layer.value) + + #get plots of trimesh and dmap + self._MData.get_plot_map() logger.debug("Created dynamic map") - #save plot for efficiency - self.trimesh,self.dmap,self._timestamp=trimesh,dmap,self.timestamp.value + #update dynamic map + if self.show_grid.value: + dmap=self._MData.tiles * self._MData.trimap * self._MData.grid + else: + dmap=self._MData.tiles * self._MData.trimap #update time series if self.timeseries.value: if self.timeseries_pts.value=='clear': - self._TimeseriesData.clear() + self._TSData.clear() hpoint,hcurve=api.get_timeseries( - self.trimesh, - self._TimeseriesData, - self._dataset, + self._MData, + self._TSData, self.timeseries_ymin.value, self.timeseries_ymax.value, self.timeseries_pts.value, ) - self._main.objects = [self.dmap*hpoint,hcurve] + self._main.objects = [dmap*hpoint,hcurve] logger.info("update timeseries") else: - self._main.objects = [self.dmap.opts(height=650)] + self._main.objects = [dmap.opts(height=650)] logger.info("check objects: {}".format(len(self._main.objects))) except Exception: diff --git a/thalassa/utils.py b/thalassa/utils.py index d693668..0fcdcd0 100644 --- a/thalassa/utils.py +++ b/thalassa/utils.py @@ -6,26 +6,98 @@ import os import pathlib import sys - +import numpy as np import xarray as xr - logger = logging.getLogger(__name__) +def read_dataset(fname,method=0,dataset_format="SCHISM",prj='epsg:4326',time=None,variable=None,layer=None): + ''' + function to read information of dataset + Inputs: + fname: path of dataset file, or file handle (xr.Dataset) + method: different ways to read dataset + 0: open dataset; 1: read header information; 2: read dataset snapshot + dataset_format: format of dataset + prj: projection of dataset coordinate + time: timestamp or index of timestamp for dataset snapshot + variable: variable to be read + layer: layer for 3D variables + + note: only SCHISM format is defined so far + ''' -def open_dataset(path: str | pathlib.Path, load: bool = False) -> xr.Dataset: - path = pathlib.Path(path) - if path.suffix == ".nc": - ds = xr.open_dataset(path, mask_and_scale=True) - elif path.suffix in (".zarr", ".zip") or path.is_dir(): - ds = xr.open_dataset(path, mask_and_scale=True, engine="zarr") - # TODO: extend with GeoTiff, Grib etc + #open dataset + if isinstance(fname, xr.Dataset): #already open + ds=fname else: - raise ValueError(f"Don't know how to handle this: {path}") - if load: - # load dataset to memory - ds.load() - return ds + if fname.endswith(".nc"): + ds = xr.open_dataset(fname, mask_and_scale=True) + elif fname.endswith(".zarr") or fname.endswith(".zip"): + ds = xr.open_dataset(fname, mask_and_scale=True, engine="zarr") + else: + raise ValueError(f"unknown format of dataset: {fname}") + if method==0: + return ds + + #read dataset + if dataset_format=="SCHISM": + if method==1: #extract header information of dataset + #variables to be hidden from user + hvars=[ + 'time', 'SCHISM_hgrid', 'SCHISM_hgrid_face_nodes', 'SCHISM_hgrid_edge_nodes', + 'SCHISM_hgrid_node_x', 'SCHISM_hgrid_node_y', 'node_bottom_index', 'SCHISM_hgrid_face_x', + 'SCHISM_hgrid_face_y', 'ele_bottom_index', 'SCHISM_hgrid_edge_x', 'SCHISM_hgrid_edge_y', + 'edge_bottom_index', 'depth', 'sigma', 'dry_value_flag', 'coordinate_system_flag', + 'minimum_depth', 'sigma_h_c', 'sigma_theta_b', 'sigma_theta_f', 'sigma_maxdepth', 'Cs', + 'wetdry_node', 'wetdry_elem', 'wetdry_side', 'zcor'] + + times=ds['time'].to_pandas().dt.to_pydatetime() + variables=[i for i in ds.variables if (i not in hvars)] + x=ds.variables['SCHISM_hgrid_node_x'].values + y=ds.variables['SCHISM_hgrid_node_y'].values + elnode=ds.variables['SCHISM_hgrid_face_nodes'].values + #split quads + if elnode.shape[1]==4: + eid=np.nonzero(~((np.isnan(elnode[:,-1]))|(elnode[:,-1]<0)))[0] + elnode=np.r_[elnode[:,:3],np.c_[elnode[eid,0][:,None], elnode[eid,2:]]] + if elnode.max()>=len(x): + elnode=elnode-1 + elnode=elnode.astype('int') + + return ds,times,variables,x,y,elnode + elif method==2: #extract one snapshot of dataset + #time index + if isinstance(time,int): + tid=time + else: + times=ds['time'].to_pandas().dt.to_pydatetime() + tid=np.nonzero(np.array(times)==timestamp)[0][0] + + #2D and 3D variables + if ds.variables[variable].ndim==1: + mdata=ds.variables[variable].values + elif ds.variables[variable].ndim==2: + mdata=ds.variables[variable][tid].values + elif ds.variables[variable].ndim==3: + #layer index + if layer=='surface': + mdata=ds.variables[variable][tid,:,-1].values + elif layer=='bottom': + if 'node_bottom_index' in [*ds.variables]: + zid=ds.variables['node_bottom_index'][:].values.astype('int') + pid=np.arange(len(zid)) + mdata=ds.variables[variable][tid].values[pid,zid] + else: + mdata=ds.variables[variable][tid,:,-1].values + else: + mdata=ds.variables[variable][tid,:,layer].values + + return mdata + else: + raise ValueError(f"unknown read method for SCHISM model: {method}") + else: + raise ValueError(f"unknown model (read method needs to be defined): {dataset_format}") def reload(module_name: str) -> None: """ @@ -50,10 +122,9 @@ def reload(module_name: str) -> None: # OK, now let's reload! deepreload.reload(module, exclude=to_be_excluded) - def can_be_opened_by_xarray(path): try: - open_dataset(path) + read_dataset(path) except ValueError: logger.debug("path cannot be opened by xarray: %s", path) return False From bf6c2cc2ec9b7f16edfb9a040b317a4d47a641e0 Mon Sep 17 00:00:00 2001 From: Zhengui Wang Date: Fri, 4 Mar 2022 11:50:06 -0500 Subject: [PATCH 2/4] put all functions related to plot timeseries under 'TimeseriesData' class --- thalassa/__init__.py | 1 - thalassa/api.py | 203 +++++++++++++++++++++---------------------- thalassa/ui.py | 44 ++++++---- thalassa/utils.py | 29 ++++++- 4 files changed, 152 insertions(+), 125 deletions(-) diff --git a/thalassa/__init__.py b/thalassa/__init__.py index 4742e1e..3250836 100644 --- a/thalassa/__init__.py +++ b/thalassa/__init__.py @@ -1,7 +1,6 @@ from __future__ import annotations from .api import get_tiles -from .api import get_timeseries from .ui import ThalassaUI from .utils import reload diff --git a/thalassa/api.py b/thalassa/api.py index b7b9a60..b0878ca 100644 --- a/thalassa/api.py +++ b/thalassa/api.py @@ -32,6 +32,7 @@ def __init__(self): self.dataset = None #file handle -> xr.Dataset self.times = None self.variables = None + self.layer = None #connectivity self.x = None self.y = None @@ -81,108 +82,104 @@ class TimeseriesData: ''' define a class to store data related to time series points ''' - def __init__(self): - self.init=False - def clear(self): - self.init=False - -def extract_timeseries(x,y,sx,sy,dataset,variable): - ''' - function for extracting time series@(x,y) from data - ''' - dist=abs(sx+1j*sy-x-1j*y) - mdist=dist.min() - nid=np.nonzero(dist==mdist)[0][0] - mdata=dataset[variable].data[:,nid].copy() - return mdist,mdata - -def add_remove_pts(x,y,data,dataset,fmt,variable): - ''' - function to dynamically add or remove pts by double clicking on the map - ''' - if fmt=='add pts': - if len(data.xys)==0: - mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable) - hcurve=hv.Curve((data.time,mdata),'time',variable).opts(tools=["hover"]) - if mdist<=data.mdist: - data.xys.append((x,y)) - data.elev.append(mdata) - data.curve.append(hcurve) - else: - if data.xys[-1][0]!=x and data.xys[-1][1]!=y: - mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable) - hcurve=hv.Curve((data.time,mdata),'time',variable).opts(tools=["hover"]) - if mdist<=data.mdist: - data.xys.append((x,y)) - data.elev.append(mdata) - data.curve.append(hcurve) - elif fmt=='remove pts': - if len(data.xys)>0: - xys=np.array(data.xys) - dist=abs(xys[:,0]+1j*xys[:,1]-x-1j*y) - mdist=dist.min() - if mdist<=data.mdist: - nid=np.nonzero(dist==mdist)[0][0] - data.xys=[k for i,k in enumerate(data.xys) if i!=nid] - data.elev=[k for i,k in enumerate(data.elev) if i!=nid] - data.curve=[k for i,k in enumerate(data.curve) if i!=nid] - else: - pass - -def get_timeseries(MData,data,ymin,ymax,fmt): - ''' - get time series plots - ''' - - source, dataset = MData.trimesh, MData.dataset - variable='elev' #todo: add an input for time series variable - - #initialize timeseries_data - if data.init is False: - #find the maximum side length - x,y=MData.x,MData.y #tmp fix, improve: todo - e1,e2,e3=MData.elnode.T - - s1=abs((x[e1]-x[e2])+1j*(y[e1]-y[e2])).max() - s2=abs((x[e2]-x[e3])+1j*(y[e2]-y[e3])).max() - s3=abs((x[e3]-x[e1])+1j*(y[e3]-y[e1])).max() - - #save data - data.sx, data.sy, data.x0, data.y0 = x, y, x.mean(), y.mean() - data.mdist=np.max([s1,s2,s3]) - data.time=dataset['time'].data - data.xys=[] - data.elev=[] - data.curve=[] - data.init=True - - def get_plot_point(x,y): - if None not in [x,y]: - add_remove_pts(x,y,data,dataset,fmt,variable) - - if ((x is None) or (y is None)) and len(data.xys)==0: - xys=[(data.x0,data.y0)] - hpoint=gv.Points(xys).opts(show_legend=False,visible=False) - htext=gv.HoloMap({i:gv.Text(*xy,'{}'.format(i+1)).opts( - show_legend=False,visible=False) for i,xy in enumerate(xys)}).overlay() - else: - xys=data.xys - hpoint=gv.Points(xys).opts(color='r',size=3,show_legend=False) - htext=gv.HoloMap({i:gv.Text(*xy,'{}'.format(i+1)).opts( - show_legend=False,color='k',fontsize=3) for i,xy in enumerate(xys)}).overlay() - return hpoint*htext - - def get_plot_curve(x,y): - mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable) - if mdist>data.mdist: - mdata=mdata*np.nan - hdynamic=hv.Curve((data.time,mdata)).opts(color='k',line_width=2,line_dash='dotted') - hcurve=hv.HoloMap({'dynamic':hdynamic,**{(i+1):k for i,k in enumerate(data.curve)}}).overlay() - return hcurve + def __init__(self,MData,variable,layer): + #dataset and header info + self.source = MData + self.dataset = MData.dataset + self.format = MData.format + self.times = MData.times + #self.variable = variable + self.variable = 'elev' + self.layer = layer + self.x = MData.x + self.y = MData.y + self.elnode = MData.elnode + self.x0 = self.x.mean() + self.y0 = self.y.mean() + + #init. data related to time series + self.xys = [] + self.data = [] + self.curve = [] + + #compute maximum side length + e1,e2,e3=self.elnode.T + side1=abs((self.x[e1]-self.x[e2])+1j*(self.y[e1]-self.y[e2])).max() + side2=abs((self.x[e2]-self.x[e3])+1j*(self.y[e2]-self.y[e3])).max() + side3=abs((self.x[e3]-self.x[e1])+1j*(self.y[e3]-self.y[e1])).max() + self.mdist=np.max([side1,side2,side3]) + + def get_data(self,x,y): + ''' + function for extracting time series@(x,y) from data + ''' + dist=abs(self.x+1j*self.y-x-1j*y) + node=np.nonzero(dist==dist.min())[0][0] + mdata=utils.read_dataset(self.dataset,3,self.format,variable=self.variable,layer=self.layer,node=node) + return dist.min(),mdata.copy() - hpoint=gv.DynamicMap(get_plot_point,streams=[DoubleTap(source=source,transient=True)]) - hcurve=gv.DynamicMap(get_plot_curve,streams=[PointerXY(x=data.x0,y=data.y0,source=source)]).opts( - height=400,legend_cols=len(data.xys)+1,legend_position='top', - ylim=(float(ymin),float(ymax)),responsive=True,align='end',active_tools=["pan", "wheel_zoom"]) + def get_timeseries(self,ymin,ymax,fmt): + ''' + get time series plots + ''' - return hpoint,hcurve + def add_remove_pts(x,y): + ''' + function to dynamically add or remove pts by double clicking on the map + ''' + if fmt=='add pts': + if len(self.xys)==0: + mdist,mdata=self.get_data(x,y) + hcurve=hv.Curve((self.times,mdata),'time',self.variable).opts(tools=["hover"]) + if mdist<=self.mdist: + self.xys.append((x,y)) + self.data.append(mdata) + self.curve.append(hcurve) + else: + if self.xys[-1][0]!=x and self.xys[-1][1]!=y: + mdist,mdata=self.get_data(x,y) + hcurve=hv.Curve((self.times,mdata),'time',self.variable).opts(tools=["hover"]) + if mdist<=self.mdist: + self.xys.append((x,y)) + self.data.append(mdata) + self.curve.append(hcurve) + elif fmt=='remove pts': + if len(self.xys)>0: + xys=np.array(self.xys) + dist=abs(xys[:,0]+1j*xys[:,1]-x-1j*y) + mdist=dist.min() + if mdist<=self.mdist: + nid=np.nonzero(dist==mdist)[0][0] + self.xys=[k for i,k in enumerate(self.xys) if i!=nid] + self.data=[k for i,k in enumerate(self.data) if i!=nid] + self.curve=[k for i,k in enumerate(self.curve) if i!=nid] + + def get_plot_point(x,y): + if None not in [x,y]: + add_remove_pts(x,y) + + if ((x is None) or (y is None)) and len(self.xys)==0: + xys=[(self.x0,self.y0)] + hpoint=gv.Points(xys).opts(show_legend=False,visible=False) + htext=gv.HoloMap({i:gv.Text(*xy,'{}'.format(i+1)).opts( + show_legend=False,visible=False) for i,xy in enumerate(xys)}).overlay() + else: + xys=self.xys + hpoint=gv.Points(xys).opts(color='r',size=3,show_legend=False) + htext=gv.HoloMap({i:gv.Text(*xy,'{}'.format(i+1)).opts( + show_legend=False,color='k',fontsize=3) for i,xy in enumerate(xys)}).overlay() + return hpoint*htext + + def get_plot_curve(x,y): + mdist,mdata=self.get_data(x,y) + if mdist>self.mdist: + mdata=mdata*np.nan + hdynamic=hv.Curve((self.times,mdata)).opts(color='k',line_width=2,line_dash='dotted') + hcurve=hv.HoloMap({'dynamic':hdynamic,**{(i+1):k for i,k in enumerate(self.curve)}}).overlay() + return hcurve + + hpoint=gv.DynamicMap(get_plot_point,streams=[DoubleTap(source=self.source.trimesh,transient=True)]) + hcurve=gv.DynamicMap(get_plot_curve,streams=[PointerXY(x=self.x0,y=self.y0,source=self.source.trimesh)]).opts( + height=400,legend_cols=len(self.xys)+1,legend_position='top', + ylim=(float(ymin),float(ymax)),responsive=True,align='end',active_tools=["pan", "wheel_zoom"]) + return hpoint,hcurve diff --git a/thalassa/ui.py b/thalassa/ui.py index 0a624d3..c60f8e1 100644 --- a/thalassa/ui.py +++ b/thalassa/ui.py @@ -65,6 +65,8 @@ def __init__( self, display_stations: bool = False) -> None: #time series self.timeseries = pn.widgets.Checkbox(name="Time Series (double click)") + self.timeseries_variable = pn.widgets.Select(name="Variable") + self.timeseries_layer = pn.widgets.Select(name="Layer",options=["surface","bottom"],value="surface") self.timeseries_pts = pn.widgets.RadioButtonGroup(options=['add pts','remove pts','clear']) self.timeseries_ymin = pn.widgets.TextInput(value='-1.0',name="ymin") self.timeseries_ymax = pn.widgets.TextInput(value='1.0',name="ymax") @@ -90,7 +92,7 @@ def _setup_ui(self) -> None: ) self._sidebar.append( pn.Accordion( - ("Time Series", pn.WidgetBox(self.timeseries, + ("Time Series", pn.WidgetBox(self.timeseries, #pn.Row(self.timeseries_variable,self.timeseries_layer), pn.Row(self.timeseries_ymin, self.timeseries_ymax), self.timeseries_pts,)), active=[0], ), @@ -102,13 +104,16 @@ def _setup_ui(self) -> None: self._sidebar.append(self.render_button) def _define_widget_callbacks(self) -> None: - # Dataset callback + #Dataset callback self.dataset_file.param.watch(fn=self._read_header_info, parameter_names="value") self.prj.param.watch(fn=self._read_header_info,parameter_names="value") + #timeseries callback self.timeseries.param.watch(fn=self._update_main,parameter_names="value") self.timeseries_pts.param.watch(fn=self._update_main,parameter_names="value") - # Station callbacks - # Render button + #self.timeseries_variable.param.watch(fn=self._init_timeseries,parameter_names="value") + #self.timeseries_layer.param.watch(fn=self._init_timeseries,parameter_names="value") + #Station callbacks + #Render button self.render_button.on_click(self._update_main) def _populate_widgets(self) -> None: @@ -124,8 +129,6 @@ def main(self) -> pn.Column: def _read_header_info(self,event: pn.Event): self._MData=api.MapData() - self._TSData=api.TimeseriesData() - hdata=utils.read_dataset(self.dataset_file.value,1,self.dataset_format.value,self.prj.value) self._MData.name = self.dataset_file.value self._MData.format = self.dataset_format.value @@ -141,9 +144,18 @@ def _read_header_info(self,event: pn.Event): #transform projection if self.prj.value!='epsg:4326': self._MData.y, self._MData.x = Transformer.from_crs(self.prj.value,'epsg:4326').transform(self._MData.x,self._MData.y) - + self.time.param.set_param(options=[*hdata[1]], value=hdata[1][0]) self.variable.param.set_param(options=hdata[2], value=hdata[2][0]) + self.timeseries_variable.param.set_param(options=hdata[2], value=hdata[2][0]) + + #initilize Timeseries class + self._init_timeseries(event) + #self._TSData=api.TimeseriesData(self._MData) + + def _init_timeseries(self,event: pn.Event): + #self._TSData=api.TimeseriesData(self._MData, self.timeseries_variable.value, self.timeseries_layer.value) + self._TSData=api.TimeseriesData(self._MData, self.timeseries_variable.value, self.layer.value) def _debug_ui(self) -> None: logger.debug("Widget values:") @@ -157,8 +169,8 @@ def _update_main(self, event: pn.Event) -> None: # Having an explicit try/except at least allows to log the error try: if self._MData.time!=self.time.value or self._MData.variable!=self.variable.value or self._MData.layer!=self.layer.value: - if self._MData.variable!=self.variable.value or self._MData.layer!=self.layer.value: - self._TSData=api.TimeseriesData() + #if self._MData.variable!=self.variable.value or self._MData.layer!=self.layer.value: + # self._TSData=api.TimeseriesData(self._MData) #read dataset snapshot values self._debug_ui() @@ -177,14 +189,12 @@ def _update_main(self, event: pn.Event) -> None: #update time series if self.timeseries.value: if self.timeseries_pts.value=='clear': - self._TSData.clear() - hpoint,hcurve=api.get_timeseries( - self._MData, - self._TSData, - self.timeseries_ymin.value, - self.timeseries_ymax.value, - self.timeseries_pts.value, - ) + self._TSData=api.TimeseriesData(self._MData) + hpoint,hcurve=self._TSData.get_timeseries(self.timeseries_ymin.value, + self.timeseries_ymax.value, self.timeseries_pts.value) + + #display map and time series + if self.timeseries.value: self._main.objects = [dmap*hpoint,hcurve] logger.info("update timeseries") else: diff --git a/thalassa/utils.py b/thalassa/utils.py index 0fcdcd0..796ceab 100644 --- a/thalassa/utils.py +++ b/thalassa/utils.py @@ -10,18 +10,21 @@ import xarray as xr logger = logging.getLogger(__name__) -def read_dataset(fname,method=0,dataset_format="SCHISM",prj='epsg:4326',time=None,variable=None,layer=None): +def read_dataset(fname,method=0,dataset_format="SCHISM",prj='epsg:4326', + time=None,variable=None,layer=None,node=None): ''' function to read information of dataset Inputs: fname: path of dataset file, or file handle (xr.Dataset) method: different ways to read dataset - 0: open dataset; 1: read header information; 2: read dataset snapshot + 0: open dataset; 1: read header information; + 2: read dataset snapshot; 3: read time series dataset_format: format of dataset prj: projection of dataset coordinate time: timestamp or index of timestamp for dataset snapshot variable: variable to be read layer: layer for 3D variables + node: index of node for reading time series (for method=3) note: only SCHISM format is defined so far ''' @@ -66,6 +69,7 @@ def read_dataset(fname,method=0,dataset_format="SCHISM",prj='epsg:4326',time=Non elnode=elnode.astype('int') return ds,times,variables,x,y,elnode + elif method==2: #extract one snapshot of dataset #time index if isinstance(time,int): @@ -80,7 +84,6 @@ def read_dataset(fname,method=0,dataset_format="SCHISM",prj='epsg:4326',time=Non elif ds.variables[variable].ndim==2: mdata=ds.variables[variable][tid].values elif ds.variables[variable].ndim==3: - #layer index if layer=='surface': mdata=ds.variables[variable][tid,:,-1].values elif layer=='bottom': @@ -89,11 +92,29 @@ def read_dataset(fname,method=0,dataset_format="SCHISM",prj='epsg:4326',time=Non pid=np.arange(len(zid)) mdata=ds.variables[variable][tid].values[pid,zid] else: - mdata=ds.variables[variable][tid,:,-1].values + mdata=ds.variables[variable][tid,:,0].values else: mdata=ds.variables[variable][tid,:,layer].values return mdata + + elif method==3: #extract time series + if ds.variables[variable].ndim==2: + mdata=ds.variables[variable][:,node].values + elif ds.variables[variable].ndim==3: + if layer=='surface': + zid=-1 + elif layer=='bottom': + if 'node_bottom_index' in [*ds.variables]: + zid=ds.variables['node_bottom_index'][node].values.astype('int') + else: + zid=0 + else: + zid=layer + mdata=ds.variables[variable][:,node,zid].values + + return mdata + else: raise ValueError(f"unknown read method for SCHISM model: {method}") else: From cd8ada94d01a1848042d61bc2ce4d53df0bba436 Mon Sep 17 00:00:00 2001 From: Zhengui Wang Date: Fri, 8 Apr 2022 10:44:11 -0400 Subject: [PATCH 3/4] Based on Panos' suggestions: 1)remove dataset_format option, infer it from dataset; 2). remove cmap options --- thalassa/api.py | 2 +- thalassa/ui.py | 9 +++++---- thalassa/utils.py | 15 +++++++++++++-- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/thalassa/api.py b/thalassa/api.py index b0878ca..b6ac402 100644 --- a/thalassa/api.py +++ b/thalassa/api.py @@ -72,7 +72,7 @@ def get_plot_map(self): title=f"SCHISM Forecast: {self.variable}", colorbar=True, clabel="meters", - cmap="jet", + #cmap="jet", show_legend=True, ) else: diff --git a/thalassa/ui.py b/thalassa/ui.py index c60f8e1..4ae801d 100644 --- a/thalassa/ui.py +++ b/thalassa/ui.py @@ -55,7 +55,7 @@ def __init__( self, display_stations: bool = False) -> None: self.dataset_file = pn.widgets.Select( name="Dataset file", options=sorted(filter(utils.can_be_opened_by_xarray, glob.glob(DATA_DIR))), ) - self.dataset_format = pn.widgets.Select(name="Format",options=["SCHISM",]) + #self.dataset_format = pn.widgets.Select(name="Format",options=["SCHISM",]) self.prj = pn.widgets.TextInput(value='epsg:4326',name="Projection") self.time = pn.widgets.Select(name="Time") self.variable = pn.widgets.Select(name="Variable") @@ -85,7 +85,7 @@ def __init__( self, display_stations: bool = False) -> None: def _setup_ui(self) -> None: self._sidebar.append( pn.Accordion( - ("Input Files", pn.WidgetBox(self.dataset_file, pn.Row(self.dataset_format,self.prj), + ("Input Files", pn.WidgetBox(self.dataset_file, self.prj, #pn.Row(self.dataset_format,self.prj), self.time, pn.Row(self.variable,self.layer), pn.Row(self.relative_colorbox,self.show_grid),)), active=[0], ), @@ -129,9 +129,10 @@ def main(self) -> pn.Column: def _read_header_info(self,event: pn.Event): self._MData=api.MapData() - hdata=utils.read_dataset(self.dataset_file.value,1,self.dataset_format.value,self.prj.value) + self.dataset_format=utils.read_dataset(self.dataset_file.value)[1] + hdata=utils.read_dataset(self.dataset_file.value,1,self.dataset_format,self.prj.value) self._MData.name = self.dataset_file.value - self._MData.format = self.dataset_format.value + self._MData.format = self.dataset_format self._MData.prj = self.prj.value self._MData.dataset = hdata[0] self._MData.times = hdata[1] diff --git a/thalassa/utils.py b/thalassa/utils.py index 796ceab..e9423d8 100644 --- a/thalassa/utils.py +++ b/thalassa/utils.py @@ -39,8 +39,19 @@ def read_dataset(fname,method=0,dataset_format="SCHISM",prj='epsg:4326', ds = xr.open_dataset(fname, mask_and_scale=True, engine="zarr") else: raise ValueError(f"unknown format of dataset: {fname}") - if method==0: - return ds + + #retrun file handle and data_format + if method==0: + #infer data format + if 'SCHISM_hgrid_face_nodes' in [*ds.variables]: + dataset_format=='SCHISM' + else: + dataset_format=None + + if dataset_format is None: + raise ValueError(f"unknown model of dataset: {fname}; please define its format and read method") + + return ds,dataset_format #read dataset if dataset_format=="SCHISM": From bbeb2a55ccce92bad6c06127255e8bbc15b7e52e Mon Sep 17 00:00:00 2001 From: Zhengui Wang Date: Fri, 8 Apr 2022 12:07:20 -0400 Subject: [PATCH 4/4] add maximum value option for map --- thalassa/api.py | 18 ++++++++++++++++-- thalassa/ui.py | 2 +- thalassa/utils.py | 19 ++++++++++++++++--- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/thalassa/api.py b/thalassa/api.py index b6ac402..fa5a25f 100644 --- a/thalassa/api.py +++ b/thalassa/api.py @@ -45,6 +45,10 @@ def __init__(self): self.trimesh = None self.trimap = None self.tiles = get_tiles() + #max value + self.variable_max = '' + self.layer_max = 0 + self.data_max = None def get_data(self,time,variable,layer): ''' @@ -53,8 +57,18 @@ def get_data(self,time,variable,layer): self.time = time self.variable = variable self.layer = layer - tid=int(np.nonzero(np.array(self.times)==time)[0][0]) - self.data=utils.read_dataset(self.dataset,2,self.format,time=tid,variable=variable,layer=layer) + + if time=='max': + if variable==self.variable_max and layer==self.layer_max: + self.data=self.data_max + else: + self.data=utils.read_dataset(self.dataset,2,self.format,time=time,variable=variable,layer=layer) + self.variable_max=variable + self.layer_max=layer + self.data_max=self.data + else: + tid=int(np.nonzero(np.array(self.times)==time)[0][0]) + self.data=utils.read_dataset(self.dataset,2,self.format,time=tid,variable=variable,layer=layer) def get_plot_map(self): ''' diff --git a/thalassa/ui.py b/thalassa/ui.py index 4ae801d..d3a12be 100644 --- a/thalassa/ui.py +++ b/thalassa/ui.py @@ -146,7 +146,7 @@ def _read_header_info(self,event: pn.Event): if self.prj.value!='epsg:4326': self._MData.y, self._MData.x = Transformer.from_crs(self.prj.value,'epsg:4326').transform(self._MData.x,self._MData.y) - self.time.param.set_param(options=[*hdata[1]], value=hdata[1][0]) + self.time.param.set_param(options=['max',*hdata[1]], value=hdata[1][0]) self.variable.param.set_param(options=hdata[2], value=hdata[2][0]) self.timeseries_variable.param.set_param(options=hdata[2], value=hdata[2][0]) diff --git a/thalassa/utils.py b/thalassa/utils.py index e9423d8..2099320 100644 --- a/thalassa/utils.py +++ b/thalassa/utils.py @@ -83,7 +83,12 @@ def read_dataset(fname,method=0,dataset_format="SCHISM",prj='epsg:4326', elif method==2: #extract one snapshot of dataset #time index - if isinstance(time,int): + if isinstance(time,str): + if time=='max': + tid=-1 #for maximum value + elif time=='min': + tid=-2 #for minimum value + elif isinstance(time,int): tid=time else: times=ds['time'].to_pandas().dt.to_pydatetime() @@ -93,10 +98,18 @@ def read_dataset(fname,method=0,dataset_format="SCHISM",prj='epsg:4326', if ds.variables[variable].ndim==1: mdata=ds.variables[variable].values elif ds.variables[variable].ndim==2: - mdata=ds.variables[variable][tid].values + if tid==-1: + mdata=ds.variables[variable][:].values.max(axis=0) + #elif tid==-2: + # mdata=ds.variables[variable][:].values.min(axis=0) + else: + mdata=ds.variables[variable][tid].values elif ds.variables[variable].ndim==3: if layer=='surface': - mdata=ds.variables[variable][tid,:,-1].values + if tid==-1: + mdata=ds.variables[variable][:,:,-1].values.max(axis=0) + else: + mdata=ds.variables[variable][tid,:,-1].values elif layer=='bottom': if 'node_bottom_index' in [*ds.variables]: zid=ds.variables['node_bottom_index'][:].values.astype('int')