From bdfcbdb59a5483b6514001629869594a5f61ce4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Fatih=20C=C4=B1r=C4=B1t?= Date: Tue, 20 Dec 2022 00:56:23 +0300 Subject: [PATCH 1/2] fix(tensorrt): update tensorrt code of traffic_light_classifier MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: M. Fatih Cırıt --- .../utils/trt_common.cpp | 16 ++++++++++------ .../utils/trt_common.hpp | 2 -- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/perception/traffic_light_classifier/utils/trt_common.cpp b/perception/traffic_light_classifier/utils/trt_common.cpp index 4aae3e0ece5e1..2cba60e0d7b6c 100644 --- a/perception/traffic_light_classifier/utils/trt_common.cpp +++ b/perception/traffic_light_classifier/utils/trt_common.cpp @@ -75,8 +75,16 @@ void TrtCommon::setup() } context_ = UniquePtr(engine_->createExecutionContext()); - input_dims_ = engine_->getBindingDimensions(getInputBindingIndex()); - output_dims_ = engine_->getBindingDimensions(getOutputBindingIndex()); + +#if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500 + input_dims_ = engine_->getTensorShape(input_name_.c_str()); + output_dims_ = engine_->getTensorShape(output_name_.c_str()); +#else + // Deprecated since 8.5 + input_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(input_name_.c_str()); + output_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(output_name_.c_str()); +#endif + is_initialized_ = true; } @@ -155,8 +163,4 @@ int TrtCommon::getNumOutput() output_dims_.d, output_dims_.d + output_dims_.nbDims, 1, std::multiplies()); } -int TrtCommon::getInputBindingIndex() { return engine_->getBindingIndex(input_name_.c_str()); } - -int TrtCommon::getOutputBindingIndex() { return engine_->getBindingIndex(output_name_.c_str()); } - } // namespace Tn diff --git a/perception/traffic_light_classifier/utils/trt_common.hpp b/perception/traffic_light_classifier/utils/trt_common.hpp index 9577dfd2262a6..7fc3d3b3e46d9 100644 --- a/perception/traffic_light_classifier/utils/trt_common.hpp +++ b/perception/traffic_light_classifier/utils/trt_common.hpp @@ -118,8 +118,6 @@ class TrtCommon bool isInitialized(); int getNumInput(); int getNumOutput(); - int getInputBindingIndex(); - int getOutputBindingIndex(); UniquePtr context_; From 825ef165ae0abbf8ddbc303ea7a65e04a75055b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Fatih=20C=C4=B1r=C4=B1t?= Date: Tue, 20 Dec 2022 10:39:48 +0300 Subject: [PATCH 2/2] missing parentheses :( MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: M. Fatih Cırıt --- perception/traffic_light_classifier/utils/trt_common.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/perception/traffic_light_classifier/utils/trt_common.cpp b/perception/traffic_light_classifier/utils/trt_common.cpp index 2cba60e0d7b6c..adb2fbe037a31 100644 --- a/perception/traffic_light_classifier/utils/trt_common.cpp +++ b/perception/traffic_light_classifier/utils/trt_common.cpp @@ -81,8 +81,8 @@ void TrtCommon::setup() output_dims_ = engine_->getTensorShape(output_name_.c_str()); #else // Deprecated since 8.5 - input_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(input_name_.c_str()); - output_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(output_name_.c_str()); + input_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(input_name_.c_str())); + output_dims_ = engine_->getBindingDimensions(engine_->getBindingIndex(output_name_.c_str())); #endif is_initialized_ = true;