Skip to content

Commit

Permalink
Add support for quantizing models in onnx export
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Feb 5, 2025
1 parent d5d0a0a commit bba4d82
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions micro_sam/bioimageio/bioengine_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,14 @@ def export_image_encoder(
def export_onnx_model(
model_type: str,
output_root: Union[str, os.PathLike],
opset: int,
opset: int = 17,
export_name: Optional[str] = None,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
return_single_mask: bool = True,
gelu_approximate: bool = False,
use_stability_score: bool = False,
return_extra_metrics: bool = False,
quantize_model: bool = False,
) -> None:
"""Export SAM prompt encoder and mask decoder to onnx.
Expand All @@ -123,14 +124,16 @@ def export_onnx_model(
Args:
model_type: The SAM model type.
output_root: The output root directory where the exported model is saved.
opset: The ONNX opset version.
opset: The ONNX opset version. The recommended opset version is 17.
export_name: The name of the exported model.
checkpoint_path: Optional checkpoint for loading the SAM model.
return_single_mask: Whether the mask decoder returns a single or multiple masks.
gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend
does not have an efficient GeLU implementation.
use_stability_score: Whether to use the stability score instead of the predicted score.
return_extra_metrics: Whether to return a larger set of metrics.
quantize_model: Whether to also export a quantized version of the model.
This only works for onnxruntime < 1.17.
"""
if export_name is None:
export_name = model_type
Expand All @@ -155,9 +158,7 @@ def export_onnx_model(
if isinstance(m, torch.nn.GELU):
m.approximate = "tanh"

dynamic_axes = {
"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"},
}
dynamic_axes = {"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}}

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
Expand Down Expand Up @@ -202,6 +203,23 @@ def export_onnx_model(
_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")

# This requires onnxruntime < 1.17.
# See https://github.com/facebookresearch/segment-anything/issues/699#issuecomment-1984670808
if quantize_model:
assert onnxruntime_exists
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

quantized_path = os.path.join(weight_output_folder, "model_quantized.onnx")
quantize_dynamic(
model_input=weight_path,
model_output=quantized_path,
# optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)

config_output_path = os.path.join(output_folder, "config.pbtxt")
with open(config_output_path, "w") as f:
f.write(DECODER_CONFIG % name)
Expand Down

0 comments on commit bba4d82

Please sign in to comment.