Skip to content

Commit

Permalink
Merge pull request #25 from ISSMteam/add_PFNN
Browse files Browse the repository at this point in the history
add parallel FNN
  • Loading branch information
Cheng Gong authored May 13, 2024
2 parents 4e695b2 + 7b783a7 commit 9f36caf
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 4 deletions.
12 changes: 11 additions & 1 deletion pinnicle/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ def __init__(self, parameters=NNParameter()):
self.parameters = parameters

# create new NN
self.net = self.createFNN()
if self.parameters.is_parallel:
self.net = self.createPFNN()
else:
self.net = self.createFNN()

# apply transform
# by default, use min-max scale for the input
Expand All @@ -29,6 +32,13 @@ def createFNN(self):
"""
layer_size = [self.parameters.input_size] + [self.parameters.num_neurons] * self.parameters.num_layers + [self.parameters.output_size]
return dde.nn.FNN(layer_size, self.parameters.activation, self.parameters.initializer)

def createPFNN(self):
"""
create a parallel fully connected neural network
"""
layer_size = [self.parameters.input_size] + [[self.parameters.num_neurons]*self.parameters.output_size] * self.parameters.num_layers + [self.parameters.output_size]
return dde.nn.PFNN(layer_size, self.parameters.activation, self.parameters.initializer)

def _add_input_transform(self, func):
"""
Expand Down
3 changes: 3 additions & 0 deletions pinnicle/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def set_default(self):
self.activation = "tanh"
self.initializer = "Glorot uniform"

# parallel neural network
self.is_parallel = False

# scaling parameters
self.input_lb = None
self.input_ub = None
Expand Down
16 changes: 16 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,19 @@ def test_output_scale_nn():
x = np.linspace(-1.0, 1.0, 100)
assert np.all(p.net._output_transform(0, x) > d.output_lb - d.output_lb*np.finfo(float).eps)
assert np.all(p.net._output_transform(0, x) < d.output_ub + d.output_ub*np.finfo(float).eps)

def test_pfnn():
hp={}
hp['input_variables'] = ['x','y']
hp['output_variables'] = ['u', 'v','s']
hp['num_neurons'] = 4
hp['num_layers'] = 5
hp['is_parallel'] = False
d = NNParameter(hp)
p = pinn.nn.FNN(d)
assert len(p.net.layers) == 6
hp['is_parallel'] = True
d = NNParameter(hp)
p = pinn.nn.FNN(d)
assert len(p.net.layers) == 18

27 changes: 24 additions & 3 deletions tests/test_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
hp["activation"] = "tanh"
hp["initializer"] = "Glorot uniform"
hp["num_neurons"] = 10
hp["num_layers"] = 6
hp["num_layers"] = 4

# data
issm = {}
Expand Down Expand Up @@ -98,7 +98,29 @@ def test_save_and_load_setting(tmp_path):
experiment2 = pinn.PINN(loadFrom=tmp_path)
assert experiment.params.param_dict == experiment2.params.param_dict

#def test_train(tmp_path):
def test_train(tmp_path):
hp["is_save"] = False
hp["num_collocation_points"] = 100
issm["data_size"] = {"u":100, "v":100, "s":100, "H":100, "C":None, "vel":100}
hp["data"] = {"ISSM": issm}
experiment = pinn.PINN(params=hp)
experiment.compile()
experiment.train()
assert experiment.loss_names == ['fSSA1', 'fSSA2', 'u', 'v', 's', 'H', 'C', "vel log"]

def test_train_PFNN(tmp_path):
hp["is_parallel"] = True
hp["is_save"] = False
hp["num_collocation_points"] = 100
issm["data_size"] = {"u":100, "v":100, "s":100, "H":100, "C":None, "vel":100}
hp["data"] = {"ISSM": issm}
experiment = pinn.PINN(params=hp)
experiment.compile()
experiment.train()
assert experiment.loss_names == ['fSSA1', 'fSSA2', 'u', 'v', 's', 'H', 'C', "vel log"]
assert len(experiment.model.net.trainable_weights) == 50

#def test_save_train(tmp_path):
# hp["save_path"] = str(tmp_path)
# hp["is_save"] = True
# hp["num_collocation_points"] = 100
Expand Down Expand Up @@ -143,7 +165,6 @@ def test_only_callbacks(tmp_path):
assert callbacks is not None
assert len(callbacks) == 3


def test_plot(tmp_path):
hp["save_path"] = str(tmp_path)
hp["is_save"] = True
Expand Down

0 comments on commit 9f36caf

Please sign in to comment.