Skip to content

Commit

Permalink
[ATLAS] Refact TNN ATLAS for Multiple Model
Browse files Browse the repository at this point in the history
  • Loading branch information
doxutx committed Apr 12, 2024
1 parent 7cea315 commit 3bf8a7b
Show file tree
Hide file tree
Showing 21 changed files with 1,219 additions and 1,142 deletions.
112 changes: 61 additions & 51 deletions source/tnn/device/atlas/atlas_blob_converter.cc

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions source/tnn/device/atlas/atlas_blob_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class AtlasBlobConverterAcc : public BlobConverterAcc {
virtual Status ConvertFromMatAsync(Mat& mat, MatConvertParam param, void* command_queue = NULL);

private:
Status ConvertFromMatAsyncWithoutAipp(Mat& mat, MatConvertParam param, AtlasCommandQueue* atlas_cmd_queue);
Status ConvertFromMatAsyncWithStaticAipp(Mat& mat, MatConvertParam param, AtlasCommandQueue* atlas_cmd_queue);
Status ConvertFromMatAsyncWithDynamicAipp(Mat& mat, MatConvertParam param, AtlasCommandQueue* atlas_cmd_queue);
Status ConvertFromMatAsyncWithoutAipp(Mat& mat, MatConvertParam param, const aclrtStream& aclrt_stream);
Status ConvertFromMatAsyncWithStaticAipp(Mat& mat, MatConvertParam param, const aclrtStream& aclrt_stream);
Status ConvertFromMatAsyncWithDynamicAipp(Mat& mat, MatConvertParam param, const aclrtStream& aclrt_stream);

bool NeedDoScaleBias(MatConvertParam& param);
Status AtlasMemoryCopyAsync(void* dst, void* src, DeviceType mat_device_type, int bytes, void* stream,
Expand All @@ -55,7 +55,7 @@ class AtlasBlobConverterAcc : public BlobConverterAcc {
AippType aipp_type_ = AIPP_NONE;
int aipp_mat_batchsize_ = 0;
size_t dynamic_aipp_index_ = 0;
AtlasModelInfo model_info_;
std::shared_ptr<AtlasOMModelInfo> om_model_info_;
};

} // namespace TNN_NS
Expand Down
5 changes: 0 additions & 5 deletions source/tnn/device/atlas/atlas_common_types.cc

This file was deleted.

68 changes: 39 additions & 29 deletions source/tnn/device/atlas/atlas_common_types.h
Original file line number Diff line number Diff line change
@@ -1,53 +1,63 @@
// Copyright 2019 Tencent. All Rights Reserved
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef TNN_SOURCE_DEVICE_ATLAS_ATLAS_COMMON_TYPES_H_
#define TNN_SOURCE_DEVICE_ATLAS_ATLAS_COMMON_TYPES_H_

#include <climits>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>

#include "acl/acl.h"
#include "tnn/core/blob.h"
#include "tnn/core/macro.h"

///////////////////////
#include <iostream>
///////////////////////

namespace TNN_NS {

enum ImageTypeT {
IMAGE_TYPE_RAW = -1,
IMAGE_TYPE_NV12 = 0,
IMAGE_TYPE_JPEG,
IMAGE_TYPE_PNG,
IMAGE_TYPE_BMP,
IMAGE_TYPE_TIFF,
IMAGE_TYPE_VIDEO = 100
enum class AtlasOmModelDynamicMode {
Static = 0,
DynamicBatch = 1,
DynamicHW = 2,
GenericDynamic = 3, // New Dynamic Mode, convert by input_shape_range or input_shape without dynamic dim/hw specified.
};

struct AtlasModelConfig {
std::string om_str = "";
bool is_path = false;
};
struct AtlasOMModelInfo {
aclmdlDesc* model_desc = nullptr;
uint32_t model_id = INT_MAX;
aclmdlDataset* input_dataset = nullptr;
aclrtContext aclrt_context = nullptr;

struct DimInfo {
uint32_t batch = 0;
uint32_t channel = 0;
uint32_t height = 0;
uint32_t width = 0;
};
size_t memory_size = 0;
size_t weight_size = 0;

struct AtlasCommandQueue {
void* context;
void* stream;
};
// Dynamic Input
AtlasOmModelDynamicMode dynamic_mode;
std::unordered_set<std::string> generic_dynamic_input_names;

struct AtlasModelInfo {
aclmdlDesc* model_desc = nullptr;
uint32_t model_id = 0;
aclmdlDataset* input_dataset = nullptr;
bool has_aipp = false;
aclAippInputFormat aipp_input_format = ACL_AIPP_RESERVED;
// AIPP Input
std::map<std::string, aclAippInputFormat> aipp_input_format_map;
};

extern std::map<Blob*, std::shared_ptr<AtlasOMModelInfo>> global_blob_om_model_info_map;
extern std::map<aclrtStream, aclrtContext> global_stream_context_map;

} // namespace TNN_NS

#endif // TNN_SOURCE_DEVICE_ATLAS_ATLAS_COMMON_TYPES_H_
110 changes: 110 additions & 0 deletions source/tnn/device/atlas/atlas_context.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "tnn/device/atlas/atlas_context.h"

namespace TNN_NS {

AtlasContext::~AtlasContext() {
// Aclrt Stream is created and maintained by AtlasNetwork
// Do not Destroy aclrtStream HERE.
//if (this->aclrt_stream_ != nullptr) {
// ret = aclrtDestroyStream(this->aclrt_stream_);
// this->aclrt_stream_ = nullptr;
//}
}

Status AtlasContext::Setup(int device_id) {
this->device_id_ = device_id;
return TNN_OK;
}

Status AtlasContext::LoadLibrary(std::vector<std::string> path) {
return TNN_OK;
}

Status AtlasContext::GetCommandQueue(void** command_queue) {
// Reshape Model For different Model Types
if (this->model_type_ == MODEL_TYPE_TORCHSCRIPT) {
LOGE("Fail to GetCommandQueue, MODEL_TYPE_TORCHSCRIPT not supported YET.\n");
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to GetCommandQueue, MODEL_TYPE_TORCHSCRIPT not supported YET");
} else if (this->model_type_ == MODEL_TYPE_TNN || this->model_type_ == MODEL_TYPE_RAPIDNET) {
LOGE("Fail to GetCommandQueue, MODEL_TYPE_TNN not supported YET.\n");
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to GetCommandQueue, MODEL_TYPE_TNN not supported YET");
} else if (this->model_type_ == MODEL_TYPE_ATLAS) {
*command_queue = this->aclrt_stream_;
} else {
LOGE("Fail to GetCommandQueue, model type not supported.\n");
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to GetCommandQueue, model type not supported");
}

return TNN_OK;
}

Status AtlasContext::SetCommandQueue(void* command_queue) {
return TNN_OK;
}

Status AtlasContext::ShareCommandQueue(Context* context) {
return TNN_OK;
}

Status AtlasContext::OnInstanceForwardBegin() {
return TNN_OK;
}

Status AtlasContext::OnInstanceForwardEnd() {
return TNN_OK;
}

Status AtlasContext::Synchronize() {
if (model_type_ == MODEL_TYPE_TNN || model_type_ == MODEL_TYPE_RAPIDNET ||
model_type_ == MODEL_TYPE_ATLAS) {
aclError acl_ret = aclrtSynchronizeStream(this->aclrt_stream_);
if (acl_ret != ACL_ERROR_NONE) {
LOGE("before forward synchronize stream failed\n");
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "before forward synchronize stream failed");
}
}
return TNN_OK;
}

aclrtStream& AtlasContext::GetAclrtStream() {
return this->aclrt_stream_;
}

void AtlasContext::SetAclrtStream(const aclrtStream& stream) {
this->aclrt_stream_ = stream;
}

Status AtlasContext::CreateAclrtStream() {
// Create aclrt Stream
aclError acl_ret = aclrtCreateStream(&aclrt_stream_);
if (acl_ret != ACL_ERROR_NONE) {
LOGE("acl create stream failed (acl error code: %d)\n", acl_ret);
}
return TNN_OK;
}

ModelType& AtlasContext::GetModelType() {
return this->model_type_;
}

void AtlasContext::SetModelType(ModelType model_type) {
this->model_type_ = model_type;
}



} // namespace TNN_NS
80 changes: 80 additions & 0 deletions source/tnn/device/atlas/atlas_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef TNN_SOURCE_TNN_DEVICE_ATLAS_ATLAS_CONTEXT_H_
#define TNN_SOURCE_TNN_DEVICE_ATLAS_ATLAS_CONTEXT_H_

#include "tnn/core/context.h"
#include "tnn/device/atlas/atlas_common_types.h"
#include "tnn/interpreter/raw_buffer.h"

namespace TNN_NS {

class AtlasContext : public Context {
public:
// @brief deconstructor
~AtlasContext();

// @brief setup with specified device id
Status Setup(int device_id);

// @brief load library
virtual Status LoadLibrary(std::vector<std::string> path) override;

// @brief get tnn command queue
// @param command_queue device command queue for forward
virtual Status GetCommandQueue(void** command_queue) override;

// @brief set tnn command queue
// @param command_queue device command queue for forward
virtual Status SetCommandQueue(void* command_queue) override;

// @brief share tnn command queue to another context
virtual Status ShareCommandQueue(Context* context);

// @brief before instance forward
virtual Status OnInstanceForwardBegin() override;

// @brief after instance forward
virtual Status OnInstanceForwardEnd() override;

// @brief wait for jobs in the current context to complete
virtual Status Synchronize() override;

// @brief get Atlas stream
aclrtStream& GetAclrtStream();

// @brief set Atlas stream
void SetAclrtStream(const aclrtStream& stream);

// @brief create Atlas stream
Status CreateAclrtStream();

// @brief get ModelType
ModelType& GetModelType();

// @brief set ModelType
void SetModelType(ModelType model_type);

private:
ModelType model_type_;
int device_id_ = INT_MAX;

// ACL Runtime Related
aclrtStream aclrt_stream_ = nullptr;
};

} // namespace TNN_NS;

#endif // TNN_SOURCE_TNN_DEVICE_ATLAS_ATLAS_CONTEXT_H_
18 changes: 14 additions & 4 deletions source/tnn/device/atlas/atlas_device.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
Expand All @@ -12,8 +12,9 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "tnn/device/atlas/atlas_device.h"
#include "acl/ops/acl_dvpp.h"
#include "tnn/device/atlas/atlas_context.h"
#include "tnn/device/atlas/atlas_device.h"
#include "tnn/utils/blob_memory_size_utils.h"
#include "tnn/utils/dims_vector_utils.h"

Expand Down Expand Up @@ -107,8 +108,17 @@ AbstractLayerAcc* AtlasDevice::CreateLayerAcc(LayerType type) {
return nullptr;
}

Context* AtlasDevice::CreateContext(int) {
return nullptr;
Context* AtlasDevice::CreateContext(int device_id) {
auto context = new AtlasContext();

Status ret = context->Setup(device_id);
if (ret != TNN_OK) {
LOGE("Cuda context setup failed.");
delete context;
return NULL;
}

return context;
}

NetworkType AtlasDevice::ConvertAutoNetworkType() {
Expand Down
Loading

0 comments on commit 3bf8a7b

Please sign in to comment.