diff --git a/pinnicle/utils/plotting.py b/pinnicle/utils/plotting.py index 11d8d78..adaaac4 100644 --- a/pinnicle/utils/plotting.py +++ b/pinnicle/utils/plotting.py @@ -19,7 +19,7 @@ def cmap_Rignot(): cmap = ListedColormap(cmap) return cmap -def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolution=200, **kwargs): +def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolution=200, absvariable=[], **kwargs): """ plot model predictions Args: @@ -28,6 +28,7 @@ def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolutio u_ref (dict): Reference solutions, if None, then just plot the predicted solutions cols (int): Number of columns of subplot resolution (int): Number of grid points per row/column for plotting + absvariable (list): Names of variables in the predictions that will need to take abs() before comparison """ # generate Cartisian grid of X, Y # currently only work on 2D @@ -44,6 +45,9 @@ def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolutio sol_pred = pinn.model.predict(X_nn) plot_data = {k+"_pred":np.reshape(sol_pred[:,i:i+1], X.shape) for i,k in enumerate(pinn.params.nn.output_variables)} vranges = {k+"_pred":[pinn.params.nn.output_lb[i], pinn.params.nn.output_ub[i]] for i,k in enumerate(pinn.params.nn.output_variables)} + # take abs + for k in absvariable: + plot_data[k+"_pred"] = np.abs( plot_data[k+"_pred"]) # if ref solution is provided if (sol_ref is not None) and (X_ref is not None): diff --git a/tests/test_pinn.py b/tests/test_pinn.py index 7e9a35c..6c95fd1 100644 --- a/tests/test_pinn.py +++ b/tests/test_pinn.py @@ -210,7 +210,7 @@ def test_plot(tmp_path): experiment.model_data.data["ISSM"].X_dict['y'].flatten()[:,None])) assert experiment.plot_predictions(X_ref=X_ref, sol_ref=experiment.model_data.data["ISSM"].data_dict, - resolution=10) is None + resolution=10, absvariable=['C']) is None X, Y, im_data, axs = plot_nn(experiment, experiment.model_data.data["ISSM"].data_dict, resolution=10); assert X.shape == (10,10) assert Y.shape == (10,10)