diff --git a/tasks/plot_tcn_model.py b/tasks/plot_tcn_model.py index 27feb16..f4e6c3d 100644 --- a/tasks/plot_tcn_model.py +++ b/tasks/plot_tcn_model.py @@ -10,7 +10,7 @@ inputs = tf.keras.layers.Input(shape=input_shape, name='input') tcn_out = TCN(nb_filters=64, kernel_size=3, nb_stacks=1, activation='relu')(inputs) outputs = tf.keras.layers.Dense(forecast_horizon * num_features, activation='linear')(tcn_out) -outputs = tf.reshape(outputs, shape=(-1, forecast_horizon, num_features), name='ouput') +outputs = tf.keras.layers.Reshape((forecast_horizon, num_features), name='ouput')(outputs) model = tf.keras.Model(inputs=inputs, outputs=outputs) tf.keras.utils.plot_model(