Skip to content

Commit

Permalink
pnnx convert torch maximum minimum and torch max min as expression (#…
Browse files Browse the repository at this point in the history
…4944)

* reset device check dtype kind int

* placeholder for ncnn sign

* convert torch maximum minimum

* torch.max as expression

* torch.min as expression
  • Loading branch information
nihui authored Aug 15, 2023
1 parent fed3b43 commit 93e395d
Show file tree
Hide file tree
Showing 13 changed files with 325 additions and 13 deletions.
20 changes: 19 additions & 1 deletion tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1298,11 +1298,19 @@ static std::string expand_expression(const Operator* op)
}
else if (t == "atan2"
|| t == "fmod"
|| t == "max"
|| t == "maximum"
|| t == "min"
|| t == "minimum"
|| t == "pow")
{
std::string binaryop;
if (t == "atan2") binaryop = "torch.atan2";
if (t == "fmod") binaryop = "torch.fmod";
if (t == "max") binaryop = "torch.max";
if (t == "maximum") binaryop = "torch.maximum";
if (t == "min") binaryop = "torch.min";
if (t == "minimum") binaryop = "torch.minimum";
if (t == "pow") binaryop = "torch.pow";

std::string a = exprstack.top();
Expand All @@ -1313,7 +1321,17 @@ static std::string expand_expression(const Operator* op)
std::string r = binaryop + "(" + a + ", " + b + ")";
exprstack.push(r);
}
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "remainder" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift")
else if (t == "add"
|| t == "sub"
|| t == "mul"
|| t == "div"
|| t == "floor_divide"
|| t == "remainder"
|| t == "and"
|| t == "or"
|| t == "xor"
|| t == "lshift"
|| t == "rshift")
{
std::string binaryop;
if (t == "add") binaryop = "+";
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level0/reset_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void reset_device(std::shared_ptr<torch::jit::Graph>& graph, const std::string&
if (dtype_node->hasAttribute(torch::jit::attr::value))
{
// change dtype=half to dtype=float
if (dtype_node->i(torch::jit::attr::value) == 5)
if (dtype_node->kindOf(torch::jit::attr::value) == torch::jit::AttributeKind::i && dtype_node->i(torch::jit::attr::value) == 5)
{
dtype_node->i_(torch::jit::attr::value, 6);
}
Expand Down
12 changes: 12 additions & 0 deletions tools/pnnx/src/pass_level3/fuse_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ static bool operand_maybe_tensor(const Operand* operand)
|| op->type == "aten::div"
|| op->type == "aten::floor_divide"
|| op->type == "aten::fmod"
|| op->type == "aten::max"
|| op->type == "aten::maximum"
|| op->type == "aten::min"
|| op->type == "aten::minimum"
|| op->type == "aten::mul"
|| op->type == "aten::pow"
|| op->type == "aten::remainder")
Expand Down Expand Up @@ -536,6 +540,10 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
else if (op->type == "aten::atan2"
|| op->type == "aten::floor_divide"
|| op->type == "aten::fmod"
|| op->type == "aten::max"
|| op->type == "aten::maximum"
|| op->type == "aten::min"
|| op->type == "aten::minimum"
|| op->type == "aten::mul"
|| op->type == "aten::pow"
|| op->type == "aten::remainder")
Expand Down Expand Up @@ -729,6 +737,10 @@ void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constan
|| op->type == "aten::fmod"
|| op->type == "aten::log"
|| op->type == "aten::log10"
|| op->type == "aten::max"
|| op->type == "aten::maximum"
|| op->type == "aten::min"
|| op->type == "aten::minimum"
|| op->type == "aten::mul"
|| op->type == "aten::neg"
|| op->type == "aten::pow"
Expand Down
14 changes: 14 additions & 0 deletions tools/pnnx/src/pass_level5/eval_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@ static std::string eval_expression(const Operator* op)
else if (t == "atan2"
|| t == "add"
|| t == "sub"
|| t == "max"
|| t == "maximum"
|| t == "min"
|| t == "minimum"
|| t == "mul"
|| t == "div"
|| t == "floor_divide"
Expand Down Expand Up @@ -376,6 +380,16 @@ static std::string eval_expression(const Operator* op)
float r = af - bf;
exprstack.push(std::to_string(r));
}
if (t == "max" || t == "maximum")
{
float r = std::max(af, bf);
exprstack.push(std::to_string(r));
}
if (t == "minimum")
{
float r = std::min(af, bf);
exprstack.push(std::to_string(r));
}
if (t == "mul")
{
float r = af * bf;
Expand Down
18 changes: 17 additions & 1 deletion tools/pnnx/src/pass_ncnn/expand_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
|| t == "reciprocal"
|| t == "round"
|| t == "rsqrt"
|| t == "sign"
|| t == "sin"
|| t == "sqrt"
|| t == "square"
Expand Down Expand Up @@ -160,6 +161,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
if (t == "reciprocal") op_unary->params["0"] = 15;
if (t == "round") op_unary->params["0"] = 18;
if (t == "rsqrt") op_unary->params["0"] = 6;
if (t == "sign") fprintf(stderr, "UnaryOp sign not supported yet\n"); // TODO
if (t == "sin") op_unary->params["0"] = 9;
if (t == "sqrt") op_unary->params["0"] = 5;
if (t == "square") op_unary->params["0"] = 4;
Expand All @@ -178,7 +180,19 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
op_unary->inputs.push_back(op_unary_in);
op_unary->outputs.push_back(op_unary_out);
}
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "fmod" || t == "remainder" || t == "pow" || t == "atan2")
else if (t == "add"
|| t == "atan2"
|| t == "div"
|| t == "floor_divide"
|| t == "fmod"
|| t == "max"
|| t == "maximum"
|| t == "min"
|| t == "minimum"
|| t == "mul"
|| t == "pow"
|| t == "remainder"
|| t == "sub")
{
std::string a = exprstack.top();
exprstack.pop();
Expand All @@ -197,6 +211,8 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
if (t == "sub") op_binary->params["0"] = 1;
if (t == "mul") op_binary->params["0"] = 2;
if (t == "div") op_binary->params["0"] = 3;
if (t == "max" || t == "maximum") op_binary->params["0"] = 4;
if (t == "min" || t == "minimum") op_binary->params["0"] = 5;
if (t == "floor_divide") fprintf(stderr, "BinaryOp floor_divide not supported yet\n"); // TODO
if (t == "fmod") fprintf(stderr, "BinaryOp fmod not supported yet\n"); // TODO
if (t == "remainder") fprintf(stderr, "BinaryOp remainder not supported yet\n"); // TODO
Expand Down
2 changes: 2 additions & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ pnnx_add_test(torch_floor)
pnnx_add_test(torch_imag)
pnnx_add_test(torch_log)
pnnx_add_test(torch_log10)
pnnx_add_test(torch_maximum)
pnnx_add_test(torch_minimum)
pnnx_add_test(torch_neg)
pnnx_add_test(torch_pow)
pnnx_add_test(torch_real)
Expand Down
2 changes: 2 additions & 0 deletions tools/pnnx/tests/ncnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ pnnx_ncnn_add_test(torch_exp)
pnnx_ncnn_add_test(torch_floor)
pnnx_ncnn_add_test(torch_log)
pnnx_ncnn_add_test(torch_log10)
pnnx_ncnn_add_test(torch_maximum)
pnnx_ncnn_add_test(torch_minimum)
pnnx_ncnn_add_test(torch_neg)
pnnx_ncnn_add_test(torch_pow)
pnnx_ncnn_add_test(torch_reciprocal)
Expand Down
12 changes: 7 additions & 5 deletions tools/pnnx/tests/ncnn/test_torch_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
def forward(self, x, y, z, w):
x, x_indices = torch.max(x, dim=1, keepdim=False)
y = torch.max(y)
w = torch.max(z, w)
z, z_indices = torch.max(z, dim=0, keepdim=True)
return x, y, z
return x, y, z, w

def test():
net = Model()
Expand All @@ -34,16 +35,17 @@ def test():
x = torch.rand(3, 16)
y = torch.rand(5, 9, 11)
z = torch.rand(8, 5, 9, 10)
w = torch.rand(5, 9, 10)

a = net(x, y, z)
a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torch_max.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_max.pt inputshape=[3,16],[5,9,11],[8,5,9,10]")
os.system("../../src/pnnx test_torch_max.pt inputshape=[3,16],[5,9,11],[8,5,9,10],[5,9,10]")

# ncnn inference
import test_torch_max_ncnn
Expand Down
61 changes: 61 additions & 0 deletions tools/pnnx/tests/ncnn/test_torch_maximum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
out0 = torch.maximum(x, y)
out1 = torch.maximum(y, y)
out2 = torch.maximum(z, torch.ones_like(z) + 0.1)
return out0, out1, out2

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 16)
y = torch.rand(3, 16)
z = torch.rand(5, 9, 3)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_maximum.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_maximum.pt inputshape=[3,16],[3,16],[5,9,3]")

# ncnn inference
import test_torch_maximum_ncnn
b = test_torch_maximum_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
12 changes: 7 additions & 5 deletions tools/pnnx/tests/ncnn/test_torch_min.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
def forward(self, x, y, z, w):
x, x_indices = torch.min(x, dim=1, keepdim=False)
y = torch.min(y)
w = torch.min(z, w)
z, z_indices = torch.min(z, dim=0, keepdim=True)
return x, y, z
return x, y, z, w

def test():
net = Model()
Expand All @@ -34,16 +35,17 @@ def test():
x = torch.rand(3, 16)
y = torch.rand(5, 9, 11)
z = torch.rand(8, 5, 9, 10)
w = torch.rand(5, 9, 10)

a = net(x, y, z)
a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torch_min.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_min.pt inputshape=[3,16],[5,9,11],[8,5,9,10]")
os.system("../../src/pnnx test_torch_min.pt inputshape=[3,16],[5,9,11],[8,5,9,10],[5,9,10]")

# ncnn inference
import test_torch_min_ncnn
Expand Down
61 changes: 61 additions & 0 deletions tools/pnnx/tests/ncnn/test_torch_minimum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
out0 = torch.minimum(x, y)
out1 = torch.minimum(y, y)
out2 = torch.minimum(z, torch.ones_like(z) + 0.1)
return out0, out1, out2

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 16)
y = torch.rand(3, 16)
z = torch.rand(5, 9, 3)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_minimum.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_minimum.pt inputshape=[3,16],[3,16],[5,9,3]")

# ncnn inference
import test_torch_minimum_ncnn
b = test_torch_minimum_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
Loading

0 comments on commit 93e395d

Please sign in to comment.