Skip to content

Commit

Permalink
Remove workspace parameter from generate_trt_engine.py (nvidia-holosc…
Browse files Browse the repository at this point in the history
…an#532)

* Remove workspace parameter as it is no longer supported in tensorrt:24.08

Signed-off-by: Victor Chang <[email protected]>

* Fix lint error

Signed-off-by: Victor Chang <[email protected]>

---------

Signed-off-by: Victor Chang <[email protected]>
  • Loading branch information
mocsharp authored Oct 11, 2024
1 parent 3f35dd7 commit e7eefc8
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions utilities/generate_trt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,19 @@ def get_gpu_info():
return gpu_info_list


def convert_onnx(input_file, output_file, workspace_size, fp16_enabled):
def convert_onnx(input_file, output_file, fp16_enabled):
"""
Convert ONNX model to TensorRT engine using trtexec command.
Args:
input_file (str): Path to the input ONNX file.
output_file (str): Path to the output engine file.
workspace_size (int): Workspace size for TensorRT engine (in MiB).
fp16_enabled (bool): Flag indicating whether to enable FP16 mode.
Returns:
None
"""
trtexec_cmd = (
f"trtexec --onnx='{input_file}' --workspace={workspace_size} --saveEngine='{output_file}'"
)
trtexec_cmd = f"trtexec --onnx='{input_file}' --saveEngine='{output_file}'"
if fp16_enabled:
trtexec_cmd += " --fp16"
status = os.system(trtexec_cmd)
Expand All @@ -80,7 +77,6 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", help="Input file path")
parser.add_argument("--output", help="Output file path")
parser.add_argument("--workspace", type=int, default=2048, help="Workspace size in MiB")
parser.add_argument("--fp16", default=False, action="store_true", help="Enable FP16 mode")
parser.add_argument(
"--force", default=False, action="store_true", help="Force overwrite existing output file"
Expand All @@ -103,10 +99,9 @@ def main():
file=sys.stderr,
)
return
workspace_size = args.workspace
fp16_enabled = args.fp16

result_code = convert_onnx(input_file, output_file, workspace_size, fp16_enabled)
result_code = convert_onnx(input_file, output_file, fp16_enabled)
if result_code != 0:
print("TensorRT engine generation failed.", file=sys.stderr)
print("Exit code is ", result_code)
Expand Down

0 comments on commit e7eefc8

Please sign in to comment.