Skip to content

Commit

Permalink
Fix issue with black images.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous authored and yusing committed Jun 17, 2024
1 parent 0f2d14e commit 505b204
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tensorrt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(self, engine_path):
self.engine = runtime.deserialize_cuda_engine(f.read())
self.context = self.engine.create_execution_context()
self.dtype = torch.float16
self.stream = torch.cuda.Stream()

def set_bindings_shape(self, inputs, split_batch):
for k in inputs:
Expand Down Expand Up @@ -91,12 +90,13 @@ def __call__(self, x, timesteps, context, y=None, control=None, transformer_opti
dtype=trt_datatype_to_torch(self.engine.get_tensor_dtype(output_binding_name)))
model_inputs_converted[output_binding_name] = out

stream = torch.cuda.default_stream(x.device)
for i in range(curr_split_batch):
for k in model_inputs_converted:
x = model_inputs_converted[k]
self.context.set_tensor_address(k, x[(x.shape[0] // curr_split_batch) * i:].data_ptr())
self.context.execute_async_v3(stream_handle=self.stream.cuda_stream)
self.stream.synchronize()
self.context.execute_async_v3(stream_handle=stream.cuda_stream)
stream.synchronize()
return out

def load_state_dict(self, sd, strict=False):
Expand Down

0 comments on commit 505b204

Please sign in to comment.