Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return of compute_output_shape #19259

Closed
innat opened this issue Mar 6, 2024 · 3 comments
Closed

return of compute_output_shape #19259

innat opened this issue Mar 6, 2024 · 3 comments
Assignees
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug

Comments

@innat
Copy link

innat commented Mar 6, 2024

The compute_output_shape return is changed. Following sample code, works in Keras 2 but breaks in Keras 3. The thing is, in the following code, as stated there, what are the difference of those returns.

class ReshapeLayer(keras.layers.Layer):
    def __init__(self, new_shape, **kwargs):
        super(ReshapeLayer, self).__init__(**kwargs)
        self.new_shape = new_shape
    def build(self, input_shape):
        super(ReshapeLayer, self).build(input_shape)
    def call(self, inputs):
        return tf.reshape(inputs, self.new_shape)
    def compute_output_shape(self, input_shape):
        # option 1 # works in keras 2 and 3.
        # return (input_shape[0],) + self.new_shape 

        # option 2 #  < ---- works in keras 2 but not 3
        return ( 
            input_shape[0],
            self.new_shape
        )
    
input_shape = (None, 10)  
new_shape = (5, 2)
model = keras.Sequential([
    keras.layers.Dense(20, input_shape=input_shape),
    ReshapeLayer(new_shape)
])
model.summary()

Also, In some official layers, for example multi-head-layer, or conv3d layer, and also from docstring, the tf.TensorShape is used in compute_output_shape func. But doesn't it bring requirement to have tensorflow for other backends or any equivalent keras.ops?

@sachinprasadhs sachinprasadhs added the keras-team-review-pending Pending review by a Keras team member. label Mar 6, 2024
@sampathweb
Copy link
Collaborator

Option 2 seems to return a Tuple of tuples instead of single value. So not sure if they are equivalent.

Also, the code references, you pointed to for tf.TensorShape don't exist in the main Keras Repo. It looks like its pointing to a forked repo that may not be updated. Can you check again?

@sampathweb sampathweb added stat:awaiting response from contributor and removed keras-team-review-pending Pending review by a Keras team member. labels Mar 11, 2024
@sampathweb sampathweb self-assigned this Mar 11, 2024
@innat
Copy link
Author

innat commented Mar 11, 2024

@sampathweb

Option 2 seems to return a Tuple of tuples instead of single value. So not sure if they are equivalent.

What I meant is their type (tuple). Regardless any approaches, option 2 used to work with keras 2. Now, in keras 3, I faced an issue with this approaches, specifically on tf and jax backend but successful on torch backend. Please, check this workflow.

# with tensorflow

        if dtype not in ALLOWED_DTYPES:
>           raise ValueError(f"Invalid dtype: {dtype}")
E           ValueError: Exception encountered when calling VideoSwinBasicLayer.call().
E           
E           Invalid dtype: <property object at 0x7fe799453e00>
E           
E           Arguments received by VideoSwinBasicLayer.call():
Eargs=('<KerasTensor shape=(None, 4, 56, 56, 96), dtype=float32, sparse=False, name=keras_tensor_14086>',)
Ekwargs=<class 'inspect._empty'>
# with jax

        if dtype not in ALLOWED_DTYPES:
>           raise ValueError(f"Invalid dtype: {dtype}")
E           ValueError: Exception encountered when calling VideoSwinBasicLayer.call().
E           
E           Invalid dtype: ArrayImpl
E           
E           Arguments received by VideoSwinBasicLayer.call():
Eargs=('<KerasTensor shape=(None, 4, 56, 56, 96), dtype=float32, sparse=False, name=keras_tensor_14075>',)
Ekwargs=<class 'inspect._empty'>

To address it, returning the ouput shape info wrapping around tf.TensorShape did work for all backend, like this. Not sure, if it's the ideal approach.

Also, the code references, you pointed to for tf.TensorShape don't exist in the main Keras Repo. It looks like its pointing to a forked repo that may not be updated.

You're right, my bad. I should double check.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug
Projects
None yet
Development

No branches or pull requests

3 participants