Skip to content

Commit

Permalink
bug fixed for dummy data_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
enigne committed Jun 17, 2024
1 parent 1fb834b commit c518a69
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pinnicle/physics/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def set_default(self):
def update(self):
""" set all the weights to 1, and load all the lb and ub is not given
"""
self.data_weights = [1.0 for ou in self.output]
if not self.data_weights:
self.data_weights = [1.0 for ou in self.output]
if not self.output_lb:
self.output_lb = [self.variable_lb[k] for k in self.output]
if not self.output_ub:
Expand Down
19 changes: 18 additions & 1 deletion tests/test_parameters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import pinnicle as pinn
from pinnicle.parameter import DataParameter, SingleDataParameter, NNParameter, DomainParameter, PhysicsParameter, Parameters, EquationParameter, TrainingParameter
from pinnicle.physics import SSAEquationParameter
from pinnicle.physics import SSAEquationParameter, DummyEquationParameter

yts = 3600*24*365.0

Expand Down Expand Up @@ -83,6 +83,7 @@ def test_equation_parameters():
p = EquationParameter(SSA)
assert p.input == SSA["input"]
assert p.output == SSA["output"]
assert p.data_weights == SSA["data_weights"]

p = SSAEquationParameter(SSA)
assert p.scalar_variables['n'] == 3.0
Expand Down Expand Up @@ -110,6 +111,22 @@ def test_equation_parameters():
with pytest.raises(Exception):
p = Parameters(hp)

def test_dummy_equation_parameters():
DUMMY = {}
DUMMY["input"] = ["x", "y"]
DUMMY["output"] = ["u", "v", "s", "H", "C"]
DUMMY["output_lb"] = [-1.0e4/yts, -1.0e4/yts, -1.0e3, 10.0, 0.01]
DUMMY["output_ub"] = [ 1.0e4/yts, 1.0e4/yts, 2.5e3, 2.0e3, 1.0e4]

p = DummyEquationParameter(DUMMY)
assert p.input == DUMMY["input"]
assert p.output == DUMMY["output"]
assert p.data_weights == [1.0]*5

DUMMY["data_weights"] = [1.0e-8*yts**2.0, 1.0e-8*yts**2.0, 1.0e-6, 1.0e-6, 1.0e-8]
p = DummyEquationParameter(DUMMY)
assert p.data_weights == DUMMY["data_weights"]

def test_training_parameters():
hp = {}
p = TrainingParameter(hp)
Expand Down

0 comments on commit c518a69

Please sign in to comment.