Skip to content

Commit

Permalink
add B in the nn parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
enigne committed Nov 5, 2024
1 parent 93f833f commit e840fa0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pinnicle/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def set_default(self):
self.fft = False
self.num_fourier_feature = 10
self.sigma = 1.0
self.B = None

# parallel neural network
self.is_parallel = False
Expand All @@ -207,6 +208,9 @@ def check_consistency(self):
if self.fft:
if self.input_size != self.num_fourier_feature*2:
raise ValueError("'input_size' does not match the number of fourier feature")
if self.B is not None:
if self.B.shape[1] != self.num_fourier_feature:
raise ValueError("Number of columns of 'B' matrix does not match the number of fourier feature")
else:
# input size of nn equals to dependent in physics
if self.input_size != len(self.input_variables):
Expand Down
12 changes: 11 additions & 1 deletion tests/test_parameters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import numpy as np
import pinnicle as pinn
from pinnicle.parameter import DataParameter, SingleDataParameter, NNParameter, DomainParameter, PhysicsParameter, Parameters, EquationParameter, TrainingParameter
from pinnicle.physics import SSAEquationParameter, DummyEquationParameter
Expand Down Expand Up @@ -67,17 +68,26 @@ def test_nn_parameter():
d.input_lb = 1
d.input_ub = 10
assert d.is_input_scaling()

assert not d.is_output_scaling()

d.output_lb = 1
d.output_ub = 10
assert d.is_output_scaling()

d = NNParameter({"num_neurons":[1,2,3]})
assert d.num_layers == 3
assert d.input_size == 0

d = NNParameter({"fft":True})
assert d.input_size == 2*d.num_fourier_feature
assert d.is_input_scaling()
assert d.B is None

d = NNParameter({"fft":True, "num_fourier_feature":4, "B":np.array([[1,2,3,4]])})
assert d.B is not None
with pytest.raises(Exception):
d = NNParameter({"fft":True, "num_fourier_feature":4, "B":np.array([[1,2]])})
d = NNParameter({"fft":True, "num_fourier_feature":4, "B":1})

def test_parameters():
p = Parameters()
Expand Down

0 comments on commit e840fa0

Please sign in to comment.