Skip to content

Commit

Permalink
save B for fft automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
enigne committed Nov 5, 2024
1 parent 901efa1 commit c5649d8
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 4 deletions.
4 changes: 3 additions & 1 deletion pinnicle/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def check_consistency(self):
if self.input_size != self.num_fourier_feature*2:
raise ValueError("'input_size' does not match the number of fourier feature")
if self.B is not None:
if self.B.shape[1] != self.num_fourier_feature:
if not isinstance(self.B, list):
raise TypeError("'B' matrix need to be input in a list")
if len(self.B[0]) != self.num_fourier_feature:
raise ValueError("Number of columns of 'B' matrix does not match the number of fourier feature")

Check warning on line 215 in pinnicle/parameter.py

View check run for this annotation

Codecov / codecov/patch

pinnicle/parameter.py#L215

Added line #L215 was not covered by tests
else:
# input size of nn equals to dependent in physics
Expand Down
3 changes: 3 additions & 0 deletions pinnicle/pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def setup(self):
self._update_ub_lb_in_nn(self.model_data)
# define the neural network in use
self.nn = FNN(self.params.nn)
# save B if it is not defined by the user and generated by FFT
if (self.params.nn.B is None) and (self.params.nn.fft):
self.params.param_dict.update({"B": dde.backend.to_numpy(self.nn.B).tolist()})

# Step 7: setup the deepxde PINN model
self.model = dde.Model(self.dde_data, self.nn.net)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_input_fft_nn():
z = y**2
assert np.all(abs(z[:,1:10]+z[:,11:20]) <= 1.0+np.finfo(float).eps)

hp['B'] = np.array([[1,2,3]])
hp['B'] = [[1,2,3]]
hp['num_fourier_feature'] = 3
d = NNParameter(hp)
p = pinn.nn.FNN(d)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ def test_nn_parameter():
assert d.is_input_scaling()
assert d.B is None

d = NNParameter({"fft":True, "num_fourier_feature":4, "B":np.array([[1,2,3,4]])})
d = NNParameter({"fft":True, "num_fourier_feature":4, "B":[[1,2,3,4]]})
assert d.B is not None
with pytest.raises(Exception):
d = NNParameter({"fft":True, "num_fourier_feature":4, "B":np.array([[1,2]])})
d = NNParameter({"fft":True, "num_fourier_feature":4, "B":1})
d = NNParameter({"fft":True, "num_fourier_feature":4, "B":[[1,2]]})

def test_parameters():
p = Parameters()
Expand Down
18 changes: 18 additions & 0 deletions tests/test_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,24 @@ def test_train(tmp_path):
experiment.train()
assert experiment.loss_names == ['fSSA1', 'fSSA2', 'u', 'v', 's', 'H', 'C']

def test_fft_training(tmp_path):
hp_local = dict(hp)
hp_local['fft'] = True
hp_local["is_save"] = False
hp_local["num_collocation_points"] = 10
issm["data_size"] = {"u":10, "v":10, "s":10, "H":10, "C":None}
hp_local["data"] = {"ISSM": issm}
hp_local["equations"] = {"SSA":SSA}
experiment = pinn.PINN(params=hp_local)
experiment.save_setting(path=tmp_path)
assert experiment.params.param_dict == experiment.load_setting(path=tmp_path)
assert experiment.params.nn.B is None
assert os.path.isdir(f"{tmp_path}/pinn/")
experiment2 = pinn.PINN(loadFrom=tmp_path)
assert experiment.params.param_dict == experiment2.params.param_dict
assert len(experiment2.params.nn.B) == 2
assert len(experiment2.params.nn.B[1]) == 10

@pytest.mark.skipif(backend_name in ["jax"], reason="save model is not implemented in deepxde for jax")
def test_train_PFNN(tmp_path):
hp_local = dict(hp)
Expand Down

0 comments on commit c5649d8

Please sign in to comment.