Skip to content

Commit

Permalink
Removed last warning on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tonegas committed Jan 30, 2025
1 parent d707f1d commit 00fc178
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_train_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_recurrent_shuffle(self):
dataset = {'x': [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], 'target': [21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40]}
test.loadData(name='dataset', source=dataset)

test.trainModel(train_dataset='dataset', optimizer='SGD', lr=1, num_of_epochs=1, train_batch_size=4, prediction_samples=1, step=1, shuffle_data=True)
test.trainModel(train_dataset='dataset', optimizer='SGD', lr=0.01, num_of_epochs=1, train_batch_size=4, prediction_samples=1, step=1, shuffle_data=True)
self.assertListEqual([[[4.0]], [[1.0]], [[9.0]], [[18.0]]], test.internals['inout_0_0']['XY']['x'])
self.assertListEqual([[[25.0]], [[22.0]], [[30.0]], [[39.0]]], test.internals['inout_0_0']['XY']['target'])
self.assertListEqual([[[26.0]], [[23.0]], [[31.0]], [[40.0]]], test.internals['inout_0_1']['XY']['target'])
Expand All @@ -65,8 +65,8 @@ def test_recurrent_shuffle(self):
self.assertListEqual([[[36.0]], [[23.0]], [[38.0]], [[35.0]]], test.internals['inout_2_0']['XY']['target'])
self.assertListEqual([[[37.0]], [[24.0]], [[39.0]], [[36.0]]], test.internals['inout_2_1']['XY']['target'])

test.trainModel(train_dataset='dataset', optimizer='SGD', lr=1, num_of_epochs=1, train_batch_size=2,
prediction_samples=2, step=0, shuffle_data=True)
test.trainModel(train_dataset='dataset', optimizer='SGD', lr=0.01, num_of_epochs=1, train_batch_size=2,
prediction_samples=2, step=0, shuffle_data=True)
# ( number_samples - window_size - prediction_samples )// (batch_size + step=0) * (predictoin_samples+1)
self.assertEqual((20-1-2)//2*3, len(test.internals.keys()))

Expand Down

0 comments on commit 00fc178

Please sign in to comment.