Skip to content

Commit

Permalink
Merge pull request #14 from mansakrishna23/main
Browse files Browse the repository at this point in the history
continuity.py update + plotting function
  • Loading branch information
Cheng Gong authored Apr 10, 2024
2 parents a94fb8a + 00ce0be commit 30f2748
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 3 deletions.
2 changes: 1 addition & 1 deletion PINNICLE/physics/continuity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def set_default(self):
self.output = ['u', 'v', 'a', 'H']
self.output_lb = [-1.0e4/self.yts, -1.0e4/self.yts, -5.0/self.yts, 10.0]
self.output_ub = [ 1.0e4/self.yts, 1.0e4/self.yts, 5/self.yts, 2500.0]
self.data_weights = [1.0e-3*self.yts, 1.0e-3*self.yts, 1.0e4*self.yts, 1.0e-3]
self.data_weights = [1.0e-3*self.yts, 1.0e-3*self.yts, 1.0e4*self.yts, 1.0e-6]
self.residuals = ["fMC"]
self.pde_weights = [1.0e6]

Expand Down
2 changes: 1 addition & 1 deletion PINNICLE/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .helper import *
from .history import History
from .data_misfit import get
from .plotting import plot_solutions, plot_dict_data, plot_data, plot_nn
from .plotting import plot_solutions, plot_dict_data, plot_data, plot_nn, plot_similarity
74 changes: 74 additions & 0 deletions PINNICLE/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import math
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib as mpl
from matplotlib.colors import ListedColormap
from scipy.interpolate import griddata
Expand Down Expand Up @@ -214,3 +215,76 @@ def plot_data(X, Y, im_data, axs=None, vranges={}, **kwargs):
plt.colorbar(im, ax=axs[i], shrink=0.8)

return axs

def plot_similarity(pinn, feature_name, savepath, sim='MAE', cmap='jet', scale=1, cols=[0, 1, 2]):
"""
plotting the similarity between reference and predicted
solutions, mae default
"""
# initialize figure, default all 3 columns
fig, axs = plt.subplots(1, len(cols), figsize=(5*len(cols), 4))

# inputs and outputs of NN
input_names = pinn.nn.parameters.input_variables
output_names = pinn.nn.parameters.output_variables

# inputs
X_ref = pinn.model_data.data['ISSM'].X_dict
xref = X_ref[input_names[0]].flatten()[:,None]
for i in range(1, len(input_names)):
xref = np.hstack((xref, X_ref[input_names[i]].flatten()[:,None]))
meshx = np.squeeze(xref[:, 0])
meshy = np.squeeze(xref[:, 1])

# predictions
pred = pinn.model.predict(xref)

# reference solution
X_sol = pinn.model_data.data['ISSM'].data_dict
sol = X_sol[output_names[0]].flatten()[:,None] # initializing array
for i in range(1, len(output_names)):
sol = np.hstack((sol, X_sol[output_names[i]].flatten()[:,None]))

# grab feature
fid = output_names.index(feature_name)
ref_sol = np.squeeze(sol[:, fid:fid+1]*scale)
pred_sol = np.squeeze(pred[:, fid:fid+1]*scale)
[cmin, cmax] = [np.min(np.append(ref_sol, pred_sol)), np.max(np.append(ref_sol, pred_sol))]
levels = np.linspace(cmin*0.9, cmax*1.1, 500)

# plotting
# reference solution
c = 0 # column number initialize
if 0 in cols:
ax = axs[c].tricontourf(meshx, meshy, ref_sol, levels=levels, cmap=cmap)
cb = plt.colorbar(ax, ax=axs[c])
cb.ax.tick_params(labelsize=14)
axs[c].set_title(feature_name+r"$_{ref}$", fontsize=14)
axs[c].axis('off')
c += 1

# predicted solution
if 1 in cols:
ax = axs[c].tricontourf(meshx, meshy, pred_sol, levels=levels, cmap=cmap)
cb = plt.colorbar(ax, ax=axs[c])
cb.ax.tick_params(labelsize=14)
axs[c].set_title(feature_name+r"$_{pred}$", fontsize=14)
axs[c].axis('off')
c += 1

# difference / similarity
if 2 in cols:
if sim == 'MAE':
diff = np.abs(ref_sol-pred_sol)
diff_val = np.round(np.mean(diff), 2)
title = r"|"+feature_name+r"$_{ref} - $"+feature_name+r"$_{pred}$|, MAE="+str(diff_val)
dmin, dmax = np.min(diff), np.max(diff)
levels = np.linspace(dmin*0.9, dmax*1.1, 500)
ax = axs[c].tricontourf(meshx, meshy, np.squeeze(diff), levels=levels, cmap='RdBu', norm=colors.CenteredNorm())
cb = plt.colorbar(ax, ax=axs[c])
cb.ax.tick_params(labelsize=14)
axs[c].set_title(title, fontsize=14)
axs[c].axis('off')

# save figure to path as defined
plt.savefig(savepath)
11 changes: 10 additions & 1 deletion tests/test_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import PINNICLE as pinn
import numpy as np
import deepxde as dde
from PINNICLE.utils import data_misfit, plot_nn
from PINNICLE.utils import data_misfit, plot_nn, plot_similarity

dde.config.set_default_float('float64')
dde.config.disable_xla_jit()
Expand Down Expand Up @@ -164,3 +164,12 @@ def test_plot(tmp_path):
assert Y.shape == (10,10)
assert len(im_data) == 5
assert im_data['u'].shape == (10,10)

def test_similarity(tmp_path):
hp["save_path"] = str(tmp_path)
hp["is_save"] = True
issm["data_size"] = {"u":4000, "v":4000, "s":4000, "H":4000, "C":None}
hp["data"] = {"ISSM": issm}
experiment = pinn.PINN(params=hp)
experiment.compile()
assert plot_similarity(experiment, feature_name="u", savepath=tmp_path) is None

0 comments on commit 30f2748

Please sign in to comment.