From 56e53259b4e11922f7fc817494012634df9b6ad1 Mon Sep 17 00:00:00 2001 From: Qi <1825013335@qq.com> Date: Fri, 13 Sep 2024 16:28:47 +0800 Subject: [PATCH] =?UTF-8?q?flops=E5=92=8C=E5=86=85=E5=AD=98=E8=AE=BF?= =?UTF-8?q?=E9=97=AE=E6=95=B0=E8=AE=A1=E7=AE=97=E5=AE=8C=E5=96=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/pnnx/src/ir.cpp | 57 +++++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 0c5ff108588..229acbd39fd 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2881,7 +2881,7 @@ void calculate_conv_flops_and_memory(const pnnx::Operator& op) int64_t flops = output_height * output_width * output_channels * input_channels * kernel_size * kernel_size; op.attrs["flops"] = pnnx::Attribute(flops); //属性初始化 - int64_t memory_ops = flops; + int64_t memory_ops = 2 * (output_height * output_width * output_channels * input_channels * kernel_size * kernel_size) + (output_channels * output_height * output_width); op.attrs["memory_ops"] = pnnx::Attribute(memory_ops); } @@ -2891,28 +2891,53 @@ void calculate_fc_flops_and_memory(const pnnx::Operator& op) int input_size = op.params.at("input_size").i; int output_size = op.params.at("output_size").i; - // FLOPS = 2 * (input_size * output_size) (乘法和加法) + // FLOPS = 2 * (input_size * output_size) (乘法和加法) 不考虑激活函数,后续有激活层计算 + // FLOPS = 2 * (input_size * output_size) + output_size int64_t flops = 2 * input_size * output_size; op.attrs["flops"] = pnnx::Attribute(flops); + // 考虑权重、偏置等 + int64_t weight_memory_access = input_size * output_size; + int64_t bias_memory_access = output_size; + int64_t input_memory_access = input_size; + int64_t output_memory_access = output_size; - int64_t memory_ops = flops; + int64_t memory_ops = weight_memory_access + bias_memory_access + input_memory_access + output_memory_access; op.attrs["memory_ops"] = pnnx::Attribute(memory_ops); } // 池化层 -void calculate_pool_flops_and_memory(const pnnx::Operator& op) +void calculate_max_pool_flops_and_memory(const pnnx::Operator& op) { int input_height = op.params.at("input_height").i; int input_width = op.params.at("input_width").i; + int channels = op.params.at("channnels").i; int kernel_size = op.params.at("kernel_size").i; int stride = op.params.at("stride").i; int padding = op.params.at("padding").i; int output_height = (input_height + 2 * padding - kernel_size) / stride + 1; int output_width = (input_width + 2 * padding - kernel_size) / stride + 1; - int64_t flops = output_height * output_width; + // 假设每次比较作为一次flop计算 + int64_t flops = output_height * output_width * channels * kernel_size * kernel_size * 1; op.attrs["flops"] = pnnx::Attribute(flops); - int64_t memory_ops = flops; + int64_t memory_ops = input_height * input_width * channels + output_height * output_width * channels; + op.attrs["memory_ops"] = pnnx::Attribute(memory_ops); +} +void calculate_avg_pool_flops_and_memory(const pnnx::Operator& op) +{ + int input_height = op.params.at("input_height").i; + int input_width = op.params.at("input_width").i; + int channels = op.params.at("channnels").i; + int kernel_size = op.params.at("kernel_size").i; + int stride = op.params.at("stride").i; + int padding = op.params.at("padding").i; + int output_height = (input_height + 2 * padding - kernel_size) / stride + 1; + int output_width = (input_width + 2 * padding - kernel_size) / stride + 1; + + // For addition and one division (division FLOP cost is negligible) + int64_t flops = output_height * output_width * channels * kernel_size * kernel_size; + op.attrs["flops"] = pnnx::Attribute(flops); + int64_t memory_ops = input_height * input_width * channels + output_height * output_width * channels; op.attrs["memory_ops"] = pnnx::Attribute(memory_ops); } @@ -2922,7 +2947,7 @@ void calculate_activation_flops_and_memory(const pnnx::Operator& op) int input_size = op.params.at("input_size").i; int64_t flops = input_size; op.attrs["flops"] = pnnx::Attribute(flops); - int64_t memory_ops = flops; + int64_t memory_ops = flops * 2; // 读取输入和写入输出 op.attrs["memory_ops"] = pnnx::Attribute(memory_ops); } @@ -2935,16 +2960,16 @@ void calculate_bn_flops_and_memory(const pnnx::Operator& op) // FLOPS = 5 * input_channels * (input_height * input_width) (归一化计算) int64_t flops = 5 * input_channels * (input_height * input_width); op.attrs["flops"] = pnnx::Attribute(flops); - int64_t memory_ops = flops; + int64_t memory_ops = 6 * input_channels * (input_height * input_width); //考虑读取输入、均值、方差、缩放因子、偏移和写入输出 op.attrs["memory_ops"] = pnnx::Attribute(memory_ops); } // 丢弃层 void calculate_dropout_flops_and_memory(const pnnx::Operator& op) { - int64_t flops = 0; + int64_t flops = op.params.at("input_size").i; op.attrs["flops"] = pnnx::Attribute(flops); - int64_t memory_ops = op.params.at("input_size").i; + int64_t memory_ops = op.params.at("input_size").i * 2; op.attrs["memory_ops"] = pnnx::Attribute(memory_ops); } @@ -2953,9 +2978,9 @@ void calculate_lstm_flops_and_memory(const pnnx::Operator& op) { int input_size = op.params.at("input_size").i; int hidden_size = op.params.at("hidden_size").i; - int64_t flops = 4 * input_size * hidden_size; + int64_t flops = 10 * input_size * hidden_size; op.attrs["flops"] = pnnx::Attribute(flops); - int64_t memory_ops = flops; + int64_t memory_ops = input_size + 2 * hidden_size + 4 * hidden_size * (input_size + hidden_size) + 2 * hidden_size; op.attrs["memory_ops"] = pnnx::Attribute(memory_ops); } @@ -2966,18 +2991,18 @@ void calculate_embedding_flops_and_memory(const pnnx::Operator& op) int embedding_size = op.params.at("embedding_size").i; int64_t flops = input_vocab_size * embedding_size; op.attrs["flops"] = pnnx::Attribute(flops); - int64_t memory_ops = flops; + int64_t memory_ops = input_vocab_size + input_vocab_size * embedding_size; op.attrs["memory_ops"] = pnnx::Attribute(memory_ops); } // Layer Normalization Layer void calculate_layer_norm_flops_and_memory(const pnnx::Operator& op) { - int input_elements = op.params.at("input_elements").i; + int input_size = op.params.at("input_size").i; // 层归一化涉及归一化和缩放操作,每个元素两次操作(一次减去均值,一次除以方差) - int64_t flops = 2 * input_elements; + int64_t flops = 5 * input_size; op.attrs["flops"] = pnnx::Attribute(flops); - int64_t memory_ops = flops; + int64_t memory_ops = 4 * input_size; op.attrs["memory_ops"] = pnnx::Attribute(memory_ops); }