diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index 308e37aff..a53f5bdc3 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -1161,6 +1161,7 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 auto head_ignore_num = interleave_n_offset % 128; const __m256i lowMask = _mm256_set1_epi8(0x03); const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); @@ -1170,6 +1171,7 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 for (int i = 0; i < 4; i++) { auto bit1x32 = _mm256_set1_epi32(bit1_ptr[i]); bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1); + bit1x32 = _mm256_and_si256(bit1x32, bit1Mask); bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2); bit1x32 = _mm256_and_si256(highMask, bit1x32); diff --git a/bestla/bestla/kernel_jit.h b/bestla/bestla/kernel_jit.h index 027d263c8..b69d22563 100644 --- a/bestla/bestla/kernel_jit.h +++ b/bestla/bestla/kernel_jit.h @@ -259,7 +259,7 @@ class DecompresssS3 { void *bit2ptr, *bit1ptr, *dstptr, *tmpbuf; int unpack_elt; const int8_t lowMask = 3, highMask = 4; - const int32_t bit1Shift2 = (1 << 23) + (1 << 16) + (1 << 9) + (1 << 2); + const int32_t bit1Mask = 0x0F, bit1Shift2 = (1 << 23) + (1 << 16) + (1 << 9) + (1 << 2); const int32_t bit1Shift1[8] = {0, 4, 8, 12, 16, 20, 24, 28}; }; typedef long long (*func_t)(params*); @@ -288,10 +288,12 @@ class DecompresssS3 { xor_(reg_iter, reg_iter); Xbyak::Ymm lowMask = ymm15; Xbyak::Ymm highMask = ymm14; - Xbyak::Ymm bit1Shift1 = ymm13; - Xbyak::Ymm bit1Shift2 = ymm12; + Xbyak::Ymm bit1Mask = ymm13; + Xbyak::Ymm bit1Shift1 = ymm12; + Xbyak::Ymm bit1Shift2 = ymm11; vpbroadcastb(lowMask, byte[parambase + OFFSET(lowMask)]); vpbroadcastb(highMask, byte[parambase + OFFSET(highMask)]); + vpbroadcastd(bit1Mask, dword[parambase + OFFSET(bit1Mask)]); vpbroadcastd(bit1Shift2, dword[parambase + OFFSET(bit1Shift2)]); vmovdqu(bit1Shift1, ptr[parambase + OFFSET(bit1Shift1)]); mov(reg_bit1ptr, ptr[parambase + OFFSET(bit1ptr)]); @@ -304,6 +306,7 @@ class DecompresssS3 { for (int i = 0; i < 4; i++) { vpbroadcastd(Xbyak::Ymm(i), dword[reg_bit1ptr + 4 * i]); vpsrlvd(Xbyak::Ymm(i), Xbyak::Ymm(i), bit1Shift1); // todo : check m256 + vpand(Xbyak::Ymm(i), Xbyak::Ymm(i), bit1Mask); vpmulld(Xbyak::Ymm(i), Xbyak::Ymm(i), bit1Shift2); vpand(Xbyak::Ymm(i), Xbyak::Ymm(i), highMask); vpsrlw(Xbyak::Ymm(4 + i), bit2_data, 2 * i); @@ -328,6 +331,27 @@ class DecompresssS3 { vcvtdq2ps(Xbyak::Ymm(4 + i), Xbyak::Ymm(4 + i)); vmovups(ptr[reg_dst + 4 * (32 * i + 16)], Xbyak::Ymm(i)); vmovups(ptr[reg_dst + 4 * (32 * i + 24)], Xbyak::Ymm(4 + i)); + + // vpmovsxbd(Xbyak::Ymm(4 + i), Xbyak::Xmm(i)); + // vcvtdq2ps(Xbyak::Ymm(4 + i), Xbyak::Ymm(4 + i)); + // vmovups(ptr[reg_dst + 4 * (32 * i + 0)], Xbyak::Ymm(4 + i)); + + // vextracti128(Xbyak::Xmm(4 + i), Xbyak::Ymm(i), 1); + // vpmovsxbd(Xbyak::Ymm(4 + i), Xbyak::Xmm(4 + i)); + // vcvtdq2ps(Xbyak::Ymm(4 + i), Xbyak::Ymm(4 + i)); + // vmovups(ptr[reg_dst + 4 * (32 * i + 16)], Xbyak::Ymm(4 + i)); + + // vpsrldq(Xbyak::Ymm(i), Xbyak::Ymm(i), 8); + + // vpmovsxbd(Xbyak::Ymm(4 + i), Xbyak::Xmm(i)); + // vcvtdq2ps(Xbyak::Ymm(4 + i), Xbyak::Ymm(4 + i)); + // vmovups(ptr[reg_dst + 4 * (32 * i + 8)], Xbyak::Ymm(4 + i)); + + // vextracti128(Xbyak::Xmm(4 + i), Xbyak::Ymm(i), 1); + // vpmovsxbd(Xbyak::Ymm(4 + i), Xbyak::Xmm(4 + i)); + // vcvtdq2ps(Xbyak::Ymm(4 + i), Xbyak::Ymm(4 + i)); + // vmovups(ptr[reg_dst + 4 * (32 * i + 24)], Xbyak::Ymm(4 + i)); + } else { assert(0); }