Skip to content

Commit

Permalink
Implemented more generic asset tracking mechanism in saved model expo…
Browse files Browse the repository at this point in the history
…rt. (#20758)

This new implementation is in line with what was done in Keras 2. It tracks all `TrackableResource`s, and lookup tables and hashmaps are subclasses of `TrackableResource`.

This allows users to attach preprocessing functions that are not solely based on Keras preprocessing layers.
  • Loading branch information
hertschuh authored Jan 15, 2025
1 parent 57c94f3 commit e37ee79
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 27 deletions.
9 changes: 2 additions & 7 deletions keras/src/export/saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,18 +550,13 @@ def _filter_and_track_resources(self):
# Next, track lookup tables.
# Hopefully, one day this will be automated at the tf.function level.
self._tf_trackable._misc_assets = []
from keras.src.layers import IntegerLookup
from keras.src.layers import StringLookup
from keras.src.layers import TextVectorization
from tensorflow.saved_model.experimental import TrackableResource

if hasattr(self, "_tracked"):
for root in self._tracked:
descendants = tf.train.TrackableView(root).descendants()
for trackable in descendants:
if isinstance(
trackable,
(IntegerLookup, StringLookup, TextVectorization),
):
if isinstance(trackable, TrackableResource):
self._tf_trackable._misc_assets.append(trackable)


Expand Down
41 changes: 22 additions & 19 deletions keras/src/export/saved_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,25 +782,28 @@ def test_multi_input_output_functional_model(self):
}
)

# def test_model_with_lookup_table(self):
# tf.debugging.disable_traceback_filtering()
# temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
# text_vectorization = layers.TextVectorization()
# text_vectorization.adapt(["one two", "three four", "five six"])
# model = models.Sequential(
# [
# layers.Input(shape=(), dtype="string"),
# text_vectorization,
# layers.Embedding(10, 32),
# layers.Dense(1),
# ]
# )
# ref_input = tf.convert_to_tensor(["one two three four"])
# ref_output = model(ref_input)

# saved_model.export_saved_model(model, temp_filepath)
# revived_model = tf.saved_model.load(temp_filepath)
# self.assertAllClose(ref_output, revived_model.serve(ref_input))
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="String lookup requires TensorFlow backend",
)
def test_model_with_lookup_table(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
text_vectorization = layers.TextVectorization()
text_vectorization.adapt(["one two", "three four", "five six"])
model = models.Sequential(
[
layers.Input(shape=(), dtype="string"),
text_vectorization,
layers.Embedding(10, 32),
layers.Dense(1),
]
)
ref_input = tf.convert_to_tensor(["one two three four"])
ref_output = model(ref_input)

saved_model.export_saved_model(model, temp_filepath)
revived_model = tf.saved_model.load(temp_filepath)
self.assertAllClose(ref_output, revived_model.serve(ref_input))

def test_track_multiple_layers(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ tf2onnx
--extra-index-url https://download.pytorch.org/whl/cpu
torch>=2.1.0
torchvision>=0.16.0
torch-xla
torch-xla;sys_platform != 'darwin'

# Jax.
jax[cpu]
Expand Down

0 comments on commit e37ee79

Please sign in to comment.