From baff1f6175f6bda40a2cc3f52b14f6a380178e9a Mon Sep 17 00:00:00 2001 From: Dong Xu Date: Tue, 21 May 2024 20:48:06 +0800 Subject: [PATCH] [ATLAS] Fix TNN Atlas Bugs --- source/tnn/device/atlas/atlas_context.cc | 11 ++++++----- source/tnn/device/atlas/atlas_context.h | 8 +++++--- source/tnn/device/atlas/atlas_device.cc | 9 +-------- source/tnn/device/atlas/atlas_network.cc | 12 +++++++++++- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/source/tnn/device/atlas/atlas_context.cc b/source/tnn/device/atlas/atlas_context.cc index 145e5deec..efa100211 100644 --- a/source/tnn/device/atlas/atlas_context.cc +++ b/source/tnn/device/atlas/atlas_context.cc @@ -25,11 +25,6 @@ AtlasContext::~AtlasContext() { //} } -Status AtlasContext::Setup(int device_id) { - this->device_id_ = device_id; - return TNN_OK; -} - Status AtlasContext::LoadLibrary(std::vector path) { return TNN_OK; } @@ -105,6 +100,12 @@ void AtlasContext::SetModelType(ModelType model_type) { this->model_type_ = model_type; } +void AtlasContext::SetDeviceId(int device_id) { + this->device_id_ = device_id; +} +int AtlasContext::GetDeviceId() { + return this->device_id_; +} } // namespace TNN_NS diff --git a/source/tnn/device/atlas/atlas_context.h b/source/tnn/device/atlas/atlas_context.h index f57cb7e0a..346362a96 100644 --- a/source/tnn/device/atlas/atlas_context.h +++ b/source/tnn/device/atlas/atlas_context.h @@ -26,9 +26,6 @@ class AtlasContext : public Context { // @brief deconstructor ~AtlasContext(); - // @brief setup with specified device id - Status Setup(int device_id); - // @brief load library virtual Status LoadLibrary(std::vector path) override; @@ -66,6 +63,11 @@ class AtlasContext : public Context { // @brief set ModelType void SetModelType(ModelType model_type); + + // @brief set specific device id + void SetDeviceId(int device_id); + + int GetDeviceId(); private: ModelType model_type_; diff --git a/source/tnn/device/atlas/atlas_device.cc b/source/tnn/device/atlas/atlas_device.cc index 4f57e6d46..e70d67ac2 100644 --- a/source/tnn/device/atlas/atlas_device.cc +++ b/source/tnn/device/atlas/atlas_device.cc @@ -110,14 +110,7 @@ AbstractLayerAcc* AtlasDevice::CreateLayerAcc(LayerType type) { Context* AtlasDevice::CreateContext(int device_id) { auto context = new AtlasContext(); - - Status ret = context->Setup(device_id); - if (ret != TNN_OK) { - LOGE("Cuda context setup failed."); - delete context; - return NULL; - } - + context->SetDeviceId(device_id); return context; } diff --git a/source/tnn/device/atlas/atlas_network.cc b/source/tnn/device/atlas/atlas_network.cc index c23fc84ae..dc5c078aa 100644 --- a/source/tnn/device/atlas/atlas_network.cc +++ b/source/tnn/device/atlas/atlas_network.cc @@ -69,8 +69,9 @@ AtlasNetwork::~AtlasNetwork() { if (acl_ret != ACL_ERROR_NONE) { LOGE("unload model failed, modelId is %u\n", this->om_model_info_->model_id); } + this->om_model_info_->model_id = INT_MAX; } - + if (nullptr != this->om_model_info_->model_desc) { (void)aclmdlDestroyDesc(this->om_model_info_->model_desc); this->om_model_info_->model_desc = nullptr; @@ -115,6 +116,15 @@ AtlasNetwork::~AtlasNetwork() { this->om_model_weight_ptr_ = nullptr; this->om_model_info_->weight_size = 0; } + + // Destroy aclrt Device() + if (tnn_atlas_context->GetDeviceId() != INT_MAX) { + LOGD("Reset aclrt Device.\n"); + acl_ret = aclrtResetDevice(tnn_atlas_context->GetDeviceId()); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("TNN ATLAS Network: aclrtResetDevice() failed\n"); + } + } } // Call DeInit() of DefaultNetwork