Skip to content

Commit

Permalink
update test data
Browse files Browse the repository at this point in the history
  • Loading branch information
enigne committed Nov 26, 2024
1 parent 91e741c commit 469e753
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
Binary file modified examples/dataset/Helheim_fastflow.mat
Binary file not shown.
24 changes: 13 additions & 11 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_ISSMmdData():

hp = {}
hp["data_path"] = path
hp["data_size"] = {"u":4000, "v":4000, "s":4000, "H":4000, "C":None}
hp["data_size"] = {"u":4000, "v":4000, "s":4000, "H":4000, "C":None, "a":500}
p = SingleDataParameter(hp)
data_loader = ISSMmdData(p)
data_loader.load_data()
Expand All @@ -22,12 +22,13 @@ def test_ISSMmdData():
assert(data_loader.X['v'].shape == (4000,2))
assert(data_loader.sol['s'].shape == (4000,1))
assert(data_loader.X['H'].shape == (4000,2))
assert(data_loader.sol['C'].shape == (564,1))
assert(data_loader.sol['C'].shape == (278,1))
assert(data_loader.sol['a'].shape == (500,1))

iice = data_loader.get_ice_indices()
assert iice[0].shape == (23049,)
assert iice[0].shape == (11874,)
icoord = data_loader.get_ice_coordinates()
assert icoord.shape == (23049, 2)
assert icoord.shape == (11874, 2)

def test_ISSMmdData_plot():
filename = "Helheim_fastflow.mat"
Expand Down Expand Up @@ -71,7 +72,7 @@ def test_Data():
assert(data_loader.X['v'].shape == (4000,2))
assert(data_loader.sol['s'].shape == (4000,1))
assert(data_loader.X['H'].shape == (4000,2))
assert(data_loader.sol['C'].shape == (564,1))
assert(data_loader.sol['C'].shape == (278,1))

def test_Data_multiple():
filename = "Helheim_fastflow.mat"
Expand All @@ -84,7 +85,7 @@ def test_Data_multiple():
issm["data_size"] = {"u":4000, "v":4000, "s":4000, "H":4000, "C":None}
issm2 = {}
issm2["data_path"] = path
issm2["data_size"] = {"u":400, "v":None, "s":1000, "C":1000}
issm2["data_size"] = {"u":400, "v":None, "s":1000, "C":1000, "a":1000}
issm2["default_time"] = 1

hp = {}
Expand All @@ -96,24 +97,25 @@ def test_Data_multiple():
data_loader.prepare_training_data()

assert(data_loader.sol['u'].shape == (4400,1))
assert(data_loader.X['v'].shape == (4564,2))
assert(data_loader.X['v'].shape == (4278,2))
assert(data_loader.sol['s'].shape == (5000,1))
assert(data_loader.X['H'].shape == (4000,2))
assert(data_loader.sol['C'].shape == (1564,1))
assert(data_loader.sol['C'].shape == (1278,1))
assert(data_loader.sol['a'].shape == (1000,1))

icoord = data_loader.get_ice_coordinates()
assert icoord.shape == (23049*2, 2)
assert icoord.shape == (11874*2, 2)

p = DataParameter(hp)
data_loader = Data(p)
data_loader.load_data()
data_loader.prepare_training_data(transient=True, default_time=10)
assert(data_loader.sol['v'].shape == (4564,1))
assert(data_loader.sol['v'].shape == (4278,1))
assert(data_loader.X['u'].shape == (4400,3))
assert(data_loader.X['u'][1,2] == 10)
assert(data_loader.X['u'][-1,2] == 1)
icoord = data_loader.get_ice_coordinates()
assert icoord.shape == (23049*2, 2)
assert icoord.shape == (11874*2, 2)

def test_MatData():
filename = "flightTracks.mat"
Expand Down

0 comments on commit 469e753

Please sign in to comment.