Skip to content

Commit

Permalink
fix bug and optimize code
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoh-Z committed Jul 21, 2023
1 parent bf224d3 commit 64868ef
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 31 deletions.
10 changes: 6 additions & 4 deletions src/layer/gridsample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ int GridSample::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
int outh = permute_fusion == 0 ? grid.c : grid.h;

top_blob.create(outw, outh, channels, elemsize, opt.blob_allocator);
if (top_blob.empty())
return -100;

Mat offset_blob;
offset_blob.create(outw, outh, grid.c, elemsize, opt.workspace_allocator);

if (top_blob.empty() || offset_blob.empty())
return -100;

//pre-calculate all interpolation offsets for each x y, unpack grid on-the-fly
if (permute_fusion == 0)
{
Expand Down Expand Up @@ -375,12 +376,13 @@ int GridSample::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
int outd = permute_fusion == 0 ? grid.c : grid.d;

top_blob.create(outw, outh, outd, channels, elemsize, opt.blob_allocator);
if (top_blob.empty())
return -100;

Mat offset_blob;
offset_blob.create(outw, outh, outd, grid.c, elemsize, opt.workspace_allocator);

if (top_blob.empty() || offset_blob.empty())
return -100;

//pre-calculate all interpolation offsets for each x y, unpack grid on-the-fly
if (permute_fusion == 0)
{
Expand Down
12 changes: 8 additions & 4 deletions src/layer/x86/gridsample_bicubic_apply_interpolation.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,17 @@ static void gridsample_2d_bicubic_apply_interpolation_p8(const Mat& src, Mat& ds
for (int ii = 0; ii < 4; ii++)
{
int in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 x0_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 x0_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 x1_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 x1_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 x2_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 x2_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 x3_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 x3_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;

value_f[ii] = _mm256_mul_ps(x_coeffs0, x0_val);
value_f[ii] = _mm256_comp_fmadd_ps(x_coeffs1, x1_val, value_f[ii]);
Expand Down
36 changes: 24 additions & 12 deletions src/layer/x86/gridsample_bilinear_apply_interpolation.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,17 @@ static void gridsample_2d_bilinear_apply_interpolation_p8(const Mat& src, Mat& d
for (int i = 0; i < grid_size; i++)
{
int in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v00_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v00_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v01_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v01_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v10_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v10_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v11_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v11_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;

__m256 value = _mm256_set1_ps(*offset_value_ptr++);
__m256 v0 = _mm256_comp_fmadd_ps(v01_val, value, _mm256_comp_fnmadd_ps(v00_val, value, v00_val));
Expand Down Expand Up @@ -179,22 +183,30 @@ static void gridsample_3d_bilinear_apply_interpolation_p8(const Mat& src, Mat& d
for (int i = 0; i < grid_size; i++)
{
int in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v000_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v000_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v001_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v001_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v010_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v010_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v011_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v011_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;

in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v100_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v100_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v101_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v101_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v110_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v110_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;
in_bound = *reinterpret_cast<const int*>(offset_value_ptr) >= 0 ? -1 : 0;
__m256 v111_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr++), _mm256_set1_epi32(in_bound));
__m256 v111_val = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_value_ptr), _mm256_set1_epi32(in_bound));
offset_value_ptr++;

__m256 value = _mm256_set1_ps(*offset_value_ptr++);
__m256 v00 = _mm256_comp_fmadd_ps(v001_val, value, _mm256_comp_fnmadd_ps(v000_val, value, v000_val));
Expand Down
3 changes: 2 additions & 1 deletion src/layer/x86/gridsample_nearest_apply_interpolation.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ static void gridsample_nearest_apply_interpolation_p8(const Mat& src, Mat& dst,
for (int i = 0; i < grid_size; i++)
{
int in_bound = *reinterpret_cast<const int*>(offset_ptr) >= 0 ? -1 : 0;
__m256 _v = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_ptr++), _mm256_set1_epi32(in_bound));
__m256 _v = _mm256_maskload_ps(srcptr + static_cast<int>(*offset_ptr), _mm256_set1_epi32(in_bound));
offset_ptr++;

_mm256_storeu_ps(dstptr, _v);
dstptr += 8;
Expand Down
15 changes: 5 additions & 10 deletions tests/test_gridsample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,9 @@ int main()
{
SRAND(7767517);

int ret = 0
|| test_gridsample_0()
|| test_gridsample_1()
|| test_gridsample_2()
|| test_gridsample_3();

getchar();
getchar();

return ret;
return 0
|| test_gridsample_0()
|| test_gridsample_1()
|| test_gridsample_2()
|| test_gridsample_3();
}

0 comments on commit 64868ef

Please sign in to comment.