diff --git a/onnx_tensorrt/backend.py b/onnx_tensorrt/backend.py index 6f2f9f6e..4f4fb031 100644 --- a/onnx_tensorrt/backend.py +++ b/onnx_tensorrt/backend.py @@ -32,7 +32,7 @@ def count_trailing_ones(vals): class TensorRTBackendRep(BackendRep): def __init__(self, model, device, - max_workspace_size=None, serialize_engine=False, verbose=False, **kwargs): + max_workspace_size=None, serialize_engine=False, verbose=False, fp16=False, **kwargs): if not isinstance(device, Device): device = Device(device) self._set_device(device) @@ -44,8 +44,12 @@ def __init__(self, model, device, self.shape_tensor_inputs = [] self.serialize_engine = serialize_engine self.verbose = verbose + self.fp16 = fp16 self.dynamic = False + if self.fp16: + self.config.set_flag(trt.BuilderFlag.FP16) + if self.verbose: print(f'\nRunning {model.graph.name}...') TRT_LOGGER.min_severity = trt.Logger.VERBOSE