Skip to content

Commit

Permalink
fix test_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
enigne committed Jul 15, 2024
1 parent d2b89aa commit 64c61fc
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import numpy as np
from deepxde import backend
from deepxde.backend import backend_name
from pinnicle.utils import save_dict_to_json, load_dict_from_json, data_misfit, load_mat, down_sample_core, down_sample, slice_column
import pinnicle
from pinnicle.utils import save_dict_to_json, load_dict_from_json, data_misfit, load_mat, down_sample_core, down_sample

data = {"s":1, "v":[1, 2, 3]}

Expand All @@ -26,8 +27,8 @@ def test_data_misfit():
def test_data_misfit_functions():
assert data_misfit.get("VEL_LOG") != None
assert data_misfit.get("MEAN_SQUARE_LOG") != None
assert data_misfit.get("VEL_LOG")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0
assert data_misfit.get("MEAN_SQUARE_LOG")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0
assert data_misfit.get("VEL_LOG")(backend.as_tensor([1.0]),backend.as_tensor([1.0])) == 0.0
assert data_misfit.get("MEAN_SQUARE_LOG")(backend.as_tensor([1.0]),backend.as_tensor([1.0])) == 0.0

def test_loadmat():
filename = "flightTracks.mat"
Expand Down Expand Up @@ -73,6 +74,9 @@ def test_down_sample():

def test_slice_column():
a = np.array([[1,2],[3,4]])
c = slice_column(a, 0, 1)
assert c.shape == (2,1)
assert c[1] == 3
c = pinnicle.utils.backends_specified.slice_column_tf(a, 1)
assert c.shape == (2, 1)
assert c[1] == 4
c = pinnicle.utils.backends_specified.slice_column_jax(a, 1)
assert c.shape == (1,)
assert c[0] == 2

0 comments on commit 64c61fc

Please sign in to comment.