Skip to content

Commit

Permalink
fix vulkan winograd weight layout with cooperative matrix enabled (#4093
Browse files Browse the repository at this point in the history
)
  • Loading branch information
nihui authored Jul 28, 2022
1 parent 720f3c9 commit 0666143
Showing 1 changed file with 60 additions and 4 deletions.
64 changes: 60 additions & 4 deletions src/layer/vulkan/convolution_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,37 @@ int Convolution_vulkan::create_pipeline(const Option& _opt)
}
}

// src = 36-inch-outch
// dst = 8a-8b-inch/8a-outch/8b-36
if (use_cooperative_matrix)
{
// src = 36-inch-outch
// dst = 8b-8a-inch/8a-outch/8b-36
weight_winograd43_data_packed.create(num_input / 8, num_output / 8, 36, (size_t)4 * 8 * 8, 8 * 8);

for (int k = 0; k < 36; k++)
{
float* g00 = weight_winograd43_data_packed.channel(k);

for (int q = 0; q + (8 - 1) < num_output; q += 8)
{
for (int p = 0; p + (8 - 1) < num_input; p += 8)
{
for (int i = 0; i < 8; i++)
{
for (int j = 0; j < 8; j++)
{
const float* k00 = weight_data_tm.channel(q + j).row(p + i);
g00[0] = k00[k];
g00++;
}
}
}
}
}
}
else
{
// src = 36-inch-outch
// dst = 8a-8b-inch/8a-outch/8b-36
weight_winograd43_data_packed.create(num_input / elempack, num_output / out_elempack, 36, (size_t)4 * elempack * out_elempack, elempack * out_elempack);

for (int k = 0; k < 36; k++)
Expand Down Expand Up @@ -442,9 +470,37 @@ int Convolution_vulkan::create_pipeline(const Option& _opt)
}
}

// src = 16-inch-outch
// dst = 8a-8b-inch/8a-outch/8b-16
if (use_cooperative_matrix)
{
// src = 16-inch-outch
// dst = 8b-8a-inch/8a-outch/8b-16
weight_winograd23_data_packed.create(num_input / 8, num_output / 8, 16, (size_t)4 * 8 * 8, 8 * 8);

for (int k = 0; k < 16; k++)
{
float* g00 = weight_winograd23_data_packed.channel(k);

for (int q = 0; q + (8 - 1) < num_output; q += 8)
{
for (int p = 0; p + (8 - 1) < num_input; p += 8)
{
for (int i = 0; i < 8; i++)
{
for (int j = 0; j < 8; j++)
{
const float* k00 = weight_data_tm.channel(q + j).row(p + i);
g00[0] = k00[k];
g00++;
}
}
}
}
}
}
else
{
// src = 16-inch-outch
// dst = 8a-8b-inch/8a-outch/8b-16
weight_winograd23_data_packed.create(num_input / elempack, num_output / out_elempack, 16, (size_t)4 * elempack * out_elempack, elempack * out_elempack);

for (int k = 0; k < 16; k++)
Expand Down

0 comments on commit 0666143

Please sign in to comment.