Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: output format argument #2043

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 129 additions & 121 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from utils.add_nms import RegisterNMS

if __name__ == '__main__':
default_out_formats = ('torchscript', 'coreml', 'torchscript-lite', 'onnx')
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolor-csp-c.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
Expand All @@ -34,10 +35,13 @@
parser.add_argument('--include-nms', action='store_true', help='export end2end onnx')
parser.add_argument('--fp16', action='store_true', help='CoreML FP16 half-precision export')
parser.add_argument('--int8', action='store_true', help='CoreML INT8 quantization')
parser.add_argument('--out-format', action='append', choices=default_out_formats, default=[],
dest='out_formats', help='output format. Can be specified multiple times. Default: all')
opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
opt.dynamic = opt.dynamic and not opt.end2end
opt.dynamic = False if opt.dynamic_batch else opt.dynamic
opt.out_formats = tuple(opt.out_formats) if opt.out_formats else default_out_formats
print(opt)
set_logging()
t = time.time()
Expand Down Expand Up @@ -71,135 +75,139 @@
y = None

# TorchScript export
try:
print('\nStarting TorchScript export with torch %s...' % torch.__version__)
f = opt.weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img, strict=False)
ts.save(f)
print('TorchScript export success, saved as %s' % f)
except Exception as e:
print('TorchScript export failure: %s' % e)
if 'torchscript' in opt.out_formats:
try:
print('\nStarting TorchScript export with torch %s...' % torch.__version__)
f = opt.weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img, strict=False)
ts.save(f)
print('TorchScript export success, saved as %s' % f)
except Exception as e:
print('TorchScript export failure: %s' % e)

# CoreML export
try:
import coremltools as ct

print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
# convert model from torchscript and apply pixel scaling as per detect.py
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
bits, mode = (8, 'kmeans_lut') if opt.int8 else (16, 'linear') if opt.fp16 else (32, None)
if bits < 32:
if sys.platform.lower() == 'darwin': # quantization only supported on macOS
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
else:
print('quantization only supported on macOS, skipping...')

f = opt.weights.replace('.pt', '.mlmodel') # filename
ct_model.save(f)
print('CoreML export success, saved as %s' % f)
except Exception as e:
print('CoreML export failure: %s' % e)
if 'coreml' in opt.out_formats:
try:
import coremltools as ct

print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
# convert model from torchscript and apply pixel scaling as per detect.py
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
bits, mode = (8, 'kmeans_lut') if opt.int8 else (16, 'linear') if opt.fp16 else (32, None)
if bits < 32:
if sys.platform.lower() == 'darwin': # quantization only supported on macOS
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
else:
print('quantization only supported on macOS, skipping...')

f = opt.weights.replace('.pt', '.mlmodel') # filename
ct_model.save(f)
print('CoreML export success, saved as %s' % f)
except Exception as e:
print('CoreML export failure: %s' % e)

# TorchScript-Lite export
try:
print('\nStarting TorchScript-Lite export with torch %s...' % torch.__version__)
f = opt.weights.replace('.pt', '.torchscript.ptl') # filename
tsl = torch.jit.trace(model, img, strict=False)
tsl = optimize_for_mobile(tsl)
tsl._save_for_lite_interpreter(f)
print('TorchScript-Lite export success, saved as %s' % f)
except Exception as e:
print('TorchScript-Lite export failure: %s' % e)
if 'torchscript-lite' in opt.out_formats:
try:
print('\nStarting TorchScript-Lite export with torch %s...' % torch.__version__)
f = opt.weights.replace('.pt', '.torchscript.ptl') # filename
tsl = torch.jit.trace(model, img, strict=False)
tsl = optimize_for_mobile(tsl)
tsl._save_for_lite_interpreter(f)
print('TorchScript-Lite export success, saved as %s' % f)
except Exception as e:
print('TorchScript-Lite export failure: %s' % e)

# ONNX export
try:
import onnx

print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = opt.weights.replace('.pt', '.onnx') # filename
model.eval()
output_names = ['classes', 'boxes'] if y is None else ['output']
dynamic_axes = None
if opt.dynamic:
dynamic_axes = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
'output': {0: 'batch', 2: 'y', 3: 'x'}}
if opt.dynamic_batch:
opt.batch_size = 'batch'
dynamic_axes = {
'images': {
0: 'batch',
}, }
if opt.end2end and opt.max_wh is None:
output_axes = {
'num_dets': {0: 'batch'},
'det_boxes': {0: 'batch'},
'det_scores': {0: 'batch'},
'det_classes': {0: 'batch'},
}
else:
output_axes = {
'output': {0: 'batch'},
}
dynamic_axes.update(output_axes)
if opt.grid:
if opt.end2end:
print('\nStarting export end2end onnx model for %s...' % 'TensorRT' if opt.max_wh is None else 'onnxruntime')
model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device,len(labels))
if 'onnx' in opt.out_formats:
try:
import onnx

print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = opt.weights.replace('.pt', '.onnx') # filename
model.eval()
output_names = ['classes', 'boxes'] if y is None else ['output']
dynamic_axes = None
if opt.dynamic:
dynamic_axes = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
'output': {0: 'batch', 2: 'y', 3: 'x'}}
if opt.dynamic_batch:
opt.batch_size = 'batch'
dynamic_axes = {
'images': {
0: 'batch',
}, }
if opt.end2end and opt.max_wh is None:
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
shapes = [opt.batch_size, 1, opt.batch_size, opt.topk_all, 4,
opt.batch_size, opt.topk_all, opt.batch_size, opt.topk_all]
output_axes = {
'num_dets': {0: 'batch'},
'det_boxes': {0: 'batch'},
'det_scores': {0: 'batch'},
'det_classes': {0: 'batch'},
}
else:
output_axes = {
'output': {0: 'batch'},
}
dynamic_axes.update(output_axes)
if opt.grid:
if opt.end2end:
print('\nStarting export end2end onnx model for %s...' % 'TensorRT' if opt.max_wh is None else 'onnxruntime')
model = End2End(model, opt.topk_all, opt.iou_thres, opt.conf_thres, opt.max_wh, device, len(labels))
if opt.end2end and opt.max_wh is None:
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
shapes = [opt.batch_size, 1, opt.batch_size, opt.topk_all, 4,
opt.batch_size, opt.topk_all, opt.batch_size, opt.topk_all]
else:
output_names = ['output']
else:
output_names = ['output']
else:
model.model[-1].concat = True

torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=output_names,
dynamic_axes=dynamic_axes)

# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model

if opt.end2end and opt.max_wh is None:
for i in onnx_model.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))

# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model

# # Metadata
# d = {'stride': int(max(model.stride))}
# for k, v in d.items():
# meta = onnx_model.metadata_props.add()
# meta.key, meta.value = k, str(v)
# onnx.save(onnx_model, f)

if opt.simplify:
try:
import onnxsim

print('\nStarting to simplify ONNX...')
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print(f'Simplifier failure: {e}')

# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
onnx.save(onnx_model,f)
print('ONNX export success, saved as %s' % f)

if opt.include_nms:
print('Registering NMS plugin for ONNX...')
mo = RegisterNMS(f)
mo.register_nms()
mo.save(f)

except Exception as e:
print('ONNX export failure: %s' % e)
model.model[-1].concat = True

torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=output_names,
dynamic_axes=dynamic_axes)

# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model

if opt.end2end and opt.max_wh is None:
for i in onnx_model.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))

# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model

# # Metadata
# d = {'stride': int(max(model.stride))}
# for k, v in d.items():
# meta = onnx_model.metadata_props.add()
# meta.key, meta.value = k, str(v)
# onnx.save(onnx_model, f)

if opt.simplify:
try:
import onnxsim

print('\nStarting to simplify ONNX...')
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print(f'Simplifier failure: {e}')

# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
onnx.save(onnx_model,f)
print('ONNX export success, saved as %s' % f)

if opt.include_nms:
print('Registering NMS plugin for ONNX...')
mo = RegisterNMS(f)
mo.register_nms()
mo.save(f)

except Exception as e:
print('ONNX export failure: %s' % e)

# Finish
print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))