Skip to content

Commit

Permalink
add options to change figure names when saving with plot_solutions
Browse files Browse the repository at this point in the history
  • Loading branch information
enigne committed Jul 2, 2024
1 parent d8083d1 commit db1fb27
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
5 changes: 3 additions & 2 deletions pinnicle/pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,18 @@ def plot_history(self, path=""):
path = self.check_path(path)
self.history.plot(path)

def plot_predictions(self, path="", **kwargs):
def plot_predictions(self, path="", filename="2Dsolution.png", **kwargs):
""" plot model predictions
Args:
path (Path, str): Path to save the figures
filename (str): name to save the figures, if set to None, then the figure will not be saved
X_ref (dict): Coordinates of the reference solutions, if None, then just plot the predicted solutions
u_ref (dict): Reference solutions, if None, then just plot the predicted solutions
cols (int): Number of columns of subplot
"""
path = self.check_path(path)
plot_solutions(self, path=path, **kwargs)
plot_solutions(self, path=path, filename=filename, **kwargs)

def save_history(self, path=""):
""" save training history
Expand Down
6 changes: 4 additions & 2 deletions pinnicle/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ def cmap_Rignot():
cmap = ListedColormap(cmap)
return cmap

def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolution=200, absvariable=[], **kwargs):
def plot_solutions(pinn, path="", filename="2Dsolution.png", X_ref=None, sol_ref=None, cols=None, resolution=200, absvariable=[], **kwargs):
""" plot model predictions
Args:
path (Path, str): Path to save the figures
filename (str): name to save the figures, if set to None, then the figure will not be saved
X_ref (dict): Coordinates of the reference solutions, if None, then just plot the predicted solutions
u_ref (dict): Reference solutions, if None, then just plot the predicted solutions
cols (int): Number of columns of subplot
Expand Down Expand Up @@ -90,7 +91,8 @@ def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolutio

fig.colorbar(im, ax=ax, shrink=0.8)

plt.savefig(path+"2Dsolution.png")
if filename:
plt.savefig(path+filename)

else:
raise ValueError("Plot is only implemented for 2D problem")
Expand Down

0 comments on commit db1fb27

Please sign in to comment.