Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorRT IErrorRecorder Implementation #54

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorrt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import comfy.model_management

import tensorrt as trt
from .tensorrt_error_recorder import TrTErrorRecorder
import folder_paths
from tqdm import tqdm

Expand Down Expand Up @@ -284,6 +285,7 @@ def forward(self, x, timesteps, context, y=None):
# TRT conversion starts here
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
builder.error_recorder = TrTErrorRecorder()

network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
Expand Down
47 changes: 47 additions & 0 deletions tensorrt_error_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
import threading

import tensorrt as trt

# TensorRT Python API Docs: https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/ErrorRecorder.html
#
# NOTE: TensorRT does not attempt to marshall errors across threads, so IErrorRecorder implementations must be thread-safe to allow for the possibility of errors occuring on multiple threads.
class TrTErrorRecorder(trt.IErrorRecorder):
def __init__(self):
self.lock = threading.Lock()
self.error_list = []
super().__init__()

def clear(self):
with self.lock:
self.error_list = []

def get_error_code(self, error_index):
with self.lock:
if error_index >= len(self.error_list) or error_index < 0:
raise IndexError(f'Invalid error index "{error_index}"')

return self.error_list[error_index]['error_code']

def get_error_desc(self, error_index):
with self.lock:
if error_index >= len(self.error_list) or error_index < 0:
raise IndexError(f'Invalid error index "{error_index}"')

return self.error_list[error_index]['error_desc']

def has_overflowed(self):
return False

def num_errors(self):
with self.lock:
return len(self.error_list)

def report_error(self, error_code, error_desc):
logging.error(f"TensorRT has encountered an error. ErrorCode: {error_code}. ErrorDesc: {error_desc}")

with self.lock:
self.error_list.append({'error_code': error_code, 'error_desc': error_desc})

# TODO Future: return True for errors we consider 'fatal', which hints to TensorRT to stop execution.
return False
19 changes: 19 additions & 0 deletions tensorrt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
[os.path.join(folder_paths.models_dir, "tensorrt")], {".engine"})

import tensorrt as trt
from .tensorrt_error_recorder import TrTErrorRecorder

trt.init_libnvinfer_plugins(None, "")

logger = trt.Logger(trt.Logger.INFO)
runtime = trt.Runtime(logger)
runtime.error_recorder = TrTErrorRecorder()

# Is there a function that already exists for this?
def trt_datatype_to_torch(datatype):
Expand All @@ -35,11 +37,28 @@ def trt_datatype_to_torch(datatype):
elif datatype == trt.bfloat16:
return torch.bfloat16

def check_for_trt_errors(runtime):
num_deserialize_errors = runtime.error_recorder.num_errors()
if num_deserialize_errors == 0:
return

error_string = ''
for error_index in range(num_deserialize_errors):
if error_index > 0:
error_string += "\n"
error_string += runtime.error_recorder.get_error_desc(error_index)
runtime.error_recorder.clear()
raise RuntimeError(f'Failed to deserialize TensorRT engine: {error_string}')
maedtb marked this conversation as resolved.
Show resolved Hide resolved

class TrTUnet:
def __init__(self, engine_path):
with open(engine_path, "rb") as f:
self.engine = runtime.deserialize_cuda_engine(f.read())
check_for_trt_errors(runtime)

self.context = self.engine.create_execution_context()
check_for_trt_errors(runtime)

self.dtype = torch.float16

def set_bindings_shape(self, inputs, split_batch):
Expand Down