Skip to content

Commit

Permalink
Merge pull request #31 from ISSMteam/add_abs_to_plot_sol
Browse files Browse the repository at this point in the history
add absvariable as an option
  • Loading branch information
Cheng Gong authored May 24, 2024
2 parents 2bba080 + 0b38cc8 commit 73e8120
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion pinnicle/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def cmap_Rignot():
cmap = ListedColormap(cmap)
return cmap

def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolution=200, **kwargs):
def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolution=200, absvariable=[], **kwargs):
""" plot model predictions
Args:
Expand All @@ -28,6 +28,7 @@ def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolutio
u_ref (dict): Reference solutions, if None, then just plot the predicted solutions
cols (int): Number of columns of subplot
resolution (int): Number of grid points per row/column for plotting
absvariable (list): Names of variables in the predictions that will need to take abs() before comparison
"""
# generate Cartisian grid of X, Y
# currently only work on 2D
Expand All @@ -44,6 +45,9 @@ def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolutio
sol_pred = pinn.model.predict(X_nn)
plot_data = {k+"_pred":np.reshape(sol_pred[:,i:i+1], X.shape) for i,k in enumerate(pinn.params.nn.output_variables)}
vranges = {k+"_pred":[pinn.params.nn.output_lb[i], pinn.params.nn.output_ub[i]] for i,k in enumerate(pinn.params.nn.output_variables)}
# take abs
for k in absvariable:
plot_data[k+"_pred"] = np.abs( plot_data[k+"_pred"])

# if ref solution is provided
if (sol_ref is not None) and (X_ref is not None):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_plot(tmp_path):
experiment.model_data.data["ISSM"].X_dict['y'].flatten()[:,None]))
assert experiment.plot_predictions(X_ref=X_ref,
sol_ref=experiment.model_data.data["ISSM"].data_dict,
resolution=10) is None
resolution=10, absvariable=['C']) is None
X, Y, im_data, axs = plot_nn(experiment, experiment.model_data.data["ISSM"].data_dict, resolution=10);
assert X.shape == (10,10)
assert Y.shape == (10,10)
Expand Down

0 comments on commit 73e8120

Please sign in to comment.