Skip to content

Commit

Permalink
ADD: support of fp16 python inference backend
Browse files Browse the repository at this point in the history
Signed-off-by: Sayan Protasov <[email protected]>
  • Loading branch information
Sayan Protasov committed Nov 18, 2021
1 parent b5b9be2 commit 67f35a2
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion onnx_tensorrt/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def count_trailing_ones(vals):

class TensorRTBackendRep(BackendRep):
def __init__(self, model, device,
max_workspace_size=None, serialize_engine=False, verbose=False, **kwargs):
max_workspace_size=None, serialize_engine=False, verbose=False, fp16=False, **kwargs):
if not isinstance(device, Device):
device = Device(device)
self._set_device(device)
Expand All @@ -44,8 +44,12 @@ def __init__(self, model, device,
self.shape_tensor_inputs = []
self.serialize_engine = serialize_engine
self.verbose = verbose
self.fp16 = fp16
self.dynamic = False

if self.fp16:
self.config.set_flag(trt.BuilderFlag.FP16)

if self.verbose:
print(f'\nRunning {model.graph.name}...')
TRT_LOGGER.min_severity = trt.Logger.VERBOSE
Expand Down

0 comments on commit 67f35a2

Please sign in to comment.