From e37ee792a342356fc426d33236f6b7e46892dcd0 Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Tue, 14 Jan 2025 17:09:39 -0800 Subject: [PATCH] Implemented more generic asset tracking mechanism in saved model export. (#20758) This new implementation is in line with what was done in Keras 2. It tracks all `TrackableResource`s, and lookup tables and hashmaps are subclasses of `TrackableResource`. This allows users to attach preprocessing functions that are not solely based on Keras preprocessing layers. --- keras/src/export/saved_model.py | 9 ++---- keras/src/export/saved_model_test.py | 41 +++++++++++++++------------- requirements.txt | 2 +- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/keras/src/export/saved_model.py b/keras/src/export/saved_model.py index 1546e91aadf..f52c73e5461 100644 --- a/keras/src/export/saved_model.py +++ b/keras/src/export/saved_model.py @@ -550,18 +550,13 @@ def _filter_and_track_resources(self): # Next, track lookup tables. # Hopefully, one day this will be automated at the tf.function level. self._tf_trackable._misc_assets = [] - from keras.src.layers import IntegerLookup - from keras.src.layers import StringLookup - from keras.src.layers import TextVectorization + from tensorflow.saved_model.experimental import TrackableResource if hasattr(self, "_tracked"): for root in self._tracked: descendants = tf.train.TrackableView(root).descendants() for trackable in descendants: - if isinstance( - trackable, - (IntegerLookup, StringLookup, TextVectorization), - ): + if isinstance(trackable, TrackableResource): self._tf_trackable._misc_assets.append(trackable) diff --git a/keras/src/export/saved_model_test.py b/keras/src/export/saved_model_test.py index c5ad6c58690..8c2d1c16b8c 100644 --- a/keras/src/export/saved_model_test.py +++ b/keras/src/export/saved_model_test.py @@ -782,25 +782,28 @@ def test_multi_input_output_functional_model(self): } ) - # def test_model_with_lookup_table(self): - # tf.debugging.disable_traceback_filtering() - # temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - # text_vectorization = layers.TextVectorization() - # text_vectorization.adapt(["one two", "three four", "five six"]) - # model = models.Sequential( - # [ - # layers.Input(shape=(), dtype="string"), - # text_vectorization, - # layers.Embedding(10, 32), - # layers.Dense(1), - # ] - # ) - # ref_input = tf.convert_to_tensor(["one two three four"]) - # ref_output = model(ref_input) - - # saved_model.export_saved_model(model, temp_filepath) - # revived_model = tf.saved_model.load(temp_filepath) - # self.assertAllClose(ref_output, revived_model.serve(ref_input)) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="String lookup requires TensorFlow backend", + ) + def test_model_with_lookup_table(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + text_vectorization = layers.TextVectorization() + text_vectorization.adapt(["one two", "three four", "five six"]) + model = models.Sequential( + [ + layers.Input(shape=(), dtype="string"), + text_vectorization, + layers.Embedding(10, 32), + layers.Dense(1), + ] + ) + ref_input = tf.convert_to_tensor(["one two three four"]) + ref_output = model(ref_input) + + saved_model.export_saved_model(model, temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve(ref_input)) def test_track_multiple_layers(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") diff --git a/requirements.txt b/requirements.txt index 0973be4969a..f67d36f1596 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ tf2onnx --extra-index-url https://download.pytorch.org/whl/cpu torch>=2.1.0 torchvision>=0.16.0 -torch-xla +torch-xla;sys_platform != 'darwin' # Jax. jax[cpu]