Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix accuracy
Browse files Browse the repository at this point in the history
yuchengliu1 committed Apr 1, 2024
1 parent a0b3e12 commit efd2dd4
Showing 2 changed files with 29 additions and 3 deletions.
2 changes: 2 additions & 0 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
@@ -1161,6 +1161,7 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8
auto head_ignore_num = interleave_n_offset % 128;
const __m256i lowMask = _mm256_set1_epi8(0x03);
const __m256i highMask = _mm256_set1_epi8(0x04);
const __m256i bit1Mask = _mm256_set1_epi32(0x0F);
const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2));

@@ -1170,6 +1171,7 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8
for (int i = 0; i < 4; i++) {
auto bit1x32 = _mm256_set1_epi32(bit1_ptr[i]);
bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1);
bit1x32 = _mm256_and_si256(bit1x32, bit1Mask);
bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2);
bit1x32 = _mm256_and_si256(highMask, bit1x32);

30 changes: 27 additions & 3 deletions bestla/bestla/kernel_jit.h
Original file line number Diff line number Diff line change
@@ -259,7 +259,7 @@ class DecompresssS3 {
void *bit2ptr, *bit1ptr, *dstptr, *tmpbuf;
int unpack_elt;
const int8_t lowMask = 3, highMask = 4;
const int32_t bit1Shift2 = (1 << 23) + (1 << 16) + (1 << 9) + (1 << 2);
const int32_t bit1Mask = 0x0F, bit1Shift2 = (1 << 23) + (1 << 16) + (1 << 9) + (1 << 2);
const int32_t bit1Shift1[8] = {0, 4, 8, 12, 16, 20, 24, 28};
};
typedef long long (*func_t)(params*);
@@ -288,10 +288,12 @@ class DecompresssS3 {
xor_(reg_iter, reg_iter);
Xbyak::Ymm lowMask = ymm15;
Xbyak::Ymm highMask = ymm14;
Xbyak::Ymm bit1Shift1 = ymm13;
Xbyak::Ymm bit1Shift2 = ymm12;
Xbyak::Ymm bit1Mask = ymm13;
Xbyak::Ymm bit1Shift1 = ymm12;
Xbyak::Ymm bit1Shift2 = ymm11;
vpbroadcastb(lowMask, byte[parambase + OFFSET(lowMask)]);
vpbroadcastb(highMask, byte[parambase + OFFSET(highMask)]);
vpbroadcastd(bit1Mask, dword[parambase + OFFSET(bit1Mask)]);
vpbroadcastd(bit1Shift2, dword[parambase + OFFSET(bit1Shift2)]);
vmovdqu(bit1Shift1, ptr[parambase + OFFSET(bit1Shift1)]);
mov(reg_bit1ptr, ptr[parambase + OFFSET(bit1ptr)]);
@@ -304,6 +306,7 @@ class DecompresssS3 {
for (int i = 0; i < 4; i++) {
vpbroadcastd(Xbyak::Ymm(i), dword[reg_bit1ptr + 4 * i]);
vpsrlvd(Xbyak::Ymm(i), Xbyak::Ymm(i), bit1Shift1); // todo : check m256
vpand(Xbyak::Ymm(i), Xbyak::Ymm(i), bit1Mask);
vpmulld(Xbyak::Ymm(i), Xbyak::Ymm(i), bit1Shift2);
vpand(Xbyak::Ymm(i), Xbyak::Ymm(i), highMask);
vpsrlw(Xbyak::Ymm(4 + i), bit2_data, 2 * i);
@@ -328,6 +331,27 @@ class DecompresssS3 {
vcvtdq2ps(Xbyak::Ymm(4 + i), Xbyak::Ymm(4 + i));
vmovups(ptr[reg_dst + 4 * (32 * i + 16)], Xbyak::Ymm(i));
vmovups(ptr[reg_dst + 4 * (32 * i + 24)], Xbyak::Ymm(4 + i));

// vpmovsxbd(Xbyak::Ymm(4 + i), Xbyak::Xmm(i));
// vcvtdq2ps(Xbyak::Ymm(4 + i), Xbyak::Ymm(4 + i));
// vmovups(ptr[reg_dst + 4 * (32 * i + 0)], Xbyak::Ymm(4 + i));

// vextracti128(Xbyak::Xmm(4 + i), Xbyak::Ymm(i), 1);
// vpmovsxbd(Xbyak::Ymm(4 + i), Xbyak::Xmm(4 + i));
// vcvtdq2ps(Xbyak::Ymm(4 + i), Xbyak::Ymm(4 + i));
// vmovups(ptr[reg_dst + 4 * (32 * i + 16)], Xbyak::Ymm(4 + i));

// vpsrldq(Xbyak::Ymm(i), Xbyak::Ymm(i), 8);

// vpmovsxbd(Xbyak::Ymm(4 + i), Xbyak::Xmm(i));
// vcvtdq2ps(Xbyak::Ymm(4 + i), Xbyak::Ymm(4 + i));
// vmovups(ptr[reg_dst + 4 * (32 * i + 8)], Xbyak::Ymm(4 + i));

// vextracti128(Xbyak::Xmm(4 + i), Xbyak::Ymm(i), 1);
// vpmovsxbd(Xbyak::Ymm(4 + i), Xbyak::Xmm(4 + i));
// vcvtdq2ps(Xbyak::Ymm(4 + i), Xbyak::Ymm(4 + i));
// vmovups(ptr[reg_dst + 4 * (32 * i + 24)], Xbyak::Ymm(4 + i));

} else {
assert(0);
}

0 comments on commit efd2dd4

Please sign in to comment.