-
Notifications
You must be signed in to change notification settings - Fork 1
/
backend.cc
58 lines (50 loc) · 1.85 KB
/
backend.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <algorithm>
#include <iostream>
#include <string>
#include <tuple>
#include <vector>
#include "backend.h"
#include "status.h"
namespace mlperf_bench {
Backend::Backend() {
allocator_ = allocator_info_;
};
Status Backend::LoadModel(std::string path, std::vector<std::string> outputs) {
#ifdef _WIN32
std::wstring widestr = std::wstring(path.begin(), path.end());
session_ = new Ort::Session(env_, widestr.c_str(), opt_);
#else
session_ = new Ort::Session(env_, path.c_str(), opt_);
#endif
for (size_t i = 0; i < this->session_->GetInputCount(); i++) {
input_names_.push_back(session_->GetInputName(i, allocator_));
auto ti = session_->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo();
auto input_type = ti.GetElementType();
// FIXME: ti.GetElementType() returns junk on linux. Hack it for now.
if (path.find("ssd") != std::string::npos) {
input_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
}
else {
input_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
input_type_.push_back(input_type);
}
for (size_t i = 0; i < this->session_->GetOutputCount(); i++) {
char* name = session_->GetOutputName(i, allocator_);
if (outputs.size() == 0 ||
std::find(outputs.begin(), outputs.end(), name) != outputs.end()) {
auto ti = session_->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo();
auto shape = ti.GetShape();
output_shapes_.push_back(shape);
output_names_.push_back(name);
}
}
return Status::OK();
}
std::vector<Ort::Value> Backend::Run(Ort::Value* inputs, size_t input_count) {
std::vector<Ort::Value> results =
session_->Run(run_options_, input_names_.data(), inputs, 1,
output_names_.data(), output_names_.size());
return results;
}
}