Skip to content

Commit

Permalink
Keep input/output same order with paddle (#789)
Browse files Browse the repository at this point in the history
* keep input/output same order with paddle

* keep input/output same order with paddle

Co-authored-by: yeliang2258 <[email protected]>
  • Loading branch information
jiangjiajun and yeliang2258 authored Jul 5, 2022
1 parent 99010dd commit a092257
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
26 changes: 24 additions & 2 deletions paddle2onnx/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -719,13 +719,25 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
void PaddleParser::GetGlobalBlockInputOutputInfo() {
inputs.clear();
outputs.clear();
// record the origin order of Paddle model
std::vector<TensorInfo> inputs_with_no_order;
std::vector<TensorInfo> outputs_with_no_order;
std::vector<int64_t> input_order;
std::vector<int64_t> output_order;

for (auto i = 0; i < prog->blocks(0).ops_size(); ++i) {
if (prog->blocks(0).ops(i).type() == "fetch") {
std::string name = prog->blocks(0).ops(i).inputs(0).arguments(0);
outputs.push_back(GetTensorInfo(name, prog->blocks(0)));
outputs_with_no_order.push_back(GetTensorInfo(name, prog->blocks(0)));
int64_t order = -1;
GetOpAttr(prog->blocks(0).ops(i), "col", &order);
output_order.push_back(order);
} else if (prog->blocks(0).ops(i).type() == "feed") {
std::string name = prog->blocks(0).ops(i).outputs(0).arguments(0);
inputs.push_back(GetTensorInfo(name, prog->blocks(0)));
inputs_with_no_order.push_back(GetTensorInfo(name, prog->blocks(0)));
int64_t order = -1;
GetOpAttr(prog->blocks(0).ops(i), "col", &order);
input_order.push_back(order);
}

// This is a trick check, due to the uncorrect shape inference of Paddle
Expand All @@ -736,6 +748,16 @@ void PaddleParser::GetGlobalBlockInputOutputInfo() {
}
}

// Reorder the inputs and outputs to keep same with the original Paddle model
inputs.resize(input_order.size());
for (size_t i = 0; i < input_order.size(); ++i) {
inputs[input_order[i]] = inputs_with_no_order[i];
}
outputs.resize(output_order.size());
for (size_t i = 0; i < output_order.size(); ++i) {
outputs[output_order[i]] = outputs_with_no_order[i];
}

// Trick setting for nms, remove this after shape inference fixed
if (_has_nms) {
for (size_t i = 0; i < outputs.size(); ++i) {
Expand Down
7 changes: 7 additions & 0 deletions paddle2onnx/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ struct TensorInfo {
shape.assign(_shape.begin(), _shape.end());
dtype = _dtype;
}

TensorInfo(const TensorInfo& info) {
name = info.name;
shape.assign(info.shape.begin(), info.shape.end());
dtype = info.dtype;
is_tensor_array = info.is_tensor_array;
}
};

struct Weight {
Expand Down

0 comments on commit a092257

Please sign in to comment.