From f496e5213f9fb0c5a1f7e49553ad0791d5186445 Mon Sep 17 00:00:00 2001 From: Surya2k1 Date: Tue, 21 Jan 2025 23:38:43 +0530 Subject: [PATCH 1/2] Fix issue with Masking layer with Tensor as `mask_value` --- keras/src/layers/core/masking.py | 2 ++ keras/src/layers/core/masking_test.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/keras/src/layers/core/masking.py b/keras/src/layers/core/masking.py index 64483aefb149..20a16055d683 100644 --- a/keras/src/layers/core/masking.py +++ b/keras/src/layers/core/masking.py @@ -45,6 +45,8 @@ class Masking(Layer): def __init__(self, mask_value=0.0, **kwargs): super().__init__(**kwargs) + if isinstance(mask_value, dict) and mask_value.get("config", None): + mask_value = mask_value["config"]["value"] self.mask_value = mask_value self.supports_masking = True self.built = True diff --git a/keras/src/layers/core/masking_test.py b/keras/src/layers/core/masking_test.py index b85bbeae2e7b..8cbfe450d2ea 100644 --- a/keras/src/layers/core/masking_test.py +++ b/keras/src/layers/core/masking_test.py @@ -4,6 +4,8 @@ from keras.src import layers from keras.src import models from keras.src import testing +from keras.src import ops +from keras.src.saving import load_model class MaskingTest(testing.TestCase): @@ -57,3 +59,22 @@ def call(self, inputs, mask=None): ] ) model(x) + + @pytest.mark.requires_trainable_backend + def test_masking_with_tensor(self): + model = models.Sequential( + [ + layers.Masking(mask_value=ops.convert_to_tensor([0.0])), + layers.LSTM(1), + ] + ) + x = np.array( + [ + [[0.0, 0.0], [1.0, 2.0], [0.0, 0.0]], + [[2.0, 2.0], [0.0, 0.0], [2.0, 1.0]], + ] + ) + model(x) + model.save("model.keras") + reload_model = load_model("model.keras") + reload_model(x) From 3428a6a833cdfe0982ec2ba94c57703b3481cda7 Mon Sep 17 00:00:00 2001 From: Surya2k1 Date: Wed, 22 Jan 2025 15:19:05 +0530 Subject: [PATCH 2/2] fix formatting --- keras/src/layers/core/masking.py | 4 +++- keras/src/layers/core/masking_test.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/keras/src/layers/core/masking.py b/keras/src/layers/core/masking.py index 20a16055d683..5d8a1179527b 100644 --- a/keras/src/layers/core/masking.py +++ b/keras/src/layers/core/masking.py @@ -2,6 +2,7 @@ from keras.src import ops from keras.src.api_export import keras_export from keras.src.layers.layer import Layer +from keras.src.saving.serialization_lib import deserialize_keras_object @keras_export("keras.layers.Masking") @@ -45,8 +46,9 @@ class Masking(Layer): def __init__(self, mask_value=0.0, **kwargs): super().__init__(**kwargs) + # `mask_value` can be a serialized tensor, hence verify it if isinstance(mask_value, dict) and mask_value.get("config", None): - mask_value = mask_value["config"]["value"] + mask_value = deserialize_keras_object(mask_value) self.mask_value = mask_value self.supports_masking = True self.built = True diff --git a/keras/src/layers/core/masking_test.py b/keras/src/layers/core/masking_test.py index 8cbfe450d2ea..0470b682c933 100644 --- a/keras/src/layers/core/masking_test.py +++ b/keras/src/layers/core/masking_test.py @@ -3,8 +3,8 @@ from keras.src import layers from keras.src import models -from keras.src import testing from keras.src import ops +from keras.src import testing from keras.src.saving import load_model