Skip to content

Commit

Permalink
TFA deprecated
Browse files Browse the repository at this point in the history
  • Loading branch information
philipperemy committed Aug 13, 2024
1 parent f4409bd commit f8f5750
Showing 1 changed file with 0 additions and 22 deletions.
22 changes: 0 additions & 22 deletions tasks/tcn_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
from tensorflow.keras import Input
from tensorflow.keras import Model
from tensorflow.keras.models import Sequential

from tcn import TCN

Expand Down Expand Up @@ -100,27 +99,6 @@ def test_non_causal_time_dim_unknown_return_no_sequences(self):
r = predict_with_tcn(time_steps=None, padding='same', return_sequences=False)
self.assertListEqual([list(b.shape) for b in r], [[1, NB_FILTERS], [1, NB_FILTERS], [1, NB_FILTERS]])

def test_norms(self):
Sequential(layers=[TCN(input_shape=(20, 2), use_weight_norm=True)]).compile(optimizer='adam', loss='mse')
Sequential(layers=[TCN(input_shape=(20, 2), use_weight_norm=False)]).compile(optimizer='adam', loss='mse')
Sequential(layers=[TCN(input_shape=(20, 2), use_layer_norm=True)]).compile(optimizer='adam', loss='mse')
Sequential(layers=[TCN(input_shape=(20, 2), use_layer_norm=False)]).compile(optimizer='adam', loss='mse')
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=True)]).compile(optimizer='adam', loss='mse')
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=False)]).compile(optimizer='adam', loss='mse')
try:
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=True, use_weight_norm=True)]).compile(
optimizer='adam', loss='mse')
raise AssertionError('test failed.')
except ValueError:
pass
try:
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=True,
use_weight_norm=True, use_layer_norm=True)]).compile(
optimizer='adam', loss='mse')
raise AssertionError('test failed.')
except ValueError:
pass

def test_receptive_field(self):
self.assertEqual(37, TCN(kernel_size=3, dilations=(1, 3, 5), nb_stacks=1).receptive_field)
self.assertEqual(379, TCN(kernel_size=4, dilations=(1, 2, 4, 8, 16, 32), nb_stacks=1).receptive_field)
Expand Down

0 comments on commit f8f5750

Please sign in to comment.