Skip to content

Commit

Permalink
use mat73 if .mat is v7.3, otherwise load the data by scipy.io
Browse files Browse the repository at this point in the history
  • Loading branch information
enigne committed Apr 1, 2024
1 parent 46d4698 commit 31397e3
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 22 deletions.
5 changes: 2 additions & 3 deletions PINNICLE/modeldata/general_mat_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from . import DataBase
from ..parameter import SingleDataParameter
from ..physics import Constants
from ..utils import plot_dict_data
import mat73
from ..utils import plot_dict_data, load_mat
import numpy as np


Expand All @@ -28,7 +27,7 @@ def load_data(self):
""" load scatter data from a `.mat` file, return a dict with the required data
"""
# Reading matlab data
data = mat73.loadmat(self.parameters.data_path)
data = load_mat(self.parameters.data_path)

# x,y coordinates
self.X_dict['x'] = data['x']
Expand Down
5 changes: 2 additions & 3 deletions PINNICLE/modeldata/issm_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from . import DataBase
from ..parameter import SingleDataParameter
from ..physics import Constants
from ..utils import plot_dict_data
import mat73
from ..utils import plot_dict_data, load_mat
import numpy as np


Expand Down Expand Up @@ -41,7 +40,7 @@ def load_data(self):
""" load ISSM model from a `.mat` file
"""
# Reading matlab data
data = mat73.loadmat(self.parameters.data_path)
data = load_mat(self.parameters.data_path)
# get the model
md = data['md']
# create the output dict
Expand Down
12 changes: 11 additions & 1 deletion PINNICLE/utils/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os

import mat73
import scipy.io

def is_file_ext(path, ext):
""" check if a given path is ended by ext
Expand Down Expand Up @@ -39,3 +40,12 @@ def load_dict_from_json(path, filename):
else:
data = {}
return data

def load_mat(file):
""" load .mat file, if the file is in MATLAB 7.3 format use mat73.loadmat, otherwise use scipy.io.loadmat()
"""
try:
data = mat73.loadmat(file)
except TypeError:
data = scipy.io.loadmat(file)
return data
Binary file modified examples/dataset/flightTracks.mat
Binary file not shown.
Binary file added examples/dataset/flightTracks73.mat
Binary file not shown.
44 changes: 29 additions & 15 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
import PINNICLE as pinn
import pytest
import tensorflow as tf
import os
from PINNICLE.utils import save_dict_to_json, load_dict_from_json, data_misfit, load_mat

data = {"s":1, "v":[1, 2, 3]}


def test_save_and_load_dict(tmp_path):
pinn.utils.save_dict_to_json(data, tmp_path, "temp.json")
pinn.utils.save_dict_to_json(data, tmp_path, "noextension")
assert data == pinn.utils.load_dict_from_json(tmp_path, "temp.json")
assert data == pinn.utils.load_dict_from_json(tmp_path, "temp")
assert data == pinn.utils.load_dict_from_json(tmp_path, "noextension.json")
save_dict_to_json(data, tmp_path, "temp.json")
save_dict_to_json(data, tmp_path, "noextension")
assert data == load_dict_from_json(tmp_path, "temp.json")
assert data == load_dict_from_json(tmp_path, "temp")
assert data == load_dict_from_json(tmp_path, "noextension.json")

def test_data_misfit():
with pytest.raises(Exception):
pinn.utils.data_misfit.get("not defined")
assert pinn.utils.data_misfit.get("MSE") != None
assert pinn.utils.data_misfit.get("VEL_LOG") != None
assert pinn.utils.data_misfit.get("MEAN_SQUARE_LOG") != None
assert pinn.utils.data_misfit.get("MAPE") != None
data_misfit.get("not defined")
assert data_misfit.get("MSE") != None
assert data_misfit.get("VEL_LOG") != None
assert data_misfit.get("MEAN_SQUARE_LOG") != None
assert data_misfit.get("MAPE") != None

def test_data_misfit_functions():
assert pinn.utils.data_misfit.get("MSE")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0
assert pinn.utils.data_misfit.get("VEL_LOG")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0
assert pinn.utils.data_misfit.get("MEAN_SQUARE_LOG")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0
assert pinn.utils.data_misfit.get("MAPE")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0
assert data_misfit.get("MSE")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0
assert data_misfit.get("VEL_LOG")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0
assert data_misfit.get("MEAN_SQUARE_LOG")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0
assert data_misfit.get("MAPE")(tf.convert_to_tensor([1.0]),tf.convert_to_tensor([1.0])) == 0.0

def test_loadmat():
filename = "flightTracks.mat"
repoPath = os.path.dirname(__file__) + "/../examples/"
appDataPath = os.path.join(repoPath, "dataset")
path = os.path.join(appDataPath, filename)
assert load_mat(path)

filename = "flightTracks73.mat"
repoPath = os.path.dirname(__file__) + "/../examples/"
appDataPath = os.path.join(repoPath, "dataset")
path = os.path.join(appDataPath, filename)
assert load_mat(path)

0 comments on commit 31397e3

Please sign in to comment.