Skip to content

Commit

Permalink
flops和内存访问数计算完善
Browse files Browse the repository at this point in the history
  • Loading branch information
Qi-qi0317 committed Sep 13, 2024
1 parent 2c57f5c commit 56e5325
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2881,7 +2881,7 @@ void calculate_conv_flops_and_memory(const pnnx::Operator& op)
int64_t flops = output_height * output_width * output_channels * input_channels * kernel_size * kernel_size;
op.attrs["flops"] = pnnx::Attribute(flops); //属性初始化

int64_t memory_ops = flops;
int64_t memory_ops = 2 * (output_height * output_width * output_channels * input_channels * kernel_size * kernel_size) + (output_channels * output_height * output_width);
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

Expand All @@ -2891,28 +2891,53 @@ void calculate_fc_flops_and_memory(const pnnx::Operator& op)
int input_size = op.params.at("input_size").i;
int output_size = op.params.at("output_size").i;

// FLOPS = 2 * (input_size * output_size) (乘法和加法)
// FLOPS = 2 * (input_size * output_size) (乘法和加法) 不考虑激活函数,后续有激活层计算
// FLOPS = 2 * (input_size * output_size) + output_size
int64_t flops = 2 * input_size * output_size;
op.attrs["flops"] = pnnx::Attribute(flops);
// 考虑权重、偏置等
int64_t weight_memory_access = input_size * output_size;
int64_t bias_memory_access = output_size;
int64_t input_memory_access = input_size;
int64_t output_memory_access = output_size;

int64_t memory_ops = flops;
int64_t memory_ops = weight_memory_access + bias_memory_access + input_memory_access + output_memory_access;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

// 池化层
void calculate_pool_flops_and_memory(const pnnx::Operator& op)
void calculate_max_pool_flops_and_memory(const pnnx::Operator& op)
{
int input_height = op.params.at("input_height").i;
int input_width = op.params.at("input_width").i;
int channels = op.params.at("channnels").i;
int kernel_size = op.params.at("kernel_size").i;
int stride = op.params.at("stride").i;
int padding = op.params.at("padding").i;
int output_height = (input_height + 2 * padding - kernel_size) / stride + 1;
int output_width = (input_width + 2 * padding - kernel_size) / stride + 1;

int64_t flops = output_height * output_width;
// 假设每次比较作为一次flop计算
int64_t flops = output_height * output_width * channels * kernel_size * kernel_size * 1;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = flops;
int64_t memory_ops = input_height * input_width * channels + output_height * output_width * channels;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}
void calculate_avg_pool_flops_and_memory(const pnnx::Operator& op)
{
int input_height = op.params.at("input_height").i;
int input_width = op.params.at("input_width").i;
int channels = op.params.at("channnels").i;
int kernel_size = op.params.at("kernel_size").i;
int stride = op.params.at("stride").i;
int padding = op.params.at("padding").i;
int output_height = (input_height + 2 * padding - kernel_size) / stride + 1;
int output_width = (input_width + 2 * padding - kernel_size) / stride + 1;

// For addition and one division (division FLOP cost is negligible)
int64_t flops = output_height * output_width * channels * kernel_size * kernel_size;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = input_height * input_width * channels + output_height * output_width * channels;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

Expand All @@ -2922,7 +2947,7 @@ void calculate_activation_flops_and_memory(const pnnx::Operator& op)
int input_size = op.params.at("input_size").i;
int64_t flops = input_size;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = flops;
int64_t memory_ops = flops * 2; // 读取输入和写入输出
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

Expand All @@ -2935,16 +2960,16 @@ void calculate_bn_flops_and_memory(const pnnx::Operator& op)
// FLOPS = 5 * input_channels * (input_height * input_width) (归一化计算)
int64_t flops = 5 * input_channels * (input_height * input_width);
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = flops;
int64_t memory_ops = 6 * input_channels * (input_height * input_width); //考虑读取输入、均值、方差、缩放因子、偏移和写入输出
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

// 丢弃层
void calculate_dropout_flops_and_memory(const pnnx::Operator& op)
{
int64_t flops = 0;
int64_t flops = op.params.at("input_size").i;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = op.params.at("input_size").i;
int64_t memory_ops = op.params.at("input_size").i * 2;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

Expand All @@ -2953,9 +2978,9 @@ void calculate_lstm_flops_and_memory(const pnnx::Operator& op)
{
int input_size = op.params.at("input_size").i;
int hidden_size = op.params.at("hidden_size").i;
int64_t flops = 4 * input_size * hidden_size;
int64_t flops = 10 * input_size * hidden_size;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = flops;
int64_t memory_ops = input_size + 2 * hidden_size + 4 * hidden_size * (input_size + hidden_size) + 2 * hidden_size;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

Expand All @@ -2966,18 +2991,18 @@ void calculate_embedding_flops_and_memory(const pnnx::Operator& op)
int embedding_size = op.params.at("embedding_size").i;
int64_t flops = input_vocab_size * embedding_size;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = flops;
int64_t memory_ops = input_vocab_size + input_vocab_size * embedding_size;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

// Layer Normalization Layer
void calculate_layer_norm_flops_and_memory(const pnnx::Operator& op)
{
int input_elements = op.params.at("input_elements").i;
int input_size = op.params.at("input_size").i;
// 层归一化涉及归一化和缩放操作,每个元素两次操作(一次减去均值,一次除以方差)
int64_t flops = 2 * input_elements;
int64_t flops = 5 * input_size;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = flops;
int64_t memory_ops = 4 * input_size;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

Expand Down

0 comments on commit 56e5325

Please sign in to comment.