Skip to content

Commit

Permalink
use .__init_subclass__ to dynamically create EquationParameters
Browse files Browse the repository at this point in the history
  • Loading branch information
enigne committed Jan 25, 2024
1 parent 2533e03 commit 7eaf924
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 116 deletions.
67 changes: 51 additions & 16 deletions PINN_ICE/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,32 @@ def __init__(self, param_dict):
self.check_consisteny()

def __str__(self):
"""
display all attributes except 'param_dict'
""" display all attributes except 'param_dict'
"""
return "\t" + type(self).__name__ + ": \n" + \
("\n".join(["\t\t" + k + ":\t" + str(self.__dict__[k]) for k in self.__dict__ if k != "param_dict"]))+"\n"

@abstractmethod
def set_default(self):
"""
set default values
""" set default values
"""
pass

@abstractmethod
def check_consisteny(self):
"""
check consistency of the parameter data
""" check consistency of the parameter data
"""
pass

def _add_parameters(self, pdict: dict):
"""
add all the keys from pdict to the class, with their values
""" add all the keys from pdict to the class, with their values
"""
if isinstance(pdict, dict):
for key, value in pdict.items():
setattr(self, key, value)

def set_parameters(self, pdict: dict):
"""
find all the keys from pdict which are avalible in the class, update the values
""" find all the keys from pdict which are avalible in the class, update the values
"""
if isinstance(pdict, dict):
for key, value in pdict.items():
Expand All @@ -57,8 +52,7 @@ def set_parameters(self, pdict: dict):
setattr(self, key, value)

def has_keys(self, keys):
"""
if all the keys are in the class, return true, otherwise return false
""" if all the keys are in the class, return true, otherwise return false
"""
if isinstance(keys, dict) or isinstance(keys, list):
return all([hasattr(self, k) for k in keys])
Expand Down Expand Up @@ -173,9 +167,9 @@ def check_consisteny(self):
pass

def setup_equations(self):
""" translate the input dict to EquationParameter(), and save back to the values in self.equations
""" translate the input dict to subclass of EquationParameter(), and save back to the values in self.equations
"""
self.equations = {k:EquationParameter(self.equations[k]) for k in self.equations}
self.equations = {k:EquationParameter.create(k, param_dict = self.equations[k]) for k in self.equations}

def __str__(self):
"""
Expand All @@ -187,16 +181,34 @@ def __str__(self):
class EquationParameter(ParameterBase):
""" parameter of equations
"""
subclasses = {}
def __init__(self, param_dict={}):
super().__init__(param_dict)

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.subclasses[cls._EQUATION_TYPE] = cls

@classmethod
def create(cls, equation_type, **kwargs):
if equation_type not in cls.subclasses:
raise ValueError(f"Equation type {format(message_type)} is not defined")
return cls.subclasses[equation_type](**kwargs)

def set_default(self):
# list of input names
self.input = []
# list of output names
self.output = []
# lower and upper bound of output
self.output_lb = []
self.output_ub = []
# weights of each output
self.data_weights = []
self.pde_weights = None
# names of residuals
self.residuals = []
# pde weights
self.pde_weights = []
# scalar variables: name:value
self.scalar_variables = {}

Expand All @@ -207,7 +219,30 @@ def check_consisteny(self):
raise ValueError("Size of 'output' does not match the size of 'output_ub'")
if any([l>=u for l,u in zip(self.output_lb, self.output_ub)]):
raise ValueError("output_lb is not smaller than output_ub")
pass
if (len(self.output)) != (len(self.data_weights)):
raise ValueError("Size of 'output' does not match the size of 'data_weights'")

# check the pde weights
if isinstance(self.pde_weights, list):
if len(self.pde_weights) != len(self.residuals):
raise ValueError("Length of pde_weights does not match the length of residuals")
else:
raise ValueError("pde_weights is not a list")

def set_parameters(self, pdict: dict):
""" overwrite the default function, so that for 'scalar_parameters', only update the dict
"""
if isinstance(pdict, dict):
for key, value in pdict.items():
# only update attribute the key
if hasattr(self, key):
# only update the dictionary, not overwirte
if isinstance(value, dict):
old_dict = getattr(self, key)
old_dict.update(value)
setattr(self, key, old_dict)
else:
setattr(self, key, value)

def __str__(self):
"""
Expand Down
1 change: 1 addition & 0 deletions PINN_ICE/physics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .constants import Constants
from .equationbase import *
from .physics import *
from .stressbalance import *
10 changes: 10 additions & 0 deletions PINN_ICE/physics/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class Constants():
""" Base class of all physical constants, in [SI]
"""
def __init__(self, **kwargs):
# Physical constants in [SI]
self.rhoi = 917.0 # ice density (kg/m^3)
self.rhow = 1023.0 # sea water density (kg/m^3)
self.g = 9.81 # gravitational force (m/s^2)
self.yts = 3600.0*24*365 # year to second (s)

78 changes: 34 additions & 44 deletions PINN_ICE/physics/equationbase.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,22 @@
from abc import ABC, abstractmethod
from ..parameter import EquationParameter
from . import Constants

class EquationBase(ABC):
class EquationBase(ABC, Constants):
""" base class of all the equations
"""
def __init__(self, parameters=EquationParameter()):
self.parameters = parameters
# Physical constants in [SI]
self.rhoi = 917.0 # ice density (kg/m^3)
self.rhow = 1023.0 # sea water density (kg/m^3)
self.g = 9.81 # gravitational force (m/s^2)
self.yts = 3600.0*24*365 # year to second (s)

# Dict of dependent and independent variables of the model, the values are
# the global component id in the Physics, these two dicts are maps from local
# to global
self.local_input_var = {} # x, y, z, t, etc.
self.local_output_var = {} # u, v, s, H, etc.
# load constants first
Constants.__init__(self)

# default lower and upper bounds of the output in [SI] unit
self.output_lb = {}
self.output_ub = {}

# default weights to scale the data misfit
self.data_weights = {}
# get the setting parameters
self.parameters = parameters

# residual name list
self.residuals = []
# update parameters in the equation accordingly
self.update_parameters(self.parameters)

# default pde weights
self.pde_weights = []
# update scalar variables
self.update_scalars(self.parameters.scalar_variables)

def get_input_list(self):
""" get the List of names of input variables
Expand All @@ -55,30 +42,33 @@ def update_id(self, global_input_var=None, global_output_var=None):
if global_output_var is not None:
self.local_output_var = {o:global_output_var.index(o) for o in self.local_output_var}

def update_parameters(self):
def update_parameters(self, parameters):
""" update attributes of the class using EquationParameter
"""
# input
if len(self.parameters.input) > 0:
self.local_input_var = {k:i for i,k in enumerate(self.parameters.input)}
# output
if len(self.parameters.output) > 0:
self.local_output_var = {k:i for i,k in enumerate(self.parameters.output)}
# lower bound
if len(self.parameters.output_lb) > 0:
self.output_lb = {k:self.parameters.output_lb[i] for i,k in enumerate(self.parameters.output)}
# upper bound
if len(self.parameters.output_ub) > 0:
self.output_ub = {k:self.parameters.output_ub[i] for i,k in enumerate(self.parameters.output)}
# data weight
if len(self.parameters.data_weights) > 0:
self.data_weights = {k:self.parameters.data_weights[i] for i,k in enumerate(self.parameters.output)}
# Dict of dependent and independent variables of the model, the values are
# the global component id in the Physics, these two dicts are maps from local
# to global, current indices are temporary, they will be updated after all equations are set
self.local_input_var = {k:i for i,k in enumerate(parameters.input)} # x, y, z, t, etc.
self.local_output_var = {k:i for i,k in enumerate(parameters.output)} # u, v, s, H, etc.

# lower and upper bounds of the output in [SI] unit, with keys of the variable name
self.output_lb = {k: parameters.output_lb[i] for i,k in enumerate(parameters.output)}
self.output_ub = {k: parameters.output_ub[i] for i,k in enumerate(parameters.output)}

# weights to scale the data misfit to 1 in [SI]
self.data_weights = {k: parameters.data_weights[i] for i,k in enumerate(parameters.output)}

# residuals name list
self.residuals = parameters.residuals
# pde weights
if isinstance(self.parameters.pde_weights, list):
if len(self.parameters.pde_weights) == 1:
self.pde_weights = [self.parameters.pde_weights[0] for r in self.residuals]
else:
self.pde_weights = self.parameters.pde_weights
self.pde_weights = parameters.pde_weights

def update_scalars(self, scalar_variables: dict):
""" update scalars in the equations
"""
if isinstance(scalar_variables, dict):
for key, value in scalar_variables.items():
setattr(self, key, value)

@abstractmethod
def pde(self, nn_input_var, nn_output_var):
Expand Down
2 changes: 1 addition & 1 deletion PINN_ICE/physics/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _add_equations(self, eq):
"""
equation = None
if eq == "SSA":
equation = stressbalance.SSA2DUniformB
equation = stressbalance.SSA
elif eq == "MOLHO":
equation = stressbalance.MOLHO
# TODO: add mass conservation
Expand Down
99 changes: 50 additions & 49 deletions PINN_ICE/physics/stressbalance.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
import deepxde as dde
from . import EquationBase
from . import EquationBase, Constants
from ..parameter import EquationParameter


class SSA2DUniformB(EquationBase): #{{{
""" SSA on 2D problem with uniform B
class SSAEquationParameter(EquationParameter, Constants):
""" default parameters for SSA
"""
def __init__(self, parameters=EquationParameter()):
super().__init__(parameters)
# viscosity
self.B = 1.26802073401e+08 # -8 degree C, cuffey
self.n = 3.0

# Dict of input and output used in this model, and their component id
# Note the ids will be reassigned after adding all physics together
self.local_input_var = {"x":0, "y":1}
self.local_output_var = {"u":0, "v":1, "s":2, "H":3, "C":4}

# default lower and upper bounds of the output in [SI] unit
self.output_lb = {"u":-1.0e4/self.yts, "v":-1.0e4/self.yts, "s":-1.0e3, "H":10.0, "C":0.01}
self.output_ub = {"u":1.0e4/self.yts, "v":1.0e4/self.yts, "s":2.5e3, "H":2000.0, "C":1.0e4}

# default scaling: data to 1 in [SI]
self.data_weights = {"u":1.0e-8*self.yts**2.0, "v":1.0e-8*self.yts**2.0, "s":1.0e-6, "H":1.0e-6, "C":1.0e-8}

# names of the residuals
_EQUATION_TYPE = 'SSA'
def __init__(self, param_dict={}):
# load necessary constants
Constants.__init__(self)
super().__init__(param_dict)

def set_default(self):
self.input = ['x', 'y']
self.output = ['u', 'v', 's', 'H', 'C']
self.output_lb = [-1.0e4/self.yts, -1.0e4/self.yts, -1.0e3, 10.0, 0.01]
self.output_ub = [ 1.0e4/self.yts, 1.0e4/self.yts, 2.5e3, 2000.0, 1.0e4]
self.data_weights = [1.0e-8*self.yts**2.0, 1.0e-8*self.yts**2.0, 1.0e-6, 1.0e-6, 1.0e-8]
self.residuals = ["fSSA1", "fSSA2"]

# default weights
self.pde_weights = [1.0e-10, 1.0e-10]

# update from the input parameters
self.update_parameters()
# scalar variables: name:value
self.scalar_variables = {
'n': 3.0, # exponent of Glen's flow law
'B':1.26802073401e+08 # -8 degree C, cuffey
}

class SSA(EquationBase): #{{{
""" SSA on 2D problem with uniform B
"""
def __init__(self, parameters=SSAEquationParameter()):
super().__init__(parameters)

def pde(self, nn_input_var, nn_output_var):
""" residual of SSA 2D PDEs
Expand Down Expand Up @@ -84,38 +83,40 @@ def pde(self, nn_input_var, nn_output_var):

return [f1, f2] #}}}

class MOLHOEquationParameter(EquationParameter, Constants):
""" default parameters for MOLHO
"""
_EQUATION_TYPE = 'MOLHO'
def __init__(self, param_dict={}):
# load necessary constants
Constants.__init__(self)
super().__init__(param_dict)

def set_default(self):
self.input = ['x', 'y']
self.output = ['u', 'v', 'u_base', 'v_base', 's', 'H', 'C']
self.output_lb = [-1.0e4/self.yts, -1.0e4/self.yts, -1.0e4/self.yts, -1.0e4/self.yts, -1.0e3, 10.0, 0.01]
self.output_ub = [ 1.0e4/self.yts, 1.0e4/self.yts, 1.0e4/self.yts, 1.0e4/self.yts, 2.5e3, 2000.0, 1.0e4]
self.data_weights = [1.0e-8*self.yts**2.0, 1.0e-8*self.yts**2.0, 1.0e-8*self.yts**2.0, 1.0e-8*self.yts**2.0, 1.0e-6, 1.0e-6, 1.0e-8]
self.residuals = ["fMOLHO 1", "fMOLHO 2", "fMOLHO base 1", "fMOLHO base 2"]
self.pde_weights = [1.0e-10, 1.0e-10, 1.0e-10, 1.0e-10]

# scalar variables: name:value
self.scalar_variables = {
'n': 3.0, # exponent of Glen's flow law
'B':1.26802073401e+08 # -8 degree C, cuffey
}

class MOLHO(EquationBase): #{{{
""" MOLHO on 2D problem with uniform B
"""
def __init__(self, parameters=EquationParameter()):
super().__init__(parameters)
# viscosity
self.B = 1.26802073401e+08 # -8 degree C, cuffey
self.n = 3.0

# Dict of input and output used in this model, and their component id
self.local_input_var = {"x":0, "y":1}
self.local_output_var = {"u":0, "v":1, "u_base":2, "v_base":3, "s":4, "H":5, "C":6}

# default lower and upper bounds of the output in [SI] unit
self.output_lb = {"u":-1.0e4/self.yts, "v":-1.0e4/self.yts, "u_base":-1.0e4/self.yts, "v_base":-1.0e4/self.yts, "s":-1.0e3, "H":10.0, "C":0.01}
self.output_ub = {"u":1.0e4/self.yts, "v":1.0e4/self.yts, "u_base":1.0e4/self.yts, "v_base":1.0e4/self.yts, "s":2.5e3, "H":2000.0, "C":1.0e4}

# default scaling: data to 1 in [SI]
self.data_weights = {"u":1.0e-8*self.yts**2.0, "v":1.0e-8*self.yts**2.0, "u_base":1.0e-8*self.yts**2.0, "v_base":1.0e-8*self.yts**2.0, "s":1.0e-6, "H":1.0e-6, "C":1.0e-8}

# residual names
self.residuals = ["fMOLHO 1", "fMOLHO 2", "fMOLHO base 1", "fMOLHO base 2"]

self.pde_weights = [1.0e-10, 1.0e-10, 1.0e-10, 1.0e-10]

# gauss points for integration
self.constants = {"gauss_x":[0.5, 0.23076534494715845, 0.7692346550528415, 0.04691007703066802, 0.9530899229693319],
"gauss_weights":[0.5688888888888889,0.4786286704993665,0.4786286704993665,0.2369268850561891,0.2369268850561891]}

# update from the input parameters
self.update_parameters()

def pde(self, nn_input_var, nn_output_var):
""" residual of MOLHO 2D PDEs
Args:
Expand Down
Loading

0 comments on commit 7eaf924

Please sign in to comment.