From 64c61fceb91fff9554fe87b52d2e4552f7682015 Mon Sep 17 00:00:00 2001 From: Cheng Gong Date: Mon, 15 Jul 2024 17:17:54 -0400 Subject: [PATCH] fix test_utils --- tests/test_utils.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index e621a8a..9cb8002 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,8 @@ import numpy as np from deepxde import backend from deepxde.backend import backend_name -from pinnicle.utils import save_dict_to_json, load_dict_from_json, data_misfit, load_mat, down_sample_core, down_sample, slice_column +import pinnicle +from pinnicle.utils import save_dict_to_json, load_dict_from_json, data_misfit, load_mat, down_sample_core, down_sample data = {"s":1, "v":[1, 2, 3]} @@ -26,8 +27,8 @@ def test_data_misfit(): def test_data_misfit_functions(): assert data_misfit.get("VEL_LOG") != None assert data_misfit.get("MEAN_SQUARE_LOG") != None - assert data_misfit.get("VEL_LOG")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0 - assert data_misfit.get("MEAN_SQUARE_LOG")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0 + assert data_misfit.get("VEL_LOG")(backend.as_tensor([1.0]),backend.as_tensor([1.0])) == 0.0 + assert data_misfit.get("MEAN_SQUARE_LOG")(backend.as_tensor([1.0]),backend.as_tensor([1.0])) == 0.0 def test_loadmat(): filename = "flightTracks.mat" @@ -73,6 +74,9 @@ def test_down_sample(): def test_slice_column(): a = np.array([[1,2],[3,4]]) - c = slice_column(a, 0, 1) - assert c.shape == (2,1) - assert c[1] == 3 + c = pinnicle.utils.backends_specified.slice_column_tf(a, 1) + assert c.shape == (2, 1) + assert c[1] == 4 + c = pinnicle.utils.backends_specified.slice_column_jax(a, 1) + assert c.shape == (1,) + assert c[0] == 2