-
Notifications
You must be signed in to change notification settings - Fork 771
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ATLAS] Refact TNN ATLAS for Multiple Model
- Loading branch information
Showing
21 changed files
with
1,219 additions
and
1,142 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,63 @@ | ||
// Copyright 2019 Tencent. All Rights Reserved | ||
// Tencent is pleased to support the open source community by making TNN available. | ||
// | ||
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | ||
// | ||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | ||
// in compliance with the License. You may obtain a copy of the License at | ||
// | ||
// https://opensource.org/licenses/BSD-3-Clause | ||
// | ||
// Unless required by applicable law or agreed to in writing, software distributed | ||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | ||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations under the License. | ||
|
||
#ifndef TNN_SOURCE_DEVICE_ATLAS_ATLAS_COMMON_TYPES_H_ | ||
#define TNN_SOURCE_DEVICE_ATLAS_ATLAS_COMMON_TYPES_H_ | ||
|
||
#include <climits> | ||
#include <map> | ||
#include <memory> | ||
#include <string> | ||
#include <unordered_set> | ||
|
||
#include "acl/acl.h" | ||
#include "tnn/core/blob.h" | ||
#include "tnn/core/macro.h" | ||
|
||
/////////////////////// | ||
#include <iostream> | ||
/////////////////////// | ||
|
||
namespace TNN_NS { | ||
|
||
enum ImageTypeT { | ||
IMAGE_TYPE_RAW = -1, | ||
IMAGE_TYPE_NV12 = 0, | ||
IMAGE_TYPE_JPEG, | ||
IMAGE_TYPE_PNG, | ||
IMAGE_TYPE_BMP, | ||
IMAGE_TYPE_TIFF, | ||
IMAGE_TYPE_VIDEO = 100 | ||
enum class AtlasOmModelDynamicMode { | ||
Static = 0, | ||
DynamicBatch = 1, | ||
DynamicHW = 2, | ||
GenericDynamic = 3, // New Dynamic Mode, convert by input_shape_range or input_shape without dynamic dim/hw specified. | ||
}; | ||
|
||
struct AtlasModelConfig { | ||
std::string om_str = ""; | ||
bool is_path = false; | ||
}; | ||
struct AtlasOMModelInfo { | ||
aclmdlDesc* model_desc = nullptr; | ||
uint32_t model_id = INT_MAX; | ||
aclmdlDataset* input_dataset = nullptr; | ||
aclrtContext aclrt_context = nullptr; | ||
|
||
struct DimInfo { | ||
uint32_t batch = 0; | ||
uint32_t channel = 0; | ||
uint32_t height = 0; | ||
uint32_t width = 0; | ||
}; | ||
size_t memory_size = 0; | ||
size_t weight_size = 0; | ||
|
||
struct AtlasCommandQueue { | ||
void* context; | ||
void* stream; | ||
}; | ||
// Dynamic Input | ||
AtlasOmModelDynamicMode dynamic_mode; | ||
std::unordered_set<std::string> generic_dynamic_input_names; | ||
|
||
struct AtlasModelInfo { | ||
aclmdlDesc* model_desc = nullptr; | ||
uint32_t model_id = 0; | ||
aclmdlDataset* input_dataset = nullptr; | ||
bool has_aipp = false; | ||
aclAippInputFormat aipp_input_format = ACL_AIPP_RESERVED; | ||
// AIPP Input | ||
std::map<std::string, aclAippInputFormat> aipp_input_format_map; | ||
}; | ||
|
||
extern std::map<Blob*, std::shared_ptr<AtlasOMModelInfo>> global_blob_om_model_info_map; | ||
extern std::map<aclrtStream, aclrtContext> global_stream_context_map; | ||
|
||
} // namespace TNN_NS | ||
|
||
#endif // TNN_SOURCE_DEVICE_ATLAS_ATLAS_COMMON_TYPES_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
// Tencent is pleased to support the open source community by making TNN available. | ||
// | ||
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | ||
// | ||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | ||
// in compliance with the License. You may obtain a copy of the License at | ||
// | ||
// https://opensource.org/licenses/BSD-3-Clause | ||
// | ||
// Unless required by applicable law or agreed to in writing, software distributed | ||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | ||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations under the License. | ||
|
||
#include "tnn/device/atlas/atlas_context.h" | ||
|
||
namespace TNN_NS { | ||
|
||
AtlasContext::~AtlasContext() { | ||
// Aclrt Stream is created and maintained by AtlasNetwork | ||
// Do not Destroy aclrtStream HERE. | ||
//if (this->aclrt_stream_ != nullptr) { | ||
// ret = aclrtDestroyStream(this->aclrt_stream_); | ||
// this->aclrt_stream_ = nullptr; | ||
//} | ||
} | ||
|
||
Status AtlasContext::Setup(int device_id) { | ||
this->device_id_ = device_id; | ||
return TNN_OK; | ||
} | ||
|
||
Status AtlasContext::LoadLibrary(std::vector<std::string> path) { | ||
return TNN_OK; | ||
} | ||
|
||
Status AtlasContext::GetCommandQueue(void** command_queue) { | ||
// Reshape Model For different Model Types | ||
if (this->model_type_ == MODEL_TYPE_TORCHSCRIPT) { | ||
LOGE("Fail to GetCommandQueue, MODEL_TYPE_TORCHSCRIPT not supported YET.\n"); | ||
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to GetCommandQueue, MODEL_TYPE_TORCHSCRIPT not supported YET"); | ||
} else if (this->model_type_ == MODEL_TYPE_TNN || this->model_type_ == MODEL_TYPE_RAPIDNET) { | ||
LOGE("Fail to GetCommandQueue, MODEL_TYPE_TNN not supported YET.\n"); | ||
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to GetCommandQueue, MODEL_TYPE_TNN not supported YET"); | ||
} else if (this->model_type_ == MODEL_TYPE_ATLAS) { | ||
*command_queue = this->aclrt_stream_; | ||
} else { | ||
LOGE("Fail to GetCommandQueue, model type not supported.\n"); | ||
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to GetCommandQueue, model type not supported"); | ||
} | ||
|
||
return TNN_OK; | ||
} | ||
|
||
Status AtlasContext::SetCommandQueue(void* command_queue) { | ||
return TNN_OK; | ||
} | ||
|
||
Status AtlasContext::ShareCommandQueue(Context* context) { | ||
return TNN_OK; | ||
} | ||
|
||
Status AtlasContext::OnInstanceForwardBegin() { | ||
return TNN_OK; | ||
} | ||
|
||
Status AtlasContext::OnInstanceForwardEnd() { | ||
return TNN_OK; | ||
} | ||
|
||
Status AtlasContext::Synchronize() { | ||
if (model_type_ == MODEL_TYPE_TNN || model_type_ == MODEL_TYPE_RAPIDNET || | ||
model_type_ == MODEL_TYPE_ATLAS) { | ||
aclError acl_ret = aclrtSynchronizeStream(this->aclrt_stream_); | ||
if (acl_ret != ACL_ERROR_NONE) { | ||
LOGE("before forward synchronize stream failed\n"); | ||
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "before forward synchronize stream failed"); | ||
} | ||
} | ||
return TNN_OK; | ||
} | ||
|
||
aclrtStream& AtlasContext::GetAclrtStream() { | ||
return this->aclrt_stream_; | ||
} | ||
|
||
void AtlasContext::SetAclrtStream(const aclrtStream& stream) { | ||
this->aclrt_stream_ = stream; | ||
} | ||
|
||
Status AtlasContext::CreateAclrtStream() { | ||
// Create aclrt Stream | ||
aclError acl_ret = aclrtCreateStream(&aclrt_stream_); | ||
if (acl_ret != ACL_ERROR_NONE) { | ||
LOGE("acl create stream failed (acl error code: %d)\n", acl_ret); | ||
} | ||
return TNN_OK; | ||
} | ||
|
||
ModelType& AtlasContext::GetModelType() { | ||
return this->model_type_; | ||
} | ||
|
||
void AtlasContext::SetModelType(ModelType model_type) { | ||
this->model_type_ = model_type; | ||
} | ||
|
||
|
||
|
||
} // namespace TNN_NS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
// Tencent is pleased to support the open source community by making TNN available. | ||
// | ||
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | ||
// | ||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | ||
// in compliance with the License. You may obtain a copy of the License at | ||
// | ||
// https://opensource.org/licenses/BSD-3-Clause | ||
// | ||
// Unless required by applicable law or agreed to in writing, software distributed | ||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | ||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations under the License. | ||
|
||
#ifndef TNN_SOURCE_TNN_DEVICE_ATLAS_ATLAS_CONTEXT_H_ | ||
#define TNN_SOURCE_TNN_DEVICE_ATLAS_ATLAS_CONTEXT_H_ | ||
|
||
#include "tnn/core/context.h" | ||
#include "tnn/device/atlas/atlas_common_types.h" | ||
#include "tnn/interpreter/raw_buffer.h" | ||
|
||
namespace TNN_NS { | ||
|
||
class AtlasContext : public Context { | ||
public: | ||
// @brief deconstructor | ||
~AtlasContext(); | ||
|
||
// @brief setup with specified device id | ||
Status Setup(int device_id); | ||
|
||
// @brief load library | ||
virtual Status LoadLibrary(std::vector<std::string> path) override; | ||
|
||
// @brief get tnn command queue | ||
// @param command_queue device command queue for forward | ||
virtual Status GetCommandQueue(void** command_queue) override; | ||
|
||
// @brief set tnn command queue | ||
// @param command_queue device command queue for forward | ||
virtual Status SetCommandQueue(void* command_queue) override; | ||
|
||
// @brief share tnn command queue to another context | ||
virtual Status ShareCommandQueue(Context* context); | ||
|
||
// @brief before instance forward | ||
virtual Status OnInstanceForwardBegin() override; | ||
|
||
// @brief after instance forward | ||
virtual Status OnInstanceForwardEnd() override; | ||
|
||
// @brief wait for jobs in the current context to complete | ||
virtual Status Synchronize() override; | ||
|
||
// @brief get Atlas stream | ||
aclrtStream& GetAclrtStream(); | ||
|
||
// @brief set Atlas stream | ||
void SetAclrtStream(const aclrtStream& stream); | ||
|
||
// @brief create Atlas stream | ||
Status CreateAclrtStream(); | ||
|
||
// @brief get ModelType | ||
ModelType& GetModelType(); | ||
|
||
// @brief set ModelType | ||
void SetModelType(ModelType model_type); | ||
|
||
private: | ||
ModelType model_type_; | ||
int device_id_ = INT_MAX; | ||
|
||
// ACL Runtime Related | ||
aclrtStream aclrt_stream_ = nullptr; | ||
}; | ||
|
||
} // namespace TNN_NS; | ||
|
||
#endif // TNN_SOURCE_TNN_DEVICE_ATLAS_ATLAS_CONTEXT_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.