Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support more schism variables #37

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support more schism variables
Zhengui Wang committed Mar 3, 2022
commit 974e875760e9ff85f6661456b6bbf997c9cf291f
1 change: 0 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
@@ -39,7 +39,6 @@


ui = thalassa.ThalassaUI(
display_variables=True,
display_stations=True,
)

8 changes: 0 additions & 8 deletions thalassa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
from __future__ import annotations

from .api import get_elevation_dmap
from .api import get_tiles
from .api import get_trimesh
from .api import get_wireframe
from .api import get_timeseries
from .ui import ThalassaUI
from .utils import open_dataset
from .utils import reload

__all__: list[str] = [
"open_dataset",
"reload",
"get_trimesh",
"get_tiles",
"get_wireframe",
"get_elevation_dmap",
"ThalassaUI",
]
140 changes: 77 additions & 63 deletions thalassa/api.py
Original file line number Diff line number Diff line change
@@ -10,64 +10,73 @@
from holoviews.operation.datashader import rasterize
from holoviews.streams import PointerXY,DoubleTap
import numpy as np

logger = logging.getLogger(__name__)

# Load bokeh backend
from . import utils
hv.extension("bokeh")


def get_trimesh(
dataset: xr.Dataset,
longitude_var: str,
latitude_var: str,
elevation_var: str,
simplices_var: str,
time_var: str,
timestamp: str | pd.Timestamp,
) -> gv.TriMesh:
simplices = dataset[simplices_var].values
columns = [longitude_var, latitude_var, elevation_var]
if timestamp == "MAXIMUM":
points_df = dataset.max(time_var)[columns].to_dataframe()
elif timestamp == "MINIMUM":
points_df = dataset.min(time_var)[columns].to_dataframe()
else:
points_df = dataset.sel({time_var: timestamp})[columns].to_dataframe().drop(columns=time_var)
points_df = points_df.reset_index(drop=True)
points_gv = gv.Points(points_df, kdims=[longitude_var, latitude_var], vdims=elevation_var)
trimesh = gv.TriMesh((simplices, points_gv))
return trimesh

logger = logging.getLogger(__name__)

def get_tiles() -> gv.Tiles:
tiles = gv.WMTS("http://c.tile.openstreetmap.org/{Z}/{X}/{Y}.png")
return tiles


def get_wireframe(trimesh: gv.TriMesh) -> hv.Layout:
wireframe = dynspread(rasterize(trimesh.edgepaths, precompute=True))
return wireframe


def get_elevation_dmap(trimesh: gv.TriMesh, show_grid: bool = False) -> hv.Overlay:
tiles = get_tiles()
elevation = rasterize(trimesh, precompute=True).opts( # pylint: disable=no-member
title="Elevation Forecast",
colorbar=True,
clabel="meters",
show_legend=True,
)
logger.debug("show grid: %s", show_grid)
if show_grid:
overlay = tiles * elevation * get_wireframe(trimesh=trimesh)
else:
overlay = tiles * elevation
return overlay

#----------------------------------------------------------------------------------------
#time series
#----------------------------------------------------------------------------------------
class MapData:
'''
define a class to store data related to dynamic map
'''
def __init__(self):
#dataset info
self.name = None
self.format = None
self.prj = None
#header info
self.dataset = None #file handle -> xr.Dataset
self.times = None
self.variables = None
#connectivity
self.x = None
self.y = None
self.elnode = None
#dataset snapshot
self.time = None
self.variable = None
self.data = None
self.grid = None
self.trimesh = None
self.trimap = None
self.tiles = get_tiles()

def get_data(self,time,variable,layer):
'''
extract a snapshot from dataset
'''
self.time = time
self.variable = variable
self.layer = layer
tid=int(np.nonzero(np.array(self.times)==time)[0][0])
self.data=utils.read_dataset(self.dataset,2,self.format,time=tid,variable=variable,layer=layer)

def get_plot_map(self):
'''
plot a snapshot: only SCHISM method is defined so far
'''
if self.format=="SCHISM":
if self.x.min()<-360 or self.x.max()>360 or self.y.min()<-90 or self.y.max()>90:
raise ValueError(f"check dataset projection: abs(lat)>360 or abs(lon)>90")
df=pd.DataFrame({'longitude':self.x, 'latitude':self.y, 'data':self.data})
pdf=gv.Points(df,kdims=['longitude','latitude'],vdims='data')
self.trimesh=gv.TriMesh((self.elnode,pdf))
if self.grid is None:
self.grid=dynspread(rasterize(self.trimesh.edgepaths, precompute=True))
self.trimap=rasterize(self.trimesh, precompute=True).opts(
title=f"SCHISM Forecast: {self.variable}",
colorbar=True,
clabel="meters",
cmap="jet",
Copy link
Collaborator

@pmav99 pmav99 Apr 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using this colormap throws an exception:

# ...
  File "/home/panos/.conda/envs/thalassa/lib/python3.8/site-packages/holoviews/plotting/util.py", line 912, in process_cmap
    raise ValueError("Supplied cmap %s not found among %s colormaps." %
ValueError: Supplied cmap jet not found among matplotlib, bokeh, or colorcet colormaps.

We should either add matplotlib to the dependencies or use a different colormap. I wouldn't add an extra dependency just for a colormap.

Copy link
Collaborator Author

@wzhengui wzhengui Apr 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pmav99 : Thank you for reviewing the PR. I think these are good suggestions. I will try to address them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pmav99 and @brey : Thank Panos' useful suggestions again. I am sorry that I didn't have a chance talking to George this morning, as I was working on Thalassa. I just finished the revision except point 4. I tried to dynamically update the interface according to different variable picked by users, but it turns out to be very tricky. I will keep this on my mind, and see whether this is a good solution.

show_legend=True,
)
else:
raise ValueError(f"please define plot method for dataset format: {dataset_format}")

class TimeseriesData:
'''
define a class to store data related to time series points
@@ -77,32 +86,32 @@ def __init__(self):
def clear(self):
self.init=False

def extract_timeseries(x,y,sx,sy,data):
def extract_timeseries(x,y,sx,sy,dataset,variable):
'''
function for extracting time series@(x,y) from data
'''
dist=abs(sx+1j*sy-x-1j*y)
mdist=dist.min()
nid=np.nonzero(dist==mdist)[0][0]
mdata=data['elev'].data[:,nid].copy()
mdata=dataset[variable].data[:,nid].copy()
return mdist,mdata

def add_remove_pts(x,y,data,dataset,fmt):
def add_remove_pts(x,y,data,dataset,fmt,variable):
'''
function to dynamically add or remove pts by double clicking on the map
'''
if fmt=='add pts':
if len(data.xys)==0:
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset)
hcurve=hv.Curve((data.time,mdata),'time','elevation').opts(tools=["hover"])
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable)
hcurve=hv.Curve((data.time,mdata),'time',variable).opts(tools=["hover"])
if mdist<=data.mdist:
data.xys.append((x,y))
data.elev.append(mdata)
data.curve.append(hcurve)
else:
if data.xys[-1][0]!=x and data.xys[-1][1]!=y:
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset)
hcurve=hv.Curve((data.time,mdata),'time','elevation').opts(tools=["hover"])
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable)
hcurve=hv.Curve((data.time,mdata),'time',variable).opts(tools=["hover"])
if mdist<=data.mdist:
data.xys.append((x,y))
data.elev.append(mdata)
@@ -120,15 +129,20 @@ def add_remove_pts(x,y,data,dataset,fmt):
else:
pass

def get_timeseries(source,data,dataset,ymin,ymax,fmt):
def get_timeseries(MData,data,ymin,ymax,fmt):
'''
get time series plots
'''

source, dataset = MData.trimesh, MData.dataset
variable='elev' #todo: add an input for time series variable

#initialize timeseries_data
if data.init is False:
#find the maximum side length
x,y=dataset['SCHISM_hgrid_node_x'].data,dataset['SCHISM_hgrid_node_y'].data
e1,e2,e3=dataset['SCHISM_hgrid_face_nodes'].data.T
x,y=MData.x,MData.y #tmp fix, improve: todo
e1,e2,e3=MData.elnode.T

s1=abs((x[e1]-x[e2])+1j*(y[e1]-y[e2])).max()
s2=abs((x[e2]-x[e3])+1j*(y[e2]-y[e3])).max()
s3=abs((x[e3]-x[e1])+1j*(y[e3]-y[e1])).max()
@@ -144,7 +158,7 @@ def get_timeseries(source,data,dataset,ymin,ymax,fmt):

def get_plot_point(x,y):
if None not in [x,y]:
add_remove_pts(x,y,data,dataset,fmt)
add_remove_pts(x,y,data,dataset,fmt,variable)

if ((x is None) or (y is None)) and len(data.xys)==0:
xys=[(data.x0,data.y0)]
@@ -159,7 +173,7 @@ def get_plot_point(x,y):
return hpoint*htext

def get_plot_curve(x,y):
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset)
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable)
if mdist>data.mdist:
mdata=mdata*np.nan
hdynamic=hv.Curve((data.time,mdata)).opts(color='k',line_width=2,line_dash='dotted')
Loading