Skip to content

Commit

Permalink
fix time_distributed layer with mask and partial_batch_size
Browse files Browse the repository at this point in the history
  • Loading branch information
Surya2k1 committed Jan 15, 2025
1 parent e345cbd commit 2308854
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions keras/src/layers/rnn/time_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from keras.src.api_export import keras_export
from keras.src.layers.core.wrapper import Wrapper
from keras.src.layers.layer import Layer

import tensorflow as tf

@keras_export("keras.layers.TimeDistributed")
class TimeDistributed(Wrapper):
Expand Down Expand Up @@ -77,12 +77,20 @@ def call(self, inputs, training=None, mask=None):
batch_size = input_shape[0]
timesteps = input_shape[1]

if mask_shape is not None and mask_shape[:2] != (batch_size, timesteps):
raise ValueError(
"`TimeDistributed` Layer should be passed a `mask` of shape "
f"({batch_size}, {timesteps}, ...), "
f"received: mask.shape={mask_shape}"
)
if backend.backend() == "tensorflow" and not tf.executing_eagerly():
if mask_shape is not None and mask_shape[1:2] != (timesteps,):
raise ValueError(
"`TimeDistributed` Layer should be passed a `mask` of shape "
f"({batch_size}, {timesteps}, ...), "
f"received: mask.shape={mask_shape}"
)
else:
if mask_shape is not None and mask_shape[:2] != (batch_size, timesteps):
raise ValueError(
"`TimeDistributed` Layer should be passed a `mask` of shape "
f"({batch_size}, {timesteps}, ...), "
f"received: mask.shape={mask_shape}"
)

def time_distributed_transpose(data):
"""Swaps the timestep and batch dimensions of a tensor."""
Expand Down

0 comments on commit 2308854

Please sign in to comment.