Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

passing tags for use when loading TF saved_model #1024

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions coremltools/converters/_converters_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def convert(
compute_units=_ComputeUnit.ALL,
package_dir=None,
debug=False,
**kwargs
):
"""
Convert a TensorFlow or PyTorch model to the Core ML model format as either
Expand Down Expand Up @@ -363,6 +364,9 @@ def skip_real_div_ops(op):
- For Tensorflow conversion, it will cause to display extra logging
and visualizations.

Note that for TensorFlow SaveModel models with more than 1 tag set,
``tags: list[str]``, can be used to specify a set of tags.

Returns
-------

Expand Down Expand Up @@ -458,6 +462,7 @@ def skip_real_div_ops(op):
package_dir=package_dir,
debug=debug,
specification_version=specification_version,
**kwargs
)

if exact_target == 'milinternal':
Expand Down
9 changes: 5 additions & 4 deletions coremltools/converters/mil/frontend/tensorflow/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def load(self):
logging.info("Loading TensorFlow model '{}'".format(self.model))
outputs = self.kwargs.get("outputs", None)
output_names = get_output_names(outputs)
self._graph_def = self._graph_def_from_model(output_names)
tags = self.kwargs.get("tags", None)
self._graph_def = self._graph_def_from_model(output_names, tags)

if self._graph_def is not None and len(self._graph_def.node) == 0:
msg = "tf.Graph should have at least 1 node, Got empty graph."
Expand All @@ -88,7 +89,7 @@ def load(self):
return program

# @abstractmethod
def _graph_def_from_model(self, output_names=None):
def _graph_def_from_model(self, output_names=None, tags=None):
"""Load TensorFlow model into GraphDef. Overwrite for different TF versions."""
pass

Expand Down Expand Up @@ -139,7 +140,7 @@ def __init__(self, model, debug=False, **kwargs):
"""
TFLoader.__init__(self, model, debug, **kwargs)

def _graph_def_from_model(self, output_names=None):
def _graph_def_from_model(self, output_names=None, tags=None):
"""Overwrites TFLoader._graph_def_from_model()"""
msg = "Expected model format: [tf.Graph | .pb | SavedModel | tf.keras.Model | .h5], got {}"
if isinstance(self.model, tf.Graph) and hasattr(self.model, "as_graph_def"):
Expand Down Expand Up @@ -170,7 +171,7 @@ def _graph_def_from_model(self, output_names=None):
graph_def = self._from_tf_keras_model(self.model)
return self.extract_sub_graph(graph_def, output_names)
elif os.path.isdir(str(self.model)):
graph_def = self._from_saved_model(self.model)
graph_def = self._from_saved_model(self.model, tags=tags)
return self.extract_sub_graph(graph_def, output_names)
else:
raise NotImplementedError(msg.format(self.model))
Expand Down
8 changes: 4 additions & 4 deletions coremltools/converters/mil/frontend/tensorflow2/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, model, debug=False, **kwargs):
fuse_dilation_conv,
]

def _get_concrete_functions_and_graph_def(self):
def _get_concrete_functions_and_graph_def(self, tags=None):
msg = (
"Expected model format: [SavedModel | [concrete_function] | "
"tf.keras.Model | .h5], got {}"
Expand All @@ -120,7 +120,7 @@ def _get_concrete_functions_and_graph_def(self):
and (self.model.endswith(".h5") or self.model.endswith(".hdf5")):
cfs = self._concrete_fn_from_tf_keras_or_h5(self.model)
elif _os_path.isdir(self.model):
saved_model = _tf.saved_model.load(self.model)
saved_model = _tf.saved_model.load(self.model, tags=tags)
sv = saved_model.signatures.values()
cfs = sv if isinstance(sv, list) else list(sv)
else:
Expand All @@ -132,9 +132,9 @@ def _get_concrete_functions_and_graph_def(self):

return cfs, graph_def

def _graph_def_from_model(self, output_names=None):
def _graph_def_from_model(self, output_names=None, tags=None):
"""Overwrites TFLoader._graph_def_from_model()"""
_, graph_def = self._get_concrete_functions_and_graph_def()
_, graph_def = self._get_concrete_functions_and_graph_def(tags=tags)
return self.extract_sub_graph(graph_def, output_names)

def _tf_ssa_from_graph_def(self, fn_name="main"):
Expand Down
36 changes: 36 additions & 0 deletions coremltools/test/api/test_api_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,42 @@ def test_convert_from_saved_model_dir():
mlmodel = ct.convert("./saved_model")
mlmodel.save("./model.mlmodel")

@staticmethod
def test_convert_from_two_tags_saved_model_dir(tmpdir):
import tensorflow as tf
from tensorflow.compat.v1.saved_model import build_tensor_info
from tensorflow.compat.v1.saved_model import signature_constants
from tensorflow.compat.v1.saved_model import signature_def_utils

@tf.function
def add(a, b):
return a + b

c = add.get_concrete_function(tf.constant(21.0), tf.constant(21.0))

save_path = str(tmpdir)
builder = tf.compat.v1.saved_model.Builder(save_path)

with tf.compat.v1.Session(graph=c.graph) as sess:
tensor_info_a = build_tensor_info(c.graph.inputs[0])
tensor_info_b = build_tensor_info(c.graph.inputs[1])
tensor_info_y = build_tensor_info(c.graph.outputs[0])

prediction_signature = signature_def_utils.build_signature_def(
inputs={'a': tensor_info_a, 'b': tensor_info_b},
outputs={'output': tensor_info_y},
method_name=signature_constants.PREDICT_METHOD_NAME)

builder.add_meta_graph_and_variables(sess, ["serve"],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
prediction_signature,
})

builder.add_meta_graph(["serve", "tpu"])
builder.save()

ct.convert(save_path, source="tensorflow", tags=["serve"])

@staticmethod
def test_keras_custom_layer_model():
Expand Down