Skip to content

Commit

Permalink
nn can load predefined B
Browse files Browse the repository at this point in the history
  • Loading branch information
enigne committed Nov 5, 2024
1 parent e840fa0 commit 901efa1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pinnicle/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def __init__(self, parameters=NNParameter()):
if self.parameters.is_input_scaling():
if self.parameters.fft :
print(f"add Fourier feature transform to input transform")
self.B = bkd.as_tensor(np.random.normal(0.0, self.parameters.sigma, [len(self.parameters.input_variables), self.parameters.num_fourier_feature]))
if self.parameters.B is not None:
self.B = bkd.as_tensor(self.parameters.B)
else:
self.B = bkd.as_tensor(np.random.normal(0.0, self.parameters.sigma, [len(self.parameters.input_variables), self.parameters.num_fourier_feature]))
def wrapper(x):
"""a wrapper function to add fourier feature transform to the input
"""
Expand Down
6 changes: 6 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def test_input_fft_nn():
z = y**2
assert np.all(abs(z[:,1:10]+z[:,11:20]) <= 1.0+np.finfo(float).eps)

hp['B'] = np.array([[1,2,3]])
hp['num_fourier_feature'] = 3
d = NNParameter(hp)
p = pinn.nn.FNN(d)
assert np.all(hp['B'] == bkd.to_numpy(p.B))

def test_input_scale_nn():
hp={}
hp['input_variables'] = ['x']
Expand Down

0 comments on commit 901efa1

Please sign in to comment.