Skip to content

Commit

Permalink
hopefully fixed the issue
Browse files Browse the repository at this point in the history
  • Loading branch information
kthyng committed Nov 27, 2023
1 parent ef2bd8e commit 9450c58
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 68 deletions.
140 changes: 83 additions & 57 deletions ocean_model_skill_assessor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import extract_model as em
import extract_model.accessor
import intake
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests
Expand Down Expand Up @@ -1353,11 +1354,15 @@ def _return_mask(
if mask is None:
if paths.MASK_PATH(key_variable_data).is_file():
if logger is not None:
logger.info("Using cached mask.")
logger.info(
f"Using cached mask from {paths.MASK_PATH(key_variable_data)}."
)
mask = xr.open_dataarray(paths.MASK_PATH(key_variable_data))
else:
if logger is not None:
logger.info("Finding and saving mask to cache.")
logger.info(
f"Finding and saving mask to cache to {paths.MASK_PATH(key_variable_data)}."
)
# # dam variable might not be in Dataset itself, but its coordinates probably are.
# mask = get_mask(dsm, dam.name)
mask = get_mask(dsm, lon_name, wetdry=wetdry)
Expand Down Expand Up @@ -1854,6 +1859,8 @@ def run(
source_names = list(cat)
for i, source_name in enumerate(source_names[:ndatasets]):

skip_dataset = False

if ndatasets is None:
msg = (
f"\nsource name: {source_name} ({i+1} of {ndata} for catalog {cat}."
Expand Down Expand Up @@ -1881,6 +1888,7 @@ def run(
logger.info(
f"no `key_variables` key found in source metadata or at least not {key_variable}"
)
skip_dataset = True
continue

min_lon = cat[source_name].metadata["minLongitude"]
Expand Down Expand Up @@ -1952,59 +1960,21 @@ def run(
msg = f"Data cannot be loaded for dataset {source_name}. Skipping dataset.\n"
logger.warning(msg)
maps.pop(-1)
skip_dataset = True
continue

except Exception as e:
logger.warning(str(e))
msg = f"Data cannot be loaded for dataset {source_name}. Skipping dataset.\n"
logger.warning(msg)
maps.pop(-1)
skip_dataset = True
continue

# Need to have this here because if model file has previously been read in but
# aligned file doesn't exist yet, this needs to run to update the sign of the
# data depths in certain cases.
zkeym = dsm.cf.axes["Z"][0]
dfd, Z, vertical_interp = _choose_depths(
dfd,
dsm[zkeym].attrs["positive"],
no_Z,
want_vertical_interp,
logger,
)

# take out relevant variable and identify mask if available (otherwise None)
# this mask has to match dam for em.select()
if not skip_mask:
mask = _return_mask(
mask,
dsm,
dsm.cf.coordinates["longitude"][
0
], # using the first longitude key is adequate
wetdry,
key_variable_data,
paths,
logger,
)

# I think these should always be true together
if skip_mask:
assert mask is None

# Calculate boundary of model domain to compare with data locations and for map
# don't need p1 if check_in_boundary False and plot_map False
if (check_in_boundary or plot_map) and p1 is None:
p1 = _return_p1(paths, dsm, mask, alpha, dd, logger)

# see if data location is inside alphashape-calculated polygon of model domain
if check_in_boundary:
if _is_outside_boundary(p1, min_lon, min_lat, source_name, logger):
maps.pop(-1)
continue

# check for already-aligned model-data file
fname_processed_orig = f"{cat.name}_{source_name}_{key_variable_data}"
fname_processed_orig = (
f"{cat.name}_{source_name.replace('.','_')}_{key_variable_data}"
)
(
fname_processed,
fname_processed_data,
Expand Down Expand Up @@ -2057,6 +2027,38 @@ def run(
source_name,
)

# take out relevant variable and identify mask if available (otherwise None)
# this mask has to match dam for em.select()
if not skip_mask:
mask = _return_mask(
mask,
dsm,
dsm.cf.coordinates["longitude"][
0
], # using the first longitude key is adequate
wetdry,
key_variable_data,
paths,
logger,
)

# I think these should always be true together
if skip_mask:
assert mask is None

# Calculate boundary of model domain to compare with data locations and for map
# don't need p1 if check_in_boundary False and plot_map False
if (check_in_boundary or plot_map) and p1 is None:
p1 = _return_p1(paths, dsm, mask, alpha, dd, logger)

# see if data location is inside alphashape-calculated polygon of model domain
if check_in_boundary:
if _is_outside_boundary(
p1, min_lon, min_lat, source_name, logger
):
maps.pop(-1)
continue

# Check, prep, and possibly narrow data time range
dfd, maps = _check_prep_narrow_data(
dfd,
Expand All @@ -2073,6 +2075,7 @@ def run(
# if there were any issues in the last function, dfd should be None and we should
# skip this dataset
if dfd is None:
skip_dataset = True
continue

# Read in model output from cache if possible.
Expand Down Expand Up @@ -2154,6 +2157,18 @@ def run(
for date in np.unique(dfd.cf["T"].values)
]

# Need to have this here because if model file has previously been read in but
# aligned file doesn't exist yet, this needs to run to update the sign of the
# data depths in certain cases.
zkeym = dsm.cf.axes["Z"][0]
dfd, Z, vertical_interp = _choose_depths(
dfd,
dsm[zkeym].attrs["positive"],
no_Z,
want_vertical_interp,
logger,
)

select_kwargs = dict(
dam=dam,
longitude=lons,
Expand Down Expand Up @@ -2209,7 +2224,6 @@ def run(
ts_mods_copy = deepcopy(ts_mods)
# ts_mods_copy = ts_mods.copy() # otherwise you modify ts_mods when adding data
for mod in ts_mods_copy:
# import pdb; pdb.set_trace()
logger.info(
f"Apply a time series modification called {mod['function']}."
)
Expand Down Expand Up @@ -2241,6 +2255,13 @@ def run(

model_var = mod["function"](model_var, **mod["inputs"])

# check model output for nans
ind_keep = np.arange(0, model_var.cf["T"].size)[
model_var.cf["T"].notnull()
]
if model_var.cf["T"].name in model_var.dims:
model_var = model_var.isel({model_var.cf["T"].name: ind_keep})

# there could be a small mismatch in the length of time if times were pulled
# out separately
if np.unique(model_var.cf["T"]).size != np.unique(dfd.cf["T"]).size:
Expand All @@ -2254,7 +2275,8 @@ def run(
etime = pd.Timestamp(
min(dfd.cf["T"].values[-1], model_var.cf["T"].values[-1])
)
model_var = model_var.cf.sel({"T": slice(stime, etime)})
if stime != etime:
model_var = model_var.cf.sel({"T": slice(stime, etime)})

if isinstance(dfd, pd.DataFrame):
dfd = dfd.set_index(dfd.cf["T"].name)
Expand All @@ -2270,15 +2292,17 @@ def run(
# otherwise only nan's come through
# accounting for known issue for interpolation after sampling if indices changes
# https://github.com/pandas-dev/pandas/issues/14297
model_index = model_var.cf["T"].to_pandas().index
model_index.name = dfd.index.name
ind = model_index.union(dfd.index)
dfd = (
dfd.reindex(ind)
.interpolate(method="time", limit=3)
.reindex(model_index)
)
dfd = dfd.reset_index()
# this won't run for single ctd profiles
if len(dfd.cf["T"].unique()) > 1:
model_index = model_var.cf["T"].to_pandas().index
model_index.name = dfd.index.name
ind = model_index.union(dfd.index)
dfd = (
dfd.reindex(ind)
.interpolate(method="time", limit=3)
.reindex(model_index)
)
dfd = dfd.reset_index()

elif isinstance(dfd, xr.Dataset):
# interpolate data to model times
Expand Down Expand Up @@ -2395,7 +2419,7 @@ def run(
# title = f"{count}: {source_name}"
# else:
# title = f"{source_name}"
if not figname.is_file() or override_plot:
if not skip_dataset and (not figname.is_file() or override_plot):
fig = plot.selection(
obs,
model,
Expand Down Expand Up @@ -2434,3 +2458,5 @@ def run(
if len(maps) == 1 and return_fig:
# model output, processed data, processed model, stats, fig
return fig
# else:
# plt.close(fig)
8 changes: 4 additions & 4 deletions ocean_model_skill_assessor/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def selection(
# Assume want along-transect distance if number of unique locations is
# equal to or more than number of times
if (
np.unique(obs.cf["longitude"]).size >= np.unique(obs.cf["T"]).size
or np.unique(obs.cf["latitude"]).size >= np.unique(obs.cf["T"]).size
np.unique(obs.cf["longitude"]).size + 3 >= np.unique(obs.cf["T"]).size
or np.unique(obs.cf["latitude"]).size + 3 >= np.unique(obs.cf["T"]).size
):
assert isinstance(key_variable, str)
xname, yname, zname = "distance", "Z", key_variable
Expand Down Expand Up @@ -236,7 +236,7 @@ def selection(
ylabel=ylabel,
zlabel=zlabel,
nsubplots=3,
figsize=(15, 6),
# figsize=(15, 6),
figname=figname,
along_transect_distance=along_transect_distance,
kind="scatter",
Expand All @@ -260,7 +260,7 @@ def selection(
ylabel=ylabel,
zlabel=zlabel,
kind="pcolormesh",
figsize=(15, 6),
# figsize=(15, 6),
figname=figname,
return_plot=True,
**kwargs,
Expand Down
3 changes: 2 additions & 1 deletion ocean_model_skill_assessor/plot/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def plot(
title: Optional[str] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
model_label: str = "Model",
figname: Union[str, pathlib.Path] = "figure.png",
dpi: int = 100,
figsize: tuple = (15, 5),
Expand Down Expand Up @@ -72,7 +73,7 @@ def plot(
ax.plot(
np.array(model.cf[xname].squeeze()),
np.array(model.cf[yname].squeeze()),
label="model",
label=model_label,
lw=lw,
color=col_model,
)
Expand Down
13 changes: 11 additions & 2 deletions ocean_model_skill_assessor/plot/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def plot_map(
def plot_cat_on_map(
catalog: Union[Catalog, str],
paths: Paths,
source_names: Optional[list] = None,
figname: Optional[str] = None,
remove_duplicates=None,
**kwargs_map,
Expand All @@ -407,6 +408,8 @@ def plot_cat_on_map(
Which catalog of datasets to plot on map.
paths : Paths
Paths object for finding paths to use.
source_names : list
Use these list names instead of list(cat) if input.
remove_duplicates : bool
If True, take the set of the source in catalog based on the spatial locations so they are not repeated in the map.
remove_duplicates : function, optional
Expand All @@ -420,7 +423,13 @@ def plot_cat_on_map(
>>> omsa.plot.map.plot_cat_on_map(catalog=catalog_name, project_name=project_name)
"""

cat = open_catalogs(catalog, paths)[0]
if isinstance(catalog, Catalog):
cat = catalog
else:
cat = open_catalogs(catalog, paths)[0]

if source_names is None:
source_names = list(cat)

figname = figname or f"map_of_{cat.name}"

Expand All @@ -437,7 +446,7 @@ def plot_cat_on_map(
s,
cat[s].metadata["maptype"] or "",
]
for s in list(cat)
for s in source_names
if "minLongitude" in cat[s].metadata
]
)
Expand Down
19 changes: 17 additions & 2 deletions ocean_model_skill_assessor/plot/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def plot(
nsubplots: int = 3,
figname: Union[str, pathlib.Path] = "figure.png",
dpi: int = 100,
figsize=(15, 4),
figsize=(15, 6),
return_plot: bool = False,
invert_yaxis: bool = False,
make_Z_negative=None,
**kwargs,
):
"""Plot scatter or surface plot.
Expand Down Expand Up @@ -86,6 +88,9 @@ def plot(
If True, return plot. Use for testing.
"""

if "override_plot" in kwargs:
kwargs.pop("override_plot")

# want obs and data as DataFrames
if kind == "scatter":
if isinstance(obs, xr.Dataset):
Expand Down Expand Up @@ -150,6 +155,14 @@ def plot(
else:
subplot_kw = {}

if make_Z_negative is not None:
if make_Z_negative == "obs":
if (obs[obs.cf["Z"].notnull()].cf["Z"] > 0).all():
obs[obs.cf["Z"].name] = -obs.cf["Z"]
elif make_Z_negative == "model":
if (model[model.cf["Z"].notnull()].cf["Z"] > 0).all():
model[model.cf["Z"].name] = -model.cf["Z"]

fig, axes = plt.subplots(
1,
nsubplots,
Expand All @@ -166,7 +179,7 @@ def plot(
)
pandas_kwargs = dict(colorbar=False)

kwargs = {key: cmap_params.get(key) for key in ["vmin", "vmax", "cmap"]}
kwargs.update({key: cmap_params.get(key) for key in ["vmin", "vmax", "cmap"]})

if plot_on_map:
omsa.plot.map.setup_ax(
Expand All @@ -193,6 +206,8 @@ def plot(
axes[0].set_ylabel(ylabel, fontsize=fs)
axes[0].set_xlabel(xlabel, fontsize=fs)
axes[0].tick_params(axis="both", labelsize=fs)
if invert_yaxis:
axes[0].invert_yaxis()

# plot model
if plot_on_map:
Expand Down
3 changes: 3 additions & 0 deletions ocean_model_skill_assessor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def save_processed_files(
"""

if isinstance(dfd, pd.DataFrame):
# # make sure datetimes will be recognized when reread
# # actually seems to work without this
# dfd = dfd.rename(columns={dfd.cf["T"].name: "time"})
dfd.to_csv(fname_processed_data, index=False)
elif isinstance(dfd, xr.Dataset):
dfd.to_netcdf(fname_processed_data)
Expand Down
Loading

0 comments on commit 9450c58

Please sign in to comment.