Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT] Support Multiple EP Context #23294

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3365,6 +3365,7 @@ void RegisterContribSchemas() {
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(bool)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
Expand Down
80 changes: 28 additions & 52 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ bool GraphHasCtxNode(const GraphViewer& graph_viewer) {
return false;
}

int FindCtxNodeInGraph(const GraphViewer& graph_viewer) {
// Assumes there's only 1 context node in this subgraph (graph_viewer)
// Returns index of node
for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) {
auto node = graph_viewer.GetNode(i);
if (node != nullptr && node->OpType() == EPCONTEXT_OP) {
LOGS_DEFAULT(VERBOSE) << "*#* context node found at index=" << i;
return i;
}
}
return -1;
}

const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer) {
// find the top level graph
const Graph* cur_graph = &graph_viewer.GetGraph();
Expand All @@ -40,31 +53,11 @@ const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer) {
return main_graph.ModelPath();
}

/*
* Update ep_cache_context attribute of the EP context node with the given engine binary data
*/
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size) {
ONNX_NAMESPACE::GraphProto* graph_proto = model_proto->mutable_graph();
ONNX_NAMESPACE::NodeProto* node_proto = graph_proto->mutable_node(0);

for (int i = 0; i < node_proto->attribute_size(); ++i) {
ONNX_NAMESPACE::AttributeProto* attribute_proto = node_proto->mutable_attribute(i);
if (attribute_proto->name() == EP_CACHE_CONTEXT) {
std::string engine_data_str = "";
if (size > 0) {
engine_data_str.assign(engine_data, size);
}
attribute_proto->set_s(engine_data_str);
}
}
}

/*
* Create "EP context node" model where engine information is embedded
*/
ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
std::unique_ptr<Model> CreateCtxModel(const GraphViewer& graph_viewer,
const std::string fused_subgraph_name,
const std::string engine_cache_path,
char* engine_data,
size_t size,
Expand Down Expand Up @@ -123,18 +116,11 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_3);

// Create EP context node
graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
ORT_ENFORCE(graph_build.Resolve().IsOK());

// Serialize modelproto to string
auto new_graph_viewer = graph_build.CreateGraphViewer();
auto& metadata = graph_viewer.GetGraph().GetModel().MetaData();
auto model = new_graph_viewer->CreateModel(*logger, metadata);
auto model_proto = model->ToProto();
new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

return model_proto.release();
graph_build.AddNode(fused_subgraph_name, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
auto status = graph_build.Resolve();
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());

return model_build;
}

/*
Expand Down Expand Up @@ -203,17 +189,6 @@ std::string GetCtxModelPath(const std::string& ep_context_file_path,
return ctx_model_path;
}

/*
* Dump "EP context" model
*
*/
void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string& ctx_model_path) {
std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path;
}

bool IsAbsolutePath(const std::string& path_string) {
#ifdef _WIN32
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
Expand Down Expand Up @@ -267,11 +242,11 @@ bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) {
return engine_cache_path.stem().extension().string() == ".stripped";
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
if (!ValidateEPCtxNode(graph_viewer)) {
Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer, const int ctx_node_idx) {
if (!ValidateEPCtxNode(graph_viewer, ctx_node_idx)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
}
auto node = graph_viewer.GetNode(0);
auto node = graph_viewer.GetNode(ctx_node_idx);
auto& attrs = node->GetAttributes();

const int64_t embed_mode = attrs.at(EMBED_MODE).i();
Expand Down Expand Up @@ -381,14 +356,14 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
/*
* The sanity check for EP context contrib op.
*/
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) {
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer, const int ctx_node_idx) {
assert(graph_viewer.NumberOfNodes() == 1);
assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(0);
assert(graph_viewer.GetNode(ctx_node_idx)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(ctx_node_idx);
auto& attrs = node->GetAttributes();

// Show the warning if compute capability is not matched
if (attrs.count(COMPUTE_CAPABILITY) > 0) {
if (attrs.find(COMPUTE_CAPABILITY)!=attrs.end() && attrs.count(COMPUTE_CAPABILITY) > 0) {
std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
// Verify if engine was compiled with ampere+ hardware compatibility enabled
if (model_compute_capability == "80+") {
Expand All @@ -415,4 +390,5 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe

return true;
}

} // namespace onnxruntime
14 changes: 6 additions & 8 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ static const std::string EPCONTEXT_WARNING =
for the best model loading time";

bool GraphHasCtxNode(const GraphViewer& graph_viewer);
int FindCtxNodeInGraph(const GraphViewer& graph_viewer);

const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer);
std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path);
ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
std::unique_ptr<Model> CreateCtxModel(const GraphViewer& graph_viewer,
const std::string fused_subgraph_name,
const std::string engine_cache_path,
char* engine_data,
size_t size,
Expand All @@ -38,11 +41,6 @@ std::string GetCtxModelPath(const std::string& ep_context_file_path,
const std::string& original_model_path);
bool IsAbsolutePath(const std::string& path_string);
bool IsRelativePathToParentPath(const std::string& path_string);
void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string& ctx_model_path);
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size);

class TensorRTCacheModelHandler {
public:
Expand All @@ -67,9 +65,9 @@ class TensorRTCacheModelHandler {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);

bool ValidateEPCtxNode(const GraphViewer& graph_viewer);
bool ValidateEPCtxNode(const GraphViewer& graph_viewer, const int ctx_node_idx);

Status GetEpContextFromGraph(const GraphViewer& graph_viewer);
Status GetEpContextFromGraph(const GraphViewer& graph_viewer, const int ctx_node_idx);

private:
std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine_;
Expand Down
Loading
Loading