From 901efa17e2ab867b79ac400138cdd282c8bf6ae8 Mon Sep 17 00:00:00 2001 From: Cheng Gong Date: Tue, 5 Nov 2024 10:55:36 -0500 Subject: [PATCH] nn can load predefined B --- pinnicle/nn/nn.py | 5 ++++- tests/test_nn.py | 6 ++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pinnicle/nn/nn.py b/pinnicle/nn/nn.py index 3ed01c5..355d705 100644 --- a/pinnicle/nn/nn.py +++ b/pinnicle/nn/nn.py @@ -22,7 +22,10 @@ def __init__(self, parameters=NNParameter()): if self.parameters.is_input_scaling(): if self.parameters.fft : print(f"add Fourier feature transform to input transform") - self.B = bkd.as_tensor(np.random.normal(0.0, self.parameters.sigma, [len(self.parameters.input_variables), self.parameters.num_fourier_feature])) + if self.parameters.B is not None: + self.B = bkd.as_tensor(self.parameters.B) + else: + self.B = bkd.as_tensor(np.random.normal(0.0, self.parameters.sigma, [len(self.parameters.input_variables), self.parameters.num_fourier_feature])) def wrapper(x): """a wrapper function to add fourier feature transform to the input """ diff --git a/tests/test_nn.py b/tests/test_nn.py index 1055526..7399051 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -55,6 +55,12 @@ def test_input_fft_nn(): z = y**2 assert np.all(abs(z[:,1:10]+z[:,11:20]) <= 1.0+np.finfo(float).eps) + hp['B'] = np.array([[1,2,3]]) + hp['num_fourier_feature'] = 3 + d = NNParameter(hp) + p = pinn.nn.FNN(d) + assert np.all(hp['B'] == bkd.to_numpy(p.B)) + def test_input_scale_nn(): hp={} hp['input_variables'] = ['x']