diff --git a/bestla/README.md b/bestla/README.md index 8b46f5a9b..73b4ca890 100644 --- a/bestla/README.md +++ b/bestla/README.md @@ -22,7 +22,11 @@ BesTLA provides weight-only linear computational capabilities for LLM inference. | Weight dtype | Compute dtype | Scale dtype | algo | | ---------------------- | :----------------: | :---------------: | :--------: | | INT8 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | -| INT4 (CLIP, FULLRANGE) | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT4 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT3 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT2 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT5 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT6 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | | FP8 (E4M3, E5M2) | BF16 / FP32 | FP32 / FP8 (E8M0) | sym | | FP4 (E2M1) | BF16 / FP32 | BF16 / FP32 | sym | | NF4 | BF16 / FP32 | BF16 / FP32 | sym | @@ -47,11 +51,32 @@ BesTLA provides assembly-level postop-fusion through epilogue to minimize the ov ## Compilation Requirements and Usage Compile: -- GCC version >=8.5.0 -- CMake version >=3.5 +- GCC version >= 9.0 +- CMake version >= 3.12 +- MSVC version >= 1900 +- oneAPI version >= 2024.0 + +Best Performance: + +- GCC >= 11.0.0 +- MSVC >= 1930 +- DPCPP >= 2024.0 + Usage: ```cmake add_subdirectory(bestla) target_link_libraries("${YOUR_PROJECT}" bestla::bestla) ``` + +# Benchmark +Build with: +```shell +mkdir build +cd build +cmake .. -DBTLA_UT_BENCHMARK=ON -DBTLA_UT_ALL=ON -DCMAKE_BUILD_TYPE=Release +cmake --build . -j +./bestla_benchmark +``` + +More template usages, please refer code in [bestla_benchmark](bestla/ut/bestla_benchmark.cpp) diff --git a/bestla/bestla/bestla.h b/bestla/bestla/bestla.h index 512d550fb..49d69c29c 100644 --- a/bestla/bestla/bestla.h +++ b/bestla/bestla/bestla.h @@ -37,9 +37,13 @@ enum class BTLA_DTYPE : uint32_t { EleBitsMask = 0xff, EleBitsShift = 0, EleBitsUndef = 0, + EleBits1 = 1, EleBits2 = 2, EleBits3 = 3, EleBits4 = 4, + EleBits5 = 5, + EleBits6 = 6, + EleBits7 = 7, EleBits8 = 8, EleBits16 = 16, EleBits32 = 32, @@ -66,9 +70,13 @@ enum class BTLA_DTYPE : uint32_t { DQ8_BNB = EleBits8 | TypeFloat | SubType4, S8 = EleBits8 | TypeInt, U8 = EleBits8 | TypeInt | SubType1, + S1_CLIP = EleBits1 | TypeInt, S2_CLIP = EleBits2 | TypeInt, S3_CLIP = EleBits3 | TypeInt, S4_CLIP = EleBits4 | TypeInt, + S5_CLIP = EleBits5 | TypeInt, + S6_CLIP = EleBits6 | TypeInt, + S7_CLIP = EleBits7 | TypeInt, F4_E2M1 = EleBits4 | TypeFloat, F4_BNB = EleBits4 | TypeFloat | SubType1, F4_NF4 = EleBits4 | TypeFloat | SubType2, diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 5a5bd2a24..136cfa35b 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -555,7 +555,7 @@ class WeightKBlockNInteger { }); } - static void compressBit3Weight(const int N, const int K, const int8_t* B, int8_t* dstptr, + static void compressBit3Weight(const int N, const int K, const int8_t* B, int8_t* dstptr, BTLA_DTYPE qtype, parallel::IThreading* threading) { auto bit1_offset = size_t(N) * K; auto bit2ptr = reinterpret_cast(dstptr); @@ -564,7 +564,25 @@ class WeightKBlockNInteger { assert(ret == BTLA_CODE::Success); } - static void compressBit2Weight(const int N, const int K, const int8_t* B, int8_t* dstptr, + static void compressBit5Weight(const int N, const int K, const int8_t* B, int8_t* dstptr, BTLA_DTYPE qtype, + parallel::IThreading* threading) { + auto bit1_offset = size_t(N) * K; + auto bit4ptr = reinterpret_cast(dstptr); + auto bit1ptr = reinterpret_cast(dstptr + bit1_offset / 2); + auto ret = kernel::wrapper::CompressBit5::forward(B, bit4ptr, bit1ptr, bit1_offset); + assert(ret == BTLA_CODE::Success); + } + + static void compressBit6Weight(const int N, const int K, const int8_t* B, int8_t* dstptr, BTLA_DTYPE qtype, + parallel::IThreading* threading) { + auto bit2_offset = size_t(N) * K; + auto bit4ptr = reinterpret_cast(dstptr); + auto bit2ptr = reinterpret_cast(dstptr + bit2_offset / 2); + auto ret = kernel::wrapper::CompressBit6::forward(B, bit4ptr, bit2ptr, bit2_offset); + assert(ret == BTLA_CODE::Success); + } + + static void compressBit2Weight(const int N, const int K, const int8_t* B, int8_t* dstptr, BTLA_DTYPE qtype, parallel::IThreading* threading) { // TODO(zhe): 1D parallel compress parallel::Scheduler2D _para({threading->num_threads(), 1, K * N, 1, 64}); @@ -581,22 +599,38 @@ class WeightKBlockNInteger { }); } - static void compressWeight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr, BTLA_DTYPE qtype, - parallel::IThreading* threading) { - if (qtype == BTLA_DTYPE::S3_CLIP) return compressBit3Weight(N, K, B, dstptr, threading); - if (qtype == BTLA_DTYPE::S2_CLIP) return compressBit2Weight(N, K, B, dstptr, threading); - parallel::Scheduler2D _para({threading->num_threads(), K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE}); + static void compressBit4Weight(const int N, const int K, const int8_t* B, int8_t* dstptr, BTLA_DTYPE qtype, + parallel::IThreading* threading) { + parallel::Scheduler2D _para({threading->num_threads(), 1, K * N, 1, 64}); threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp({tidx}); _para.getIndex(thdp); if (thdp.valid) { - auto ret = doCompress(B + thdp.loc[0] * ldb + thdp.loc[1], dstptr + thdp.loc[0] * ldb / 2 + thdp.loc[1] / 2, - thdp.size[0], thdp.size[1], ldb, ldb, qtype); + BTLA_CODE ret = BTLA_CODE::NotSupport; + if (qtype == BTLA_DTYPE::S4_CLIP) { + auto bit4ptr = reinterpret_cast(dstptr); + ret = kernel::wrapper::CompressS8S4::forward(B + thdp.loc[1], bit4ptr + thdp.loc[1] / 2, thdp.size[1]); + } else if (qtype == BTLA_DTYPE::F4_BNB || qtype == BTLA_DTYPE::F4_NF4 || qtype == BTLA_DTYPE::F4_E2M1) { + auto bit4ptr = reinterpret_cast(dstptr); + ret = kernel::wrapper::CompressFp4::forward(B + thdp.loc[1], bit4ptr + thdp.loc[1] / 2, thdp.size[1]); + } else { + assert(0); + } assert(ret == BTLA_CODE::Success); (void)ret; } }); } + static void compressWeight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr, BTLA_DTYPE qtype, + parallel::IThreading* threading) { + if (qtype == BTLA_DTYPE::S6_CLIP) return compressBit6Weight(N, K, B, dstptr, qtype, threading); + if (qtype == BTLA_DTYPE::S5_CLIP) return compressBit5Weight(N, K, B, dstptr, qtype, threading); + if (qtype == BTLA_DTYPE::S4_CLIP) return compressBit4Weight(N, K, B, dstptr, qtype, threading); + if (qtype == BTLA_DTYPE::S3_CLIP) return compressBit3Weight(N, K, B, dstptr, qtype, threading); + if (qtype == BTLA_DTYPE::S2_CLIP) return compressBit2Weight(N, K, B, dstptr, qtype, threading); + if (qtype == BTLA_DTYPE::F4_BNB || qtype == BTLA_DTYPE::F4_NF4 || qtype == BTLA_DTYPE::F4_E2M1) + return compressBit4Weight(N, K, B, dstptr, qtype, threading); + } template static void reduce(const int N, const int K, const int KBlock, const float* B, const int ldb, RED_T* rptr, @@ -636,6 +670,10 @@ class WeightKBlockNInteger { auto wptr = _param.packedW; if (wptr->mDType == BTLA_DTYPE::S8) { return getQ8Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S6_CLIP) { + return getQ6Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S5_CLIP) { + return getQ5Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } else if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { return getQ4Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { @@ -710,16 +748,14 @@ class WeightKBlockNInteger { if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { if (wptr->SDtype() == BTLA_DTYPE::DQ8_BNB) { auto internal_n_offset = n_offset + i; - if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressDQKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, - wptr->template SPtr(), wptr->template DQPtr(), k_offset / _GemmCore_T::PACK_ROW, - internal_n_offset, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, wptr->mN, wptr->mDqBlockSize, - wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1, tmpcache, cachesize); - } + kernel::wrapper::DecompressDQKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( + wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + + i * KPad / 2, + *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, + wptr->template SPtr(), wptr->template DQPtr(), k_offset / _GemmCore_T::PACK_ROW, + internal_n_offset, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, wptr->mN, wptr->mDqBlockSize, + wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1, tmpcache, cachesize); } else { auto sptr = wptr->template SPtr(); kernel::wrapper::DecompressKBlockS4Fp<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE, _T>::template forward( @@ -733,11 +769,10 @@ class WeightKBlockNInteger { auto sptr = wptr->template SPtr(); int8_t* bit3_ptr = wptr->template WPtr(); auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; - auto ld_dst = _GemmCore_T::NTILE * KPad; - auto row = NPad / _GemmCore_T::NTILE; assert(elt_offset % 8 == 0); + size_t bit1_offset = size_t(NPad) * KPad; auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); - auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); + auto bit1ptr = reinterpret_cast(bit3_ptr + bit1_offset / 4 + elt_offset / 8); kernel::wrapper::DecompressKBlockS3Fp<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE, _T>::template forward( bit2ptr, bit1ptr, *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, sptr, wptr->SDtype(), zptr, k_offset, n_offset + i, wptr->mBlockSize, NPad, tmpcache, cachesize); @@ -756,9 +791,29 @@ class WeightKBlockNInteger { kernel::wrapper::DecompressKBlockS8Fp<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE, _T>::template forward( bptr, *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, sptr, wptr->SDtype(), zptr, k_offset, n_offset + i, wptr->mBlockSize, NPad, tmpcache, cachesize); - } - - else { + } else if (wptr->mDType == BTLA_DTYPE::S5_CLIP) { + auto sptr = wptr->template SPtr(); + int8_t* bit5_ptr = wptr->template WPtr(); + auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; + assert(elt_offset % 8 == 0); + size_t bit1_offset = size_t(NPad) * KPad; + auto bit4ptr = reinterpret_cast(bit5_ptr + elt_offset / 2); + auto bit1ptr = reinterpret_cast(bit5_ptr + bit1_offset / 2 + elt_offset / 8); + kernel::wrapper::DecompressKBlockS5Fp<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE, _T>::template forward( + bit4ptr, bit1ptr, *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, sptr, wptr->SDtype(), zptr, k_offset, + n_offset + i, wptr->mBlockSize, NPad, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S6_CLIP) { + auto sptr = wptr->template SPtr(); + int8_t* bit6_ptr = wptr->template WPtr(); + auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; + assert(elt_offset % 4 == 0); + size_t bit2_offset = size_t(NPad) * KPad; + auto bit4ptr = reinterpret_cast(bit6_ptr + elt_offset / 2); + auto bit2ptr = reinterpret_cast(bit6_ptr + bit2_offset / 2 + elt_offset / 4); + kernel::wrapper::DecompressKBlockS6Fp<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE, _T>::template forward( + bit4ptr, bit2ptr, *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, sptr, wptr->SDtype(), zptr, k_offset, + n_offset + i, wptr->mBlockSize, NPad, tmpcache, cachesize); + } else { assert(0); } } @@ -804,15 +859,13 @@ class WeightKBlockNInteger { auto zpptr = wptr->template ZPtr(); auto KPad = wptr->mKPad; auto NPad = wptr->mNPad; - int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; - auto row = NPad / _GemmCore_T::NTILE; - auto ld_dst = _GemmCore_T::NTILE * KPad; + size_t bit1_offset = size_t(NPad) * KPad; auto base_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { auto elt_offset = base_offset + i * KPad; assert(elt_offset % 8 == 0); auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); - auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); + auto bit1ptr = reinterpret_cast(bit3_ptr + bit1_offset / 4 + elt_offset / 8); kernel::wrapper::DecompressKBlockS3S8<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE>::template forward( bit2ptr, bit1ptr, wptr->IsAsym() ? zpptr : nullptr, *dstptr + i * k_size, wptr->mBlockSize, wptr->CStep(), n_offset + i, k_offset, k_size, _GemmCore_T::NTILE, tmpcache, cachesize); @@ -821,6 +874,50 @@ class WeightKBlockNInteger { return BTLA_CODE::Success; } + static inline BTLA_CODE getQ5Weight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { + auto wptr = _param.packedW; + int8_t* bit5_ptr = wptr->template WPtr(); + auto zpptr = wptr->template ZPtr(); + auto KPad = wptr->mKPad; + auto NPad = wptr->mNPad; + size_t bit1_offset = size_t(NPad) * KPad; + auto base_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE; + for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { + auto elt_offset = base_offset + i * KPad; + assert(elt_offset % 8 == 0); + auto bit4ptr = reinterpret_cast(bit5_ptr + elt_offset / 2); + auto bit1ptr = reinterpret_cast(bit5_ptr + bit1_offset / 2 + elt_offset / 8); + kernel::wrapper::DecompressKBlockS5S8<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE>::template forward( + bit4ptr, bit1ptr, wptr->IsAsym() ? zpptr : nullptr, *dstptr + i * k_size, wptr->mBlockSize, wptr->CStep(), + n_offset + i, k_offset, k_size, _GemmCore_T::NTILE, tmpcache, cachesize); + } + *dststep = k_size; + return BTLA_CODE::Success; + } + + static inline BTLA_CODE getQ6Weight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { + auto wptr = _param.packedW; + int8_t* bit6_ptr = wptr->template WPtr(); + auto zpptr = wptr->template ZPtr(); + auto KPad = wptr->mKPad; + auto NPad = wptr->mNPad; + size_t bit2_offset = size_t(NPad) * KPad; + auto base_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE; + for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { + auto elt_offset = base_offset + i * KPad; + assert(elt_offset % 4 == 0); + auto bit4ptr = reinterpret_cast(bit6_ptr + elt_offset / 2); + auto bit2ptr = reinterpret_cast(bit6_ptr + bit2_offset / 2 + elt_offset / 4); + kernel::wrapper::DecompressKBlockS6S8<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE>::template forward( + bit4ptr, bit2ptr, wptr->IsAsym() ? zpptr : nullptr, *dstptr + i * k_size, wptr->mBlockSize, wptr->CStep(), + n_offset + i, k_offset, k_size, _GemmCore_T::NTILE, tmpcache, cachesize); + } + *dststep = k_size; + return BTLA_CODE::Success; + } + static inline BTLA_CODE getQ2Weight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; @@ -848,6 +945,12 @@ class WeightKBlockNInteger { if (quant_dtype == BTLA_DTYPE::S8) { kernel::wrapper::QuantizeSignIntRowBlock::forward(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); + } else if (quant_dtype == BTLA_DTYPE::S6_CLIP) { + kernel::wrapper::QuantizeSignIntRowBlock::forward( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); + } else if (quant_dtype == BTLA_DTYPE::S5_CLIP) { + kernel::wrapper::QuantizeSignIntRowBlock::forward( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); } else if (quant_dtype == BTLA_DTYPE::S4_CLIP) { kernel::wrapper::QuantizeSignIntRowBlock::forward( srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); @@ -861,22 +964,6 @@ class WeightKBlockNInteger { assert(0); } } - - static inline BTLA_CODE doCompress(const int8_t* srcptr, void* dstptr, int row, int col, int ld_src, int ld_dst, - BTLA_DTYPE quant_dtype) { - if (quant_dtype == BTLA_DTYPE::S4_CLIP) { - return kernel::wrapper::CompressS8S4::forward(srcptr, reinterpret_cast(dstptr), row, col, - ld_src, ld_dst); - } else if (quant_dtype == BTLA_DTYPE::F4_BNB || quant_dtype == BTLA_DTYPE::F4_NF4 || - quant_dtype == BTLA_DTYPE::F4_E2M1) { - return kernel::wrapper::CompressFp4::forward(srcptr, reinterpret_cast(dstptr), row, col, - ld_src, - ld_dst); // ld_dst here not stride - } else { - assert(0); - return BTLA_CODE::NotSupport; - } - } }; struct ParamWeightKBlockNFloat { diff --git a/bestla/bestla/bestla_storage.h b/bestla/bestla/bestla_storage.h index 7f00f9aa0..e5c441087 100644 --- a/bestla/bestla/bestla_storage.h +++ b/bestla/bestla/bestla_storage.h @@ -712,11 +712,20 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase { InfoType::resize(NPad, KPad, Block, N, K, qtype); auto bits = utils::bestla_dtype_bits(qtype); auto elesize = static_cast(NPad) * KPad; + auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here if (qtype == BTLA_DTYPE::S3_CLIP) - elesize = - static_cast(utils::padto(KPad, 128)) * NPad; // pad K-dim to 128 because 128pack round2 interleave. - // round2 interleave ld_dim == pad_to(KPad,128) * NTILE - auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here + bytes = + utils::updiv(static_cast(KPad) * NPad * 2, 8) + utils::updiv(static_cast(KPad) * NPad * 1, 8); + else if (qtype == BTLA_DTYPE::S5_CLIP) + bytes = + utils::updiv(static_cast(KPad) * NPad * 4, 8) + utils::updiv(static_cast(KPad) * NPad * 1, 8); + else if (qtype == BTLA_DTYPE::S6_CLIP) + bytes = + utils::updiv(static_cast(KPad) * NPad * 4, 8) + utils::updiv(static_cast(KPad) * NPad * 2, 8); + else if (qtype == BTLA_DTYPE::S7_CLIP) + bytes = utils::updiv(static_cast(KPad) * NPad * 4, 8) + + utils::updiv(static_cast(KPad) * NPad * 2, 8) + + utils::updiv(static_cast(KPad) * NPad * 1, 8); mQBuf.resize(bytes); int nk_scale = utils::updiv(KPad, Block); auto gemm_comp = bestla::gemm::CoreAttr::get_comp(mCoreId); diff --git a/bestla/bestla/bestla_utils.h b/bestla/bestla/bestla_utils.h index 57c840ebd..67c59d00f 100644 --- a/bestla/bestla/bestla_utils.h +++ b/bestla/bestla/bestla_utils.h @@ -356,12 +356,20 @@ inline const char* bestla_dtype_str(BTLA_DTYPE dtype) { return "signed_int8"; case BTLA_DTYPE::U8: return "unsigned_int8"; + case BTLA_DTYPE::S7_CLIP: + return "int7_clip"; + case BTLA_DTYPE::S6_CLIP: + return "int6_clip"; + case BTLA_DTYPE::S5_CLIP: + return "int5_clip"; case BTLA_DTYPE::S4_CLIP: return "int4_clip"; case BTLA_DTYPE::S3_CLIP: return "int3_clip"; case BTLA_DTYPE::S2_CLIP: return "int2_clip"; + case BTLA_DTYPE::S1_CLIP: + return "int1_clip"; case BTLA_DTYPE::F4_E2M1: return "fp4_e2m1"; case BTLA_DTYPE::F4_BNB: diff --git a/bestla/bestla/bestla_wrapper.h b/bestla/bestla/bestla_wrapper.h index aa603a551..3bc0da0ab 100644 --- a/bestla/bestla/bestla_wrapper.h +++ b/bestla/bestla/bestla_wrapper.h @@ -23,6 +23,58 @@ namespace bestla { namespace wrapper { namespace gemv_nbits { +class S6 { + public: + static int constexpr NBits = 6; + template + static inline utils::GemvParamB createB(storage::gemm::StorageWeightKBlockNInteger* packedW) { + auto isasym = packedW->IsAsym(); + auto bzptr = packedW->template ZPtr(); + int ld_scaleb = packedW->CStep(); + auto bwptr = packedW->template WPtr(); + auto bit2_offset = packedW->mNPad * packedW->mKPad / 2; + utils::GemvParamB paramB{ + bwptr, bwptr + bit2_offset, nullptr, packedW->template SPtr(), isasym ? bzptr : nullptr, + NBits, ld_scaleb, packedW->mKPad}; + return paramB; + } + template + static void updateBNStep(utils::GemvParamB& paramB, int n_offset) { + paramB.b4ptr += n_offset * paramB.kpad / 2; + paramB.b2ptr += n_offset * paramB.kpad / 4; + paramB.sptr += n_offset; + if (paramB.zpptr) { + paramB.zpptr += n_offset; + } + } +}; + +class S5 { + public: + static int constexpr NBits = 5; + template + static inline utils::GemvParamB createB(storage::gemm::StorageWeightKBlockNInteger* packedW) { + auto isasym = packedW->IsAsym(); + auto bzptr = packedW->template ZPtr(); + int ld_scaleb = packedW->CStep(); + auto bwptr = packedW->template WPtr(); + auto bit1_offset = packedW->mNPad * packedW->mKPad / 2; + utils::GemvParamB paramB{ + bwptr, nullptr, bwptr + bit1_offset, packedW->template SPtr(), isasym ? bzptr : nullptr, + NBits, ld_scaleb, packedW->mKPad}; + return paramB; + } + template + static void updateBNStep(utils::GemvParamB& paramB, int n_offset) { + paramB.b4ptr += n_offset * paramB.kpad / 2; + paramB.b1ptr += n_offset * paramB.kpad / 8; + paramB.sptr += n_offset; + if (paramB.zpptr) { + paramB.zpptr += n_offset; + } + } +}; + class S4 { public: static int constexpr NBits = 4; @@ -161,6 +213,8 @@ class LauncherBase { bool impl = true; impl &= _param.paramB.packedW->mDType == BTLA_DTYPE::S4_CLIP || _param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP || + _param.paramB.packedW->mDType == BTLA_DTYPE::S6_CLIP || + _param.paramB.packedW->mDType == BTLA_DTYPE::S5_CLIP || _param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP; if constexpr (support()) { impl &= _param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::F32 || @@ -233,6 +287,36 @@ class LauncherBase { } return; } + if (_param.paramB.packedW->mDType == BTLA_DTYPE::S5_CLIP) { + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + } + return; + } + if (_param.paramB.packedW->mDType == BTLA_DTYPE::S6_CLIP) { + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + } + return; + } if (_param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP) { if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { if (m == 1) gemv_kblock(_param, _config); @@ -418,6 +502,8 @@ class LauncherIntKBlock { static bool implemented(const Param& _param) { bool impl = true; impl &= _param.paramB.packedW->mDType == BTLA_DTYPE::S4_CLIP || + _param.paramB.packedW->mDType == BTLA_DTYPE::S6_CLIP || + _param.paramB.packedW->mDType == BTLA_DTYPE::S5_CLIP || _param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP || _param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP; impl &= _param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::F32 || @@ -490,7 +576,36 @@ class LauncherIntKBlock { } return; } + if (_param.paramB.packedW->mDType == BTLA_DTYPE::S5_CLIP) { + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + } + return; + } + if (_param.paramB.packedW->mDType == BTLA_DTYPE::S6_CLIP) { + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + } + return; + } if (_param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP) { if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { if (m == 1) gemv_kblock(_param, _config); diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index 8856010b6..6613f6760 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -520,6 +520,73 @@ static inline BTLA_CODE decompress_kblock_s4_s8_pack1_row(utils::int4x2* srcptr, return BTLA_CODE::Success; } +static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, size_t elesize, int8_t* tmp, + size_t tmpsize) { + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + size_t velt = utils::padto_le(elesize, 32); + size_t i = 0; + auto vbias = _mm256_set1_epi8(8); + for (; i < velt; i += 32) { + auto vout_y = unpack_4bits(reinterpret_cast(srcptr + i / 2), vmask); + vout_y = _mm256_sub_epi8(vout_y, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout_y); + } + if (velt < elesize) { + if (elesize >= 32) { + i = elesize - 32; + auto vout_y = unpack_4bits(reinterpret_cast(srcptr + i / 2), vmask); + vout_y = _mm256_sub_epi8(vout_y, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout_y); + } else { + ref::decompress_kblock_s4_s8<1, 1>(srcptr + i / 2, nullptr, dstptr + i, 0, 0, 0, 0, 1, elesize - i, nullptr, 0); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s4_s8(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, + int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, + size_t tmpsize) { + if (zpptr) { + typedef BTLA_CODE (*decompfunc)(utils::int4x2 * srcptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, + int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); + decompfunc func = nullptr; + if (col == NTILE) { + if constexpr (PackRow == 4) { + func = &decompress_kblock_s4_s8_pack4_row; + } + if constexpr (PackRow == 1) { + func = &decompress_kblock_s4_s8_pack1_row; + } + if constexpr (PackRow == 2) { + func = &decompress_kblock_s4_s8_pack2_row; + } + if (func) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + (*func)(srcptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + (*func)(srcptr + head_size * NTILE / 2, zpptr, dstptr + head_size * NTILE, blocksize, ldzp, n_offset, + head_end, body_size, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + } + assert(0); + return BTLA_CODE::NotSupport; + } else { + size_t elesize = static_cast(row) * col; + return decompress_s4_s8(srcptr, dstptr, elesize, tmp, tmpsize); + } + return BTLA_CODE::Success; +} + template static inline BTLA_CODE decompress_kblock_s2_s8_pack4_row(utils::bit2x4* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, @@ -661,72 +728,6 @@ static inline BTLA_CODE decompress_kblock_s2_s8_pack1_row(utils::bit2x4* srcptr, return BTLA_CODE::Success; } -static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, size_t elesize, int8_t* tmp, - size_t tmpsize) { - uint32_t mask = 0x0f0f0f0f; - auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); - size_t velt = utils::padto_le(elesize, 32); - size_t i = 0; - auto vbias = _mm256_set1_epi8(8); - for (; i < velt; i += 32) { - auto vout_y = unpack_4bits(reinterpret_cast(srcptr + i / 2), vmask); - vout_y = _mm256_sub_epi8(vout_y, vbias); - _mm256_storeu_si256((__m256i*)(dstptr + i), vout_y); - } - if (velt < elesize) { - if (elesize >= 32) { - i = elesize - 32; - auto vout_y = unpack_4bits(reinterpret_cast(srcptr + i / 2), vmask); - vout_y = _mm256_sub_epi8(vout_y, vbias); - _mm256_storeu_si256((__m256i*)(dstptr + i), vout_y); - } else { - ref::decompress_kblock_s4_s8<1, 1>(srcptr + i / 2, nullptr, dstptr + i, 0, 0, 0, 0, 1, elesize - i, nullptr, 0); - } - } - return BTLA_CODE::Success; -} - -template -inline BTLA_CODE decompress_kblock_s4_s8(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, - int n_offset, int k_offset, int row, int col, int8_t* tmp, size_t tmpsize) { - if (zpptr) { - typedef BTLA_CODE (*decompfunc)(utils::int4x2 * srcptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, - int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); - decompfunc func = nullptr; - if (col == NTILE) { - if constexpr (PackRow == 4) { - func = &decompress_kblock_s4_s8_pack4_row; - } - if constexpr (PackRow == 1) { - func = &decompress_kblock_s4_s8_pack1_row; - } - if constexpr (PackRow == 2) { - func = &decompress_kblock_s4_s8_pack2_row; - } - if (func) { - int head_end = utils::padto(k_offset, blocksize); - head_end = std::min(head_end, k_offset + row); - int head_size = head_end - k_offset; - if (head_size > 0) { - (*func)(srcptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); - } - int body_size = row - head_size; - if (body_size > 0) { - (*func)(srcptr + head_size * NTILE / 2, zpptr, dstptr + head_size * NTILE, blocksize, ldzp, n_offset, - head_end, body_size, tmp, tmpsize); - } - return BTLA_CODE::Success; - } - } - assert(0); - return BTLA_CODE::NotSupport; - } else { - size_t elesize = static_cast(row) * col; - return decompress_s4_s8(srcptr, dstptr, elesize, tmp, tmpsize); - } - return BTLA_CODE::Success; -} - static inline BTLA_CODE decompress_s2_s8(utils::bit2x4* bit2ptr, int8_t* dstptr, size_t unpack_elt, int8_t* tmp, size_t tmpsize) { int constexpr VBits = 256; @@ -800,46 +801,6 @@ static inline BTLA_CODE decompress_kblock_s2_s8(utils::bit2x4* bit2ptr, int8_t* return BTLA_CODE::Success; } -static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int8_t* dstptr, - size_t unpack_elt, int8_t* tmp, size_t tmpsize) { - int constexpr VBits = 256; - int constexpr VElt = VBits / 8; - int i = 0; - uint64_t mask0 = 0x0303030303030303; - auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); - auto vbias = _mm256_set1_epi8(4); - auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); - auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, - 13, 9, 5, 1, 12, 8, 4, 0); - auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); - - const __m256i highMask = _mm256_set1_epi8(0x04); - const __m256i bit1Mask = _mm256_set1_epi32(0x0F); - const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); - const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); - int elt_pad = utils::padto_le(unpack_elt, VElt); - for (; i < elt_pad; i += VElt) { - auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); - vout = _mm256_or_si256(vout, vb1); - vout = _mm256_sub_epi8(vout, vbias); - _mm256_storeu_si256((__m256i*)(dstptr + i), vout); - } - if (elt_pad < unpack_elt) { - if (unpack_elt >= 32) { - i = unpack_elt - 32; - auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); - vout = _mm256_or_si256(vout, vb1); - vout = _mm256_sub_epi8(vout, vbias); - _mm256_storeu_si256((__m256i*)(dstptr + i), vout); - } else { - ref::decompress_s3_s8(bit2ptr + i / 4, bit1ptr + i / 8, dstptr + i, unpack_elt - i, tmp, tmpsize); - } - } - return BTLA_CODE::Success; -} - template static inline BTLA_CODE decompress_kblock_s3_s8_pack4_row(utils::bit2x4* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, int n_offset, @@ -1016,6 +977,46 @@ static inline BTLA_CODE decompress_kblock_s3_s8_pack1_row(utils::bit2x4* srcptr, return BTLA_CODE::Success; } +static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int8_t* dstptr, + size_t unpack_elt, int8_t* tmp, size_t tmpsize) { + int constexpr VBits = 256; + int constexpr VElt = VBits / 8; + int i = 0; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vbias = _mm256_set1_epi8(4); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + int elt_pad = utils::padto_le(unpack_elt, VElt); + for (; i < elt_pad; i += VElt) { + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vout = _mm256_or_si256(vout, vb1); + vout = _mm256_sub_epi8(vout, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + } + if (elt_pad < unpack_elt) { + if (unpack_elt >= 32) { + i = unpack_elt - 32; + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vout = _mm256_or_si256(vout, vb1); + vout = _mm256_sub_epi8(vout, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + } else { + ref::decompress_s3_s8(bit2ptr + i / 4, bit1ptr + i / 8, dstptr + i, unpack_elt - i, tmp, tmpsize); + } + } + return BTLA_CODE::Success; +} + template static inline BTLA_CODE decompress_kblock_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, @@ -1059,1386 +1060,2756 @@ static inline BTLA_CODE decompress_kblock_s3_s8(utils::bit2x4* bit2ptr, utils::b return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _S_T* scales, int k_offset, int kblock, int NPad, BTLA_DTYPE src_f8_type) { - int align_col = col / 16 * 16; - int col_tail = col - align_col; - auto ebits = utils::bestla_dtype_get_f8_ebits(src_f8_type); - auto mantissabit = 7 - ebits; - auto sign_revert_and_mask = _mm256_set1_epi32(0x80000000); - auto e_revert_and_mask = _mm256_set1_epi32(0x0000007f); - auto e_revert_shift = _mm256_set1_epi32(1); - e_revert_shift = _mm256_slli_epi32(e_revert_shift, ebits - 1); - e_revert_shift = _mm256_sub_epi32(e_revert_shift, _mm256_set1_epi32(128)); - auto mantissa_revert_and_mask = _mm256_set1_epi32(0x007fffff); - auto packrow2_permute_idx = _mm256_setr_epi32(0, 0, 1, 1, 2, 2, 3, 3); - for (int i = 0; i < row; i++) { - int kpos = (k_offset + i) / kblock; - auto sptr = scales + kpos * NPad; - int j = 0; - auto quant = [&]() { - auto sign_revert = _mm256_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr + i * ld_src + j))); - auto e_revert = sign_revert; - auto mantissa_revert = sign_revert; - sign_revert = _mm256_slli_epi32(sign_revert, 24); - sign_revert = _mm256_and_si256(sign_revert, sign_revert_and_mask); - e_revert = _mm256_and_si256(e_revert, e_revert_and_mask); - e_revert = _mm256_srli_epi32(e_revert, mantissabit); - if constexpr (WITH_SCALE && std::is_same_v<_S_T, utils::f8>) { - auto scale = _mm256_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(sptr + j / _PACK_ROW))); - if constexpr (_PACK_ROW == 2) scale = _mm256_permutevar8x32_epi32(packrow2_permute_idx, scale); - e_revert = _mm256_add_epi32(e_revert, scale); - } - e_revert = _mm256_sub_epi32(e_revert, e_revert_shift); - e_revert = _mm256_slli_epi32(e_revert, 23); - mantissa_revert = _mm256_slli_epi32(mantissa_revert, 23 - mantissabit); - mantissa_revert = _mm256_and_si256(mantissa_revert, mantissa_revert_and_mask); - auto fp_v = _mm256_or_ps(_mm256_castsi256_ps(sign_revert), _mm256_castsi256_ps(e_revert)); - fp_v = _mm256_or_ps(fp_v, _mm256_castsi256_ps(mantissa_revert)); - if constexpr (WITH_SCALE && std::is_same_v<_S_T, float>) { - auto scale = _mm256_loadu_ps(sptr + j / _PACK_ROW); - if constexpr (_PACK_ROW == 2) scale = _mm256_permutevar8x32_ps(scale, packrow2_permute_idx); - fp_v = _mm256_mul_ps(fp_v, scale); - } - if constexpr (std::is_same_v<_DST_T, float>) { - _mm256_storeu_ps(dstptr + i * ld_dst + j, fp_v); - } else { - assert(0); - } - }; - for (; j < align_col; j += 8) quant(); - for (; j < col; j++) { - auto fp_v = ref::f8_to_fp32(srcptr[i * ld_src + j], src_f8_type); - if constexpr (WITH_SCALE) { - if constexpr (std::is_same_v<_S_T, utils::f8>) { - dstptr[i * ld_dst + j] = sptr[j / _PACK_ROW].mul(fp_v); - } else if constexpr (std::is_same_v<_S_T, float>) { - dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW]; - } - } else { - dstptr[i * ld_dst + j] = fp_v; +template +static inline BTLA_CODE decompress_kblock_s5_s8_pack4_row(utils::bit4x2* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 4; + __m256i v_zp_y[NReg]; + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi32(zptr + i * 8, vindex); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 16, vmask); + auto vb1 = unpack_1bits(b1ptr + i * 4, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); } } } return BTLA_CODE::Success; } -template -static inline BTLA_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, - const int dststep, const int M, const int N) { - int constexpr Vlen = 8; - auto vN = utils::padto_le(N, Vlen); - int j = 0; - for (; j < vN; j += Vlen) { - __m256 valpha; - if constexpr (std::is_same_v) { - valpha = _mm256_loadu_ps(alpha + j); - } else if constexpr (std::is_same_v) { - auto tmp = _mm_loadu_si128(reinterpret_cast(alpha + j)); - valpha = ymm_cvt_bf16_fp32(tmp); - } else if constexpr (std::is_same_v) { - auto ebit = _mm256_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast(alpha + j))); - ebit = _mm256_add_epi32(_mm256_set1_epi32(127), ebit); - valpha = _mm256_castsi256_ps(_mm256_slli_epi32(ebit, 23)); +template +static inline BTLA_CODE decompress_kblock_s5_s8_pack2_row(utils::bit4x2* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 2; + int constexpr Unroll = 2; + __m256i v_zp_y[NReg]; + const auto vindex = _mm256_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, + 8, 6, 6, 4, 4, 2, 2, 0, 0); + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + memcpy(tmp, zptr, NTILE * sizeof(int8_t)); + memcpy(tmp + NTILE, zptr, NTILE * sizeof(int8_t)); + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi16_v16(tmp + i * 16, vindex); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); } - for (size_t i = 0; i < M; i++) { - auto vsrc = _mm256_loadu_ps(srcptr + i * srcstep + j); - auto vsrc1 = _mm256_loadu_ps(dstptr + i * dststep + j); - auto vdst = _mm256_fmadd_ps(valpha, vsrc, vsrc1); - _mm256_storeu_ps(dstptr + i * dststep + j, vdst); + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, PackRow * Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += PackRow * Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 16, vmask); + auto vb1 = unpack_1bits(b1ptr + i * 4, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } } - } - for (; j < N; j += 1) { - for (size_t i = 0; i < M; i++) { - if constexpr (!std::is_same_v) { - dstptr[i * dststep + j] += alpha[j] * srcptr[i * srcstep + j]; - } else { - dstptr[i * dststep + j] += alpha[j].mul(srcptr[i * srcstep + j]); + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb4ptr = tmp; + memcpy(tmpb4ptr, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpb1ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb1ptr, bit1ptr + (ir + ib) * NTILE / 8, k_tail * NTILE / 8); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits((utils::bit4x2*)(tmpb4ptr + i * 16), vmask); + auto vb1 = unpack_1bits((utils::bit1x8*)(tmpb1ptr + i * 4), bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(tmpout + i * 32), v_s8_y); } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); } } return BTLA_CODE::Success; } -template -static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales, __m256 vLutL, __m256 vLutH) { - static_assert(N % 8 == 0); - int constexpr VLoop = N / 8; - auto v7 = _mm256_set1_epi32(7); - auto v8 = _mm256_set1_epi32(8); - for (int iv = 0; iv < VLoop; iv++) { - auto idx = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr + iv * 8)); - auto pad_idx = _mm256_cvtepu8_epi32(idx); - auto mskgt8 = _mm256_cmpgt_epi32(pad_idx, v7); - auto fp32_dq_v0 = _mm256_permutevar8x32_ps(vLutL, pad_idx); - pad_idx = _mm256_sub_epi32(pad_idx, v8); - auto fp32_dq_v1 = _mm256_permutevar8x32_ps(vLutH, pad_idx); - auto fp32_dq_v = _mm256_blendv_ps(fp32_dq_v0, fp32_dq_v1, _mm256_castsi256_ps(mskgt8)); - if constexpr (MULS_T) { - fp32_dq_v = _mm256_mul_ps(fp32_dq_v, vscales[iv]); - } - if constexpr (std::is_same_v<_DST_T, float>) { - _mm256_storeu_ps(dstptr + iv * 8, fp32_dq_v); - } else if constexpr (std::is_same_v<_DST_T, utils::bf16>) { - auto bf16v = ymm_cvt_fp32_bf16(fp32_dq_v); - _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr + iv * 8), bf16v); - } - } -} +template +static inline BTLA_CODE decompress_kblock_s5_s8_pack1_row(utils::bit4x2* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + int constexpr UnpackLoop = Unroll * NTILE / 32; + int constexpr FullRange = 1 << (5 - 1); + __m256i v_zp_y[UnpackLoop]; + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); -template -static inline void convert_s4_s8_N_avx2(int8_t* dstptr, int8_t* srcptr, __m256i mask) { - static_assert(N % 2 == 0); - static_assert(N <= 64); - const auto vbias = _mm256_set1_epi8(8); - if constexpr (N == 32) { - auto dst0 = unpack_4bits(srcptr, mask); - if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { - dst0 = _mm256_sub_epi8(dst0, vbias); + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, zptr, NTILE * sizeof(int8_t)); } - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); - } else if constexpr (N > 32) { - auto dst0 = unpack_4bits(srcptr, mask); - if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { - dst0 = _mm256_sub_epi8(dst0, vbias); + for (int i = 0; i < UnpackLoop; i++) { + v_zp_y[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); } - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); - int8_t temp[32]; - memcpy(temp, srcptr + 16, (N - 32) / 2); - dst0 = unpack_4bits(temp, mask); - if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { - dst0 = _mm256_sub_epi8(dst0, vbias); + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 16, vmask); + auto vb1 = unpack_1bits(b1ptr + i * 4, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } } - _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); - memcpy(dstptr + 32, temp, (N - 32)); - } else { - int8_t temp[32]; - memcpy(temp, srcptr, N / 2); - auto dst0 = unpack_4bits(temp, mask); - if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { - dst0 = _mm256_sub_epi8(dst0, vbias); + + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb4ptr = tmp; + memcpy(tmpb4ptr, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpb1ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb1ptr, bit1ptr + (ir + ib) * NTILE / 8, k_tail * NTILE / 8); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_4bits((utils::bit4x2*)(tmpb4ptr + i * 16), vmask); + auto vb1 = unpack_1bits((utils::bit1x8*)(tmpb1ptr + i * 4), bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(tmpout + i * 32), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); } - _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); - memcpy(dstptr, temp, N); } + return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { +static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8* bit1ptr, int8_t* dstptr, + size_t unpack_elt, int8_t* tmp, size_t tmpsize) { + int constexpr VBits = 256; + int constexpr VElt = VBits / 8; + int i = 0; + int constexpr FullRange = 1 << (5 - 1); uint32_t mask = 0x0f0f0f0f; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); - float* LUT; - static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, - "Unsupported F4 type"); - if constexpr (F4_T == BTLA_DTYPE::F4_BNB) { - LUT = fp4_bnb_dequant_fp32_LUT; - } else if constexpr (F4_T == BTLA_DTYPE::F4_NF4) { - LUT = nf4_dequant_fp32_LUT; - } else if constexpr (F4_T == BTLA_DTYPE::F4_E2M1) { - LUT = fp4_e2m1_dequant_fp32_LUT; + auto vbias = _mm256_set1_epi8(FullRange); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + int elt_pad = utils::padto_le(unpack_elt, VElt); + for (; i < elt_pad; i += VElt) { + auto vout = unpack_4bits(bit4ptr + i / 2, vmask); + auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vout = _mm256_or_si256(vout, vb1); + vout = _mm256_sub_epi8(vout, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); } - auto vLutL = _mm256_loadu_ps(LUT); - auto vLutH = _mm256_loadu_ps(LUT + 8); - if (col == ld_src) { - size_t elesize = static_cast(row) * col; - size_t velt = utils::padto_le(elesize, 32); - size_t i = 0; - assert(tmpsize >= 32); - for (; i < velt; i += 32) { - convert_s4_s8_N_avx2<32, F4_T>(tmp, reinterpret_cast(srcptr + i / 2), vmask); - dequant_f4_N<32, DST_T, F4_T, false>(dstptr + i, tmp, nullptr, vLutL, vLutH); + if (elt_pad < unpack_elt) { + if (unpack_elt >= 32) { + i = unpack_elt - 32; + auto vout = unpack_4bits(bit4ptr + i / 2, vmask); + auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vout = _mm256_or_si256(vout, vb1); + vout = _mm256_sub_epi8(vout, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + } else { + ref::decompress_s5_s8(bit4ptr + i / 2, bit1ptr + i / 8, dstptr + i, unpack_elt - i, tmp, tmpsize); } - for (; i < elesize; i += 2) { - auto tmp = srcptr[i / 2]; - dstptr[i + 0] = static_cast(ref::f4_unpack(tmp.x)); - dstptr[i + 1] = static_cast(ref::f4_unpack(tmp.y)); + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, + int row, int col, int8_t* tmp, size_t tmpsize) { + if (zpptr) { + typedef BTLA_CODE (*decompfunc)(utils::bit4x2 * bit4ptr, utils::bit1x8 * bit1ptr, int8_t * zpptr, int8_t * dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, + size_t tmpsize); + decompfunc func = nullptr; + if (col == NTILE) { + if constexpr (PackRow == 1) { + func = &decompress_kblock_s5_s8_pack1_row; + } + if constexpr (PackRow == 2) { + func = &decompress_kblock_s5_s8_pack2_row; + } + if constexpr (PackRow == 4) { + func = &decompress_kblock_s5_s8_pack4_row; + } + if (func) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + (*func)(bit4ptr, bit1ptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + (*func)(bit4ptr + head_size * NTILE / 2, bit1ptr + head_size * NTILE / 8, zpptr, dstptr + head_size * NTILE, + blocksize, ldzp, n_offset, head_end, body_size, tmp, tmpsize); + } + return BTLA_CODE::Success; + } } - return BTLA_CODE::Success; + assert(0); + return BTLA_CODE::NotSupport; + } else { + size_t elesize = static_cast(row) * col; + return decompress_s5_s8(bit4ptr, bit1ptr, dstptr, elesize, tmp, tmpsize); } return BTLA_CODE::Success; } -template -static inline BTLA_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, - int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, int8_t* tmpbuf, - size_t tmpsize) { +template +static inline BTLA_CODE decompress_kblock_s6_s8_pack4_row(utils::bit4x2* srcptr, utils::bit2x4* bit2ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 4) == 0); + int constexpr PackRow = 4; + __m256i v_zp_y[NReg]; + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + int constexpr FullRange = 1 << (6 - 1); uint32_t mask = 0x0f0f0f0f; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); - float* LUT = nullptr; - if constexpr (QT_T == BTLA_DTYPE::F4_BNB) { - LUT = fp4_bnb_dequant_fp32_LUT; - } else if constexpr (QT_T == BTLA_DTYPE::F4_NF4) { - LUT = nf4_dequant_fp32_LUT; - } else if constexpr (QT_T == BTLA_DTYPE::F4_E2M1) { - LUT = fp4_e2m1_dequant_fp32_LUT; - } - __m256 vLutL, vLutH; - if (LUT) { - vLutL = _mm256_loadu_ps(LUT); - vLutH = _mm256_loadu_ps(LUT + 8); - } - int constexpr NReg = _NCOL / 8; - assert(col == _NCOL); - assert(ld_src == _NCOL); - assert(ld_dst == _NCOL); - __m256 vscales[NReg]; - __m256i vzps[NReg]; - int constexpr UnrollRow = 4; - assert(kblock % UnrollRow == 0); - int constexpr NTile = 32; - int constexpr Loop32 = _NCOL * UnrollRow / NTile; - assert(tmpsize >= (_NCOL * UnrollRow)); - int row0 = kblock - k_offset % kblock; - row0 = row0 == kblock ? 0 : row0; - row0 = row0 > row ? row : row0; - int row1 = row - row0; - int irow = 0; - auto dequantize = [&](_DST_T* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps = nullptr) { - if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { - dequant_s8_N_avx2<_NCOL, _IS_SYM>(dstptr, srcptr, vscales, vzps); - } else { - dequant_f4_N<_NCOL, _DST_T, QT_T, true>(dstptr, srcptr, vscales, vLutL, vLutH); + auto vbias = _mm256_set1_epi8(FullRange); + + uint32_t mask0 = 0x03030303; + auto vmask0 = _mm256_set1_epi32(*(int32_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi32(zptr + i * 8, vindex); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); } - }; - if (row0) { - int rowpad4 = utils::padto_le(row0, UnrollRow); - for (int iv = 0; iv < NReg; iv++) { - vscales[iv] = _mm256_loadu_ps(scales + (k_offset + irow) / kblock * NPad + iv * 8); - if constexpr (!_IS_SYM) { - auto tmp = - _mm_loadl_epi64(reinterpret_cast<__m128i*>(zero_points + (k_offset + irow) / kblock * NPad + iv * 8)); - vzps[iv] = _mm256_cvtepi8_epi32(tmp); + int k_remain = utils::remainsize(ir, row, blocksize); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b2ptr = bit2ptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 16, vmask); + auto vb1 = unpack_2bits(b2ptr + i * 8, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); } } - for (; irow < rowpad4; irow += UnrollRow) { - for (int iter16 = 0; iter16 < Loop32; iter16++) - convert_s4_s8_N_avx2( - tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + irow * ld_src / 2 + NTile / 2 * iter16), vmask); - for (int iterr = 0; iterr < UnrollRow; iterr++) - dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps); - } - for (; irow < row0; irow++) { - convert_s4_s8_N_avx2<_NCOL, QT_T>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2), vmask); - - dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps); - } } + return BTLA_CODE::Success; +} - int row1_blk = utils::padto_le(row1, kblock) + row0; - for (; irow < row1_blk; irow += kblock) { - for (int iv = 0; iv < NReg; iv++) { - vscales[iv] = _mm256_loadu_ps(scales + (k_offset + irow) / kblock * NPad + iv * 8); - if constexpr (!_IS_SYM) { - auto tmp = - _mm_loadl_epi64(reinterpret_cast<__m128i*>(zero_points + (k_offset + irow) / kblock * NPad + iv * 8)); - vzps[iv] = _mm256_cvtepi8_epi32(tmp); - } - } - for (int irr = 0; irr < kblock; irr += UnrollRow) { - for (int iter16 = 0; iter16 < Loop32; iter16++) - convert_s4_s8_N_avx2( - tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + NTile / 2 * iter16), - vmask); - for (int iterr = 0; iterr < UnrollRow; iterr++) - dequantize(dstptr + (irow + irr + iterr) * ld_src, tmpbuf + iterr * _NCOL, vscales, vzps); +template +static inline BTLA_CODE decompress_kblock_s6_s8_pack2_row(utils::bit4x2* srcptr, utils::bit2x4* bit2ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 2; + int constexpr Unroll = 2; + __m256i v_zp_y[NReg]; + const auto vindex = _mm256_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, + 8, 6, 6, 4, 4, 2, 2, 0, 0); + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); + + uint32_t mask0 = 0x03030303; + auto vmask0 = _mm256_set1_epi32(*(int32_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + memcpy(tmp, zptr, NTILE * sizeof(int8_t)); + memcpy(tmp + NTILE, zptr, NTILE * sizeof(int8_t)); + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi16_v16(tmp + i * 16, vindex); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); } - } - if (irow < row) { - for (int iv = 0; iv < NReg; iv++) { - vscales[iv] = _mm256_loadu_ps(scales + (k_offset + irow) / kblock * NPad + iv * 8); - if constexpr (!_IS_SYM) { - auto tmp = - _mm_loadl_epi64(reinterpret_cast<__m128i*>(zero_points + (k_offset + irow) / kblock * NPad + iv * 8)); - vzps[iv] = _mm256_cvtepi8_epi32(tmp); + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, PackRow * Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += PackRow * Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b2ptr = bit2ptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 16, vmask); + auto vb1 = unpack_2bits(b2ptr + i * 8, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); } } - auto rowre = row - irow; - int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow; - for (; irow < rowpad4; irow += UnrollRow) { - for (int iter16 = 0; iter16 < Loop32; iter16++) - convert_s4_s8_N_avx2( - tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + irow * ld_src / 2 + NTile / 2 * iter16), vmask); - for (int iterr = 0; iterr < UnrollRow; iterr++) - dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps); - } - for (; irow < row; irow++) { - convert_s4_s8_N_avx2<_NCOL, QT_T>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2), vmask); - dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps); + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb4ptr = tmp; + memcpy(tmpb4ptr, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpb2ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb2ptr, bit2ptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits((utils::bit4x2*)(tmpb4ptr + i * 16), vmask); + auto vb1 = unpack_2bits((utils::bit2x4*)(tmpb2ptr + i * 8), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(tmpout + i * 32), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); } } return BTLA_CODE::Success; } -template -static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, - int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, int8_t* tmp, - size_t tmpsize) { - return BTLA_CODE::NotSupport; -} - -template -inline BTLA_CODE decompress_kblock_s8_fp_row(int8_t* srcptr, DST_T* dstptr, int row, void* scales_, BTLA_DTYPE sdtype, - int8_t* zero_points, int k_offset, int n_offset, int blocksize, int ldzp, - int8_t* tmp, size_t tmpsize) { +template +static inline BTLA_CODE decompress_kblock_s6_s8_pack1_row(utils::bit4x2* srcptr, utils::bit2x4* bit2ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { int constexpr NReg = NTILE / 8; - const auto DstSize = row * NTILE * sizeof(DST_T); - const auto S8Size = row * NTILE * sizeof(int8_t); - if (zero_points == nullptr) { - for (int ir = 0; ir < row; ir += blocksize) { - int k_remain = utils::remainsize(ir, row, blocksize); - int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; - if constexpr (PackRow == 1) { - __m256 vscale_y[NReg]; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - for (int i = 0; i < NReg; i++) vscale_y[i] = _mm256_loadu_ps(sptr + i * 8); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * 8); - } - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + int constexpr UnpackLoop = Unroll * NTILE / 32; + int constexpr FullRange = 1 << (6 - 1); + __m256i v_zp_y[UnpackLoop]; + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); + + uint32_t mask0 = 0x03030303; + auto vmask0 = _mm256_set1_epi32(*(int32_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, zptr, NTILE * sizeof(int8_t)); + } + for (int i = 0; i < UnpackLoop; i++) { + v_zp_y[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b2ptr = bit2ptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 16, vmask); + auto vb1 = unpack_2bits(b2ptr + i * 8, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } + } + + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb4ptr = tmp; + memcpy(tmpb4ptr, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpb2ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb2ptr, bit2ptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_4bits((utils::bit4x2*)(tmpb4ptr + i * 16), vmask); + auto vb1 = unpack_2bits((utils::bit2x4*)(tmpb2ptr + i * 8), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(tmpout + i * 32), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, int8_t* dstptr, + size_t unpack_elt, int8_t* tmp, size_t tmpsize) { + int constexpr VBits = 256; + int constexpr VElt = VBits / 8; + int i = 0; + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); + + uint32_t mask0 = 0x03030303; + auto vmask0 = _mm256_set1_epi32(*(int32_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + int elt_pad = utils::padto_le(unpack_elt, VElt); + for (; i < elt_pad; i += VElt) { + auto vout = unpack_4bits(bit4ptr + i / 2, vmask); + auto vb1 = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + vout = _mm256_or_si256(vout, vb1); + vout = _mm256_sub_epi8(vout, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + } + if (elt_pad < unpack_elt) { + if (unpack_elt >= 32) { + i = unpack_elt - 32; + auto vout = unpack_4bits(bit4ptr + i / 2, vmask); + auto vb1 = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + vout = _mm256_or_si256(vout, vb1); + vout = _mm256_sub_epi8(vout, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + } else { + ref::decompress_s6_s8(bit4ptr + i / 2, bit2ptr + i / 4, dstptr + i, unpack_elt - i, tmp, tmpsize); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, + int row, int col, int8_t* tmp, size_t tmpsize) { + if (zpptr) { + typedef BTLA_CODE (*decompfunc)(utils::bit4x2 * bit4ptr, utils::bit2x4 * bit2ptr, int8_t * zpptr, int8_t * dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, + size_t tmpsize); + decompfunc func = nullptr; + if (col == NTILE) { + if constexpr (PackRow == 1) { + func = &decompress_kblock_s6_s8_pack1_row; + } + if constexpr (PackRow == 2) { + func = &decompress_kblock_s6_s8_pack2_row; + } + if constexpr (PackRow == 4) { + func = &decompress_kblock_s6_s8_pack4_row; + } + if (func) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + (*func)(bit4ptr, bit2ptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + (*func)(bit4ptr + head_size * NTILE / 2, bit2ptr + head_size * NTILE / 4, zpptr, dstptr + head_size * NTILE, + blocksize, ldzp, n_offset, head_end, body_size, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + } + assert(0); + return BTLA_CODE::NotSupport; + } else { + size_t elesize = static_cast(row) * col; + return decompress_s6_s8(bit4ptr, bit2ptr, dstptr, elesize, tmp, tmpsize); + } + return BTLA_CODE::Success; +} + +template +inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + _S_T* scales, int k_offset, int kblock, int NPad, BTLA_DTYPE src_f8_type) { + int align_col = col / 16 * 16; + int col_tail = col - align_col; + auto ebits = utils::bestla_dtype_get_f8_ebits(src_f8_type); + auto mantissabit = 7 - ebits; + auto sign_revert_and_mask = _mm256_set1_epi32(0x80000000); + auto e_revert_and_mask = _mm256_set1_epi32(0x0000007f); + auto e_revert_shift = _mm256_set1_epi32(1); + e_revert_shift = _mm256_slli_epi32(e_revert_shift, ebits - 1); + e_revert_shift = _mm256_sub_epi32(e_revert_shift, _mm256_set1_epi32(128)); + auto mantissa_revert_and_mask = _mm256_set1_epi32(0x007fffff); + auto packrow2_permute_idx = _mm256_setr_epi32(0, 0, 1, 1, 2, 2, 3, 3); + for (int i = 0; i < row; i++) { + int kpos = (k_offset + i) / kblock; + auto sptr = scales + kpos * NPad; + int j = 0; + auto quant = [&]() { + auto sign_revert = _mm256_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr + i * ld_src + j))); + auto e_revert = sign_revert; + auto mantissa_revert = sign_revert; + sign_revert = _mm256_slli_epi32(sign_revert, 24); + sign_revert = _mm256_and_si256(sign_revert, sign_revert_and_mask); + e_revert = _mm256_and_si256(e_revert, e_revert_and_mask); + e_revert = _mm256_srli_epi32(e_revert, mantissabit); + if constexpr (WITH_SCALE && std::is_same_v<_S_T, utils::f8>) { + auto scale = _mm256_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(sptr + j / _PACK_ROW))); + if constexpr (_PACK_ROW == 2) scale = _mm256_permutevar8x32_epi32(packrow2_permute_idx, scale); + e_revert = _mm256_add_epi32(e_revert, scale); + } + e_revert = _mm256_sub_epi32(e_revert, e_revert_shift); + e_revert = _mm256_slli_epi32(e_revert, 23); + mantissa_revert = _mm256_slli_epi32(mantissa_revert, 23 - mantissabit); + mantissa_revert = _mm256_and_si256(mantissa_revert, mantissa_revert_and_mask); + auto fp_v = _mm256_or_ps(_mm256_castsi256_ps(sign_revert), _mm256_castsi256_ps(e_revert)); + fp_v = _mm256_or_ps(fp_v, _mm256_castsi256_ps(mantissa_revert)); + if constexpr (WITH_SCALE && std::is_same_v<_S_T, float>) { + auto scale = _mm256_loadu_ps(sptr + j / _PACK_ROW); + if constexpr (_PACK_ROW == 2) scale = _mm256_permutevar8x32_ps(scale, packrow2_permute_idx); + fp_v = _mm256_mul_ps(fp_v, scale); + } + if constexpr (std::is_same_v<_DST_T, float>) { + _mm256_storeu_ps(dstptr + i * ld_dst + j, fp_v); + } else { + assert(0); + } + }; + for (; j < align_col; j += 8) quant(); + for (; j < col; j++) { + auto fp_v = ref::f8_to_fp32(srcptr[i * ld_src + j], src_f8_type); + if constexpr (WITH_SCALE) { + if constexpr (std::is_same_v<_S_T, utils::f8>) { + dstptr[i * ld_dst + j] = sptr[j / _PACK_ROW].mul(fp_v); + } else if constexpr (std::is_same_v<_S_T, float>) { + dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW]; + } + } else { + dstptr[i * ld_dst + j] = fp_v; + } + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, + const int dststep, const int M, const int N) { + int constexpr Vlen = 8; + auto vN = utils::padto_le(N, Vlen); + int j = 0; + for (; j < vN; j += Vlen) { + __m256 valpha; + if constexpr (std::is_same_v) { + valpha = _mm256_loadu_ps(alpha + j); + } else if constexpr (std::is_same_v) { + auto tmp = _mm_loadu_si128(reinterpret_cast(alpha + j)); + valpha = ymm_cvt_bf16_fp32(tmp); + } else if constexpr (std::is_same_v) { + auto ebit = _mm256_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast(alpha + j))); + ebit = _mm256_add_epi32(_mm256_set1_epi32(127), ebit); + valpha = _mm256_castsi256_ps(_mm256_slli_epi32(ebit, 23)); + } + for (size_t i = 0; i < M; i++) { + auto vsrc = _mm256_loadu_ps(srcptr + i * srcstep + j); + auto vsrc1 = _mm256_loadu_ps(dstptr + i * dststep + j); + auto vdst = _mm256_fmadd_ps(valpha, vsrc, vsrc1); + _mm256_storeu_ps(dstptr + i * dststep + j, vdst); + } + } + for (; j < N; j += 1) { + for (size_t i = 0; i < M; i++) { + if constexpr (!std::is_same_v) { + dstptr[i * dststep + j] += alpha[j] * srcptr[i * srcstep + j]; + } else { + dstptr[i * dststep + j] += alpha[j].mul(srcptr[i * srcstep + j]); + } + } + } + return BTLA_CODE::Success; +} + +template +static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales, __m256 vLutL, __m256 vLutH) { + static_assert(N % 8 == 0); + int constexpr VLoop = N / 8; + auto v7 = _mm256_set1_epi32(7); + auto v8 = _mm256_set1_epi32(8); + for (int iv = 0; iv < VLoop; iv++) { + auto idx = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr + iv * 8)); + auto pad_idx = _mm256_cvtepu8_epi32(idx); + auto mskgt8 = _mm256_cmpgt_epi32(pad_idx, v7); + auto fp32_dq_v0 = _mm256_permutevar8x32_ps(vLutL, pad_idx); + pad_idx = _mm256_sub_epi32(pad_idx, v8); + auto fp32_dq_v1 = _mm256_permutevar8x32_ps(vLutH, pad_idx); + auto fp32_dq_v = _mm256_blendv_ps(fp32_dq_v0, fp32_dq_v1, _mm256_castsi256_ps(mskgt8)); + if constexpr (MULS_T) { + fp32_dq_v = _mm256_mul_ps(fp32_dq_v, vscales[iv]); + } + if constexpr (std::is_same_v<_DST_T, float>) { + _mm256_storeu_ps(dstptr + iv * 8, fp32_dq_v); + } else if constexpr (std::is_same_v<_DST_T, utils::bf16>) { + auto bf16v = ymm_cvt_fp32_bf16(fp32_dq_v); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr + iv * 8), bf16v); + } + } +} + +template +static inline void convert_s4_s8_N_avx2(int8_t* dstptr, int8_t* srcptr, __m256i mask) { + static_assert(N % 2 == 0); + static_assert(N <= 64); + const auto vbias = _mm256_set1_epi8(8); + if constexpr (N == 32) { + auto dst0 = unpack_4bits(srcptr, mask); + if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { + dst0 = _mm256_sub_epi8(dst0, vbias); + } + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); + } else if constexpr (N > 32) { + auto dst0 = unpack_4bits(srcptr, mask); + if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { + dst0 = _mm256_sub_epi8(dst0, vbias); + } + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); + int8_t temp[32]; + memcpy(temp, srcptr + 16, (N - 32) / 2); + dst0 = unpack_4bits(temp, mask); + if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { + dst0 = _mm256_sub_epi8(dst0, vbias); + } + _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); + memcpy(dstptr + 32, temp, (N - 32)); + } else { + int8_t temp[32]; + memcpy(temp, srcptr, N / 2); + auto dst0 = unpack_4bits(temp, mask); + if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { + dst0 = _mm256_sub_epi8(dst0, vbias); + } + _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); + memcpy(dstptr, temp, N); + } +} + +template +inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, int8_t* tmp, size_t tmpsize) { + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + float* LUT; + static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, + "Unsupported F4 type"); + if constexpr (F4_T == BTLA_DTYPE::F4_BNB) { + LUT = fp4_bnb_dequant_fp32_LUT; + } else if constexpr (F4_T == BTLA_DTYPE::F4_NF4) { + LUT = nf4_dequant_fp32_LUT; + } else if constexpr (F4_T == BTLA_DTYPE::F4_E2M1) { + LUT = fp4_e2m1_dequant_fp32_LUT; + } + auto vLutL = _mm256_loadu_ps(LUT); + auto vLutH = _mm256_loadu_ps(LUT + 8); + if (col == ld_src) { + size_t elesize = static_cast(row) * col; + size_t velt = utils::padto_le(elesize, 32); + size_t i = 0; + assert(tmpsize >= 32); + for (; i < velt; i += 32) { + convert_s4_s8_N_avx2<32, F4_T>(tmp, reinterpret_cast(srcptr + i / 2), vmask); + dequant_f4_N<32, DST_T, F4_T, false>(dstptr + i, tmp, nullptr, vLutL, vLutH); + } + for (; i < elesize; i += 2) { + auto tmp = srcptr[i / 2]; + dstptr[i + 0] = static_cast(ref::f4_unpack(tmp.x)); + dstptr[i + 1] = static_cast(ref::f4_unpack(tmp.y)); + } + return BTLA_CODE::Success; + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, + int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, int8_t* tmpbuf, + size_t tmpsize) { + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + float* LUT = nullptr; + if constexpr (QT_T == BTLA_DTYPE::F4_BNB) { + LUT = fp4_bnb_dequant_fp32_LUT; + } else if constexpr (QT_T == BTLA_DTYPE::F4_NF4) { + LUT = nf4_dequant_fp32_LUT; + } else if constexpr (QT_T == BTLA_DTYPE::F4_E2M1) { + LUT = fp4_e2m1_dequant_fp32_LUT; + } + __m256 vLutL, vLutH; + if (LUT) { + vLutL = _mm256_loadu_ps(LUT); + vLutH = _mm256_loadu_ps(LUT + 8); + } + int constexpr NReg = _NCOL / 8; + assert(col == _NCOL); + assert(ld_src == _NCOL); + assert(ld_dst == _NCOL); + __m256 vscales[NReg]; + __m256i vzps[NReg]; + int constexpr UnrollRow = 4; + assert(kblock % UnrollRow == 0); + int constexpr NTile = 32; + int constexpr Loop32 = _NCOL * UnrollRow / NTile; + assert(tmpsize >= (_NCOL * UnrollRow)); + int row0 = kblock - k_offset % kblock; + row0 = row0 == kblock ? 0 : row0; + row0 = row0 > row ? row : row0; + int row1 = row - row0; + int irow = 0; + auto dequantize = [&](_DST_T* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps = nullptr) { + if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { + dequant_s8_N_avx2<_NCOL, _IS_SYM>(dstptr, srcptr, vscales, vzps); + } else { + dequant_f4_N<_NCOL, _DST_T, QT_T, true>(dstptr, srcptr, vscales, vLutL, vLutH); + } + }; + if (row0) { + int rowpad4 = utils::padto_le(row0, UnrollRow); + for (int iv = 0; iv < NReg; iv++) { + vscales[iv] = _mm256_loadu_ps(scales + (k_offset + irow) / kblock * NPad + iv * 8); + if constexpr (!_IS_SYM) { + auto tmp = + _mm_loadl_epi64(reinterpret_cast<__m128i*>(zero_points + (k_offset + irow) / kblock * NPad + iv * 8)); + vzps[iv] = _mm256_cvtepi8_epi32(tmp); + } + } + for (; irow < rowpad4; irow += UnrollRow) { + for (int iter16 = 0; iter16 < Loop32; iter16++) + convert_s4_s8_N_avx2( + tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + irow * ld_src / 2 + NTile / 2 * iter16), vmask); + for (int iterr = 0; iterr < UnrollRow; iterr++) + dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps); + } + for (; irow < row0; irow++) { + convert_s4_s8_N_avx2<_NCOL, QT_T>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2), vmask); + + dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps); + } + } + + int row1_blk = utils::padto_le(row1, kblock) + row0; + for (; irow < row1_blk; irow += kblock) { + for (int iv = 0; iv < NReg; iv++) { + vscales[iv] = _mm256_loadu_ps(scales + (k_offset + irow) / kblock * NPad + iv * 8); + if constexpr (!_IS_SYM) { + auto tmp = + _mm_loadl_epi64(reinterpret_cast<__m128i*>(zero_points + (k_offset + irow) / kblock * NPad + iv * 8)); + vzps[iv] = _mm256_cvtepi8_epi32(tmp); + } + } + for (int irr = 0; irr < kblock; irr += UnrollRow) { + for (int iter16 = 0; iter16 < Loop32; iter16++) + convert_s4_s8_N_avx2( + tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + NTile / 2 * iter16), + vmask); + for (int iterr = 0; iterr < UnrollRow; iterr++) + dequantize(dstptr + (irow + irr + iterr) * ld_src, tmpbuf + iterr * _NCOL, vscales, vzps); + } + } + if (irow < row) { + for (int iv = 0; iv < NReg; iv++) { + vscales[iv] = _mm256_loadu_ps(scales + (k_offset + irow) / kblock * NPad + iv * 8); + if constexpr (!_IS_SYM) { + auto tmp = + _mm_loadl_epi64(reinterpret_cast<__m128i*>(zero_points + (k_offset + irow) / kblock * NPad + iv * 8)); + vzps[iv] = _mm256_cvtepi8_epi32(tmp); + } + } + auto rowre = row - irow; + int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow; + for (; irow < rowpad4; irow += UnrollRow) { + for (int iter16 = 0; iter16 < Loop32; iter16++) + convert_s4_s8_N_avx2( + tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + irow * ld_src / 2 + NTile / 2 * iter16), vmask); + for (int iterr = 0; iterr < UnrollRow; iterr++) + dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps); + } + for (; irow < row; irow++) { + convert_s4_s8_N_avx2<_NCOL, QT_T>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2), vmask); + dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, + int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, int8_t* tmp, + size_t tmpsize) { + return BTLA_CODE::NotSupport; +} + +template +inline BTLA_CODE decompress_kblock_s8_fp_row(int8_t* srcptr, DST_T* dstptr, int row, void* scales_, BTLA_DTYPE sdtype, + int8_t* zero_points, int k_offset, int n_offset, int blocksize, int ldzp, + int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + if (zero_points == nullptr) { + for (int ir = 0; ir < row; ir += blocksize) { + int k_remain = utils::remainsize(ir, row, blocksize); + int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; + if constexpr (PackRow == 1) { + __m256 vscale_y[NReg]; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * 8); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; for (int i = 0; i < NReg; i++) { auto vdeq_y = dequant_s8_fp(b8ptr + i * 8, vscale_y[i]); store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8); } } - } else if constexpr (PackRow == 4) { - const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, - 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); - __m256 vscale_y[PackRow * NReg]; - for (int i = 0; i < NReg; i++) { - __m256 vraw; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - vraw = _mm256_loadu_ps(sptr + i * 8); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - vraw = load_bf16_fp32(sptr + i * 8); - } else { - assert(0); - } - auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); - vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_y); - vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_y); - vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); - vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_y); - vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + } else if constexpr (PackRow == 4) { + const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, + 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); + __m256 vscale_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m256 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * 8); + } else { + assert(0); + } + auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); + } + } + } + } else if constexpr (PackRow == 2) { + const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, + 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); + __m256 vscale_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m256 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * 8); + } + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_y); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); + } + } + } + } else { + assert(0); + } + } + return BTLA_CODE::Success; + } else { + for (int ir = 0; ir < row; ir += blocksize) { + int k_remain = utils::remainsize(ir, row, blocksize); + int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; + if constexpr (PackRow == 1) { + __m256 vscale_y[NReg]; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * 8); + } + __m256i vzp_y[NReg]; + for (int i = 0; i < NReg; i++) vzp_y[i] = load_s8_s32(zero_points + ele_off + i * 8); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * 8, vscale_y[i], vzp_y[i]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8); + } + } + } else if constexpr (PackRow == 4) { + const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, + 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); + __m256 vscale_y[PackRow * NReg]; + __m256i vzp_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m256 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * 8); + } else { + assert(0); + } + auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + + auto tmp = load_s8_s32(zero_points + ele_off + i * 8); + auto vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_y); + vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); + vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); + vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_y); + vzp_y[i * PackRow + 2] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); + vzp_y[i * PackRow + 3] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip], + vzp_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); + } + } + } + } else if constexpr (PackRow == 2) { + const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, + 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); + __m256 vscale_y[PackRow * NReg]; + __m256i vzp_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m256 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * 8); + } + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_y); + auto tmp = load_s8_s32(zero_points + ele_off + i * 8); + vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(tmp, vshuf_index_y); + vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(tmp, vshuf_index_y); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip], + vzp_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); + } + } + } + } else { + assert(0); + } + } + return BTLA_CODE::Success; + } +} + +template +inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s8_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s8_fp_row(srcptr + head_size * NTILE, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s4_fp_row(utils::int4x2* srcptr, DST_T* dstptr, int row, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s4_s8(srcptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, + row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s4_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s4_fp_row(srcptr + head_size * NTILE / 2, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s5_fp_row(utils::bit4x2* b4ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s5_s8(b4ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, + k_offset, row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s5_fp(utils::bit4x2* b4ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, int col, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s5_fp_row(b4ptr, b1ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s5_fp_row( + b4ptr + head_size * NTILE / 2, b1ptr + head_size * NTILE / 8, dstptr + head_size * NTILE, body_size, scales_, + sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s6_fp_row(utils::bit4x2* b4ptr, utils::bit2x4* b2ptr, DST_T* dstptr, int row, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s6_s8(b4ptr, b2ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, + k_offset, row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s6_fp(utils::bit4x2* b4ptr, utils::bit2x4* b2ptr, DST_T* dstptr, int row, int col, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s6_fp_row(b4ptr, b2ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s6_fp_row( + b4ptr + head_size * NTILE / 2, b2ptr + head_size * NTILE / 4, dstptr + head_size * NTILE, body_size, scales_, + sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s3_fp_row(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s3_s8(b2ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, + k_offset, row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s3_fp(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, int col, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s3_fp_row(b2ptr, b1ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s3_fp_row( + b2ptr + head_size * NTILE / 4, b1ptr + head_size * NTILE / 8, dstptr + head_size * NTILE, body_size, scales_, + sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s2_fp_row(utils::bit2x4* b2ptr, DST_T* dstptr, int row, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s2_s8(b2ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, + row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s2_fp(utils::bit2x4* b2ptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s2_fp_row(b2ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s2_fp_row(b2ptr + head_size * NTILE / 4, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +static inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, _ST* scales, int k_offset, int kblock, int NPad, + int8_t* tmp, size_t tmpsize) { + if constexpr (_PACK_ROW == 1) { + if (col == 24) { + return decompress_kblock_bit4_packrow1<_F4_T, true, 24, _ST, _DST_T>( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, tmp, tmpsize); + } + if (col == 48) { + return decompress_kblock_bit4_packrow1<_F4_T, true, 48, _ST, _DST_T>( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, tmp, tmpsize); + } + } else if constexpr (_PACK_ROW == 2) { + return decompress_kblock_bit4_packrow2<_F4_T, true, _ST, _DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + nullptr, k_offset, kblock, NPad, tmp, tmpsize); + } + assert(0); + return BTLA_CODE::NotSupport; +} + +enum class AVX2_REDUCE_TYPE { MAX, MIN, ADD }; +#define AVX2_REDUCE_OP \ + if constexpr (TYPE == AVX2_REDUCE_TYPE::MAX) x = _mm256_max_ps(x, y); \ + if constexpr (TYPE == AVX2_REDUCE_TYPE::MIN) x = _mm256_min_ps(x, y); \ + if constexpr (TYPE == AVX2_REDUCE_TYPE::ADD) x = _mm256_add_ps(x, y); + +template +inline float avx2_reduce_ps(__m256 x) { + __m256 y = _mm256_permute2f128_ps(x, x, 1); + AVX2_REDUCE_OP + y = _mm256_permute_ps(x, 0b01001110); + AVX2_REDUCE_OP + y = _mm256_permute_ps(x, 0b10110001); + AVX2_REDUCE_OP + return _mm256_cvtss_f32(x); +} + +#define AVX2_REDUCE_OP_EPI32(dst, src) \ + if constexpr (TYPE == AVX2_REDUCE_TYPE::MAX) dst = _mm256_max_epi32(dst, src); \ + if constexpr (TYPE == AVX2_REDUCE_TYPE::MIN) dst = _mm256_min_epi32(dst, src); \ + if constexpr (TYPE == AVX2_REDUCE_TYPE::ADD) dst = _mm256_add_epi32(dst, src); + +#ifndef _mm256_cvtsi256_si32 +#define _mm256_cvtsi256_si32(a) (_mm_cvtsi128_si32(_mm256_castsi256_si128(a))) +#endif + +template +inline int avx2_reduce_epi32(__m256i xd) { + auto x = _mm256_castsi256_ps(xd); + __m256 y = _mm256_permute2f128_ps(x, x, 1); + auto yd = _mm256_castps_si256(y); + AVX2_REDUCE_OP_EPI32(xd, yd); + x = _mm256_castsi256_ps(xd); + y = _mm256_permute_ps(x, 0b01001110); + yd = _mm256_castps_si256(y); + AVX2_REDUCE_OP_EPI32(xd, yd); + x = _mm256_castsi256_ps(xd); + y = _mm256_permute_ps(x, 0b10110001); + yd = _mm256_castps_si256(y); + AVX2_REDUCE_OP_EPI32(xd, yd); + return _mm256_cvtsi256_si32(xd); +} + +inline __m128i avx2_cvtepi32_epu8(__m256i x) { + auto out_v = _mm_packus_epi32(_mm256_castsi256_si128(x), _mm256_extractf128_si256(x, 1)); + out_v = _mm_packus_epi16(out_v, out_v); + return out_v; +} + +template +static inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, + int ld_dst, float* scales, int ld_scale, uint8_t* zps, int blocksize, + float* blkreduce) { + int constexpr VLen = 8; + auto vff = _mm256_set1_epi32(255); + auto v0 = _mm256_set1_epi32(0); + int vblocksize = utils::padto_le(blocksize, VLen); + int colblk = utils::padto_le(col, blocksize); + for (int i = 0; i < row; i++) { + size_t j = 0; + for (; j < colblk; j += blocksize) { + __m256 vmaxval = _mm256_set1_ps(0.f); + __m256 vminval = _mm256_set1_ps(0.f); + size_t ij = 0; + for (; ij < vblocksize; ij += VLen) { + __m256 vsrc; + if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); + if constexpr (std::is_same_v) { + auto vtmp = + _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); + vsrc = ymm_cvt_bf16_fp32(vtmp); + } + vmaxval = _mm256_max_ps(vmaxval, vsrc); + vminval = _mm256_min_ps(vminval, vsrc); + } + auto maxval = avx2_reduce_ps(vmaxval); + auto minval = avx2_reduce_ps(vminval); + if (ij < blocksize) { + for (; ij < blocksize; ij++) { + auto srcval = (float)srcptr[(j + ij) + i * ld_src]; + maxval = std::max(maxval, srcval); + minval = std::min(minval, srcval); } - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; - for (int i = 0; i < NReg; i++) { - for (int ip = 0; ip < PackRow; ip++) { - auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip]); - store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); - } + } + float scale = (maxval - minval) / 255; + uint8_t zp = utils::cast((0 - minval) / scale); + scales[j / blocksize + i * ld_scale] = scale; + zps[j / blocksize + i * ld_scale] = zp; + int sum = 0; + float rscale = 1.f / scale; + auto vrscale = _mm256_set1_ps(rscale); + auto vdzp = _mm256_set1_epi32(zp); + ij = 0; + if (blkreduce) { + for (; ij < vblocksize; ij += VLen) { + __m256 vsrc; + if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); + if constexpr (std::is_same_v) { + auto vtmp = + _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); + vsrc = ymm_cvt_bf16_fp32(vtmp); } + vsrc = _mm256_mul_ps(vsrc, vrscale); + auto vdsrc = _mm256_cvtps_epi32(vsrc); + sum += avx2_reduce_epi32(vdsrc); + vdsrc = _mm256_add_epi32(vdsrc, vdzp); + vdsrc = _mm256_min_epi32(vdsrc, vff); + vdsrc = _mm256_max_epi32(vdsrc, v0); + auto vbsrc = avx2_cvtepi32_epu8(vdsrc); + _mm_storel_epi64(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst]), vbsrc); } - } else if constexpr (PackRow == 2) { - const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, - 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); - __m256 vscale_y[PackRow * NReg]; - for (int i = 0; i < NReg; i++) { - __m256 vraw; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - vraw = _mm256_loadu_ps(sptr + i * 8); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - vraw = load_bf16_fp32(sptr + i * 8); + } else { + for (; ij < vblocksize; ij += VLen) { + __m256 vsrc; + if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); + if constexpr (std::is_same_v) { + auto vtmp = + _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); + vsrc = ymm_cvt_bf16_fp32(vtmp); } - vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_y); - vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_y); + vsrc = _mm256_mul_ps(vsrc, vrscale); + auto vdsrc = _mm256_cvtps_epi32(vsrc); + vdsrc = _mm256_add_epi32(vdsrc, vdzp); + vdsrc = _mm256_min_epi32(vdsrc, vff); + vdsrc = _mm256_max_epi32(vdsrc, v0); + auto vbsrc = avx2_cvtepi32_epu8(vdsrc); + _mm_storel_epi64(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst]), vbsrc); } - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; - for (int i = 0; i < NReg; i++) { - for (int ip = 0; ip < PackRow; ip++) { - auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip]); - store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); - } - } + } + for (; ij < blocksize; ij++) { + auto srcval = (float)srcptr[(j + ij) + i * ld_src]; + srcval = srcval * rscale; + auto srcint = int(roundf(srcval)); + sum += srcint; + srcint += zp; + srcint = std::min(srcint, 0xff); + srcint = std::max(srcint, 0); + dstptr[(j + ij) + i * ld_dst] = static_cast(srcint); + } + if (blkreduce) { + blkreduce[j / blocksize + i * ld_scale] = sum * scale; + } + } + if (j < col) { + float maxval = 0.f; + float minval = 0.f; + for (size_t ij = j; ij < col; ij++) { + maxval = std::max((float)srcptr[ij + i * ld_src], maxval); + minval = std::min((float)srcptr[ij + i * ld_src], minval); + } + float scale = (maxval - minval) / 255; + uint8_t zp = utils::cast((0 - minval) / scale); + float rscale = 1.f / scale; + scales[j / blocksize + i * ld_scale] = scale; + zps[j / blocksize + i * ld_scale] = zp; + int sum = 0; + for (size_t ij = j; ij < col; ij++) { + auto srcint = utils::cast(srcptr[ij + i * ld_src] * rscale); + sum += srcint; + srcint += zp; + srcint = srcint <= 255 ? srcint : 255; + srcint = srcint >= 0 ? srcint : 0; + dstptr[ij + i * ld_dst] = utils::cast(srcint); + } + if (blkreduce) { + blkreduce[j / blocksize + i * ld_scale] = sum * scale; + } + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE col_block_reduce_sum(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, + float* reduce, int ldr) { + int constexpr VLen = 8; + auto vblock2_ = utils::padto_le(blocksize, VLen * 2); + auto vblock_ = utils::padto_le(blocksize, VLen); + for (int i = 0; i < row; i++) { + for (int j = 0; j < col; j += blocksize) { + auto tmp = 0.f; + auto vsum = _mm256_set1_ps(0.f); + int jj = 0; + auto vblock2 = j + vblock2_ <= col ? vblock2_ : 0; + auto vblock = j + vblock_ <= col ? vblock_ : 0; + for (; jj < vblock2; jj += VLen * 2) { + auto vtmp = _mm256_loadu_ps(srcptr + i * ldsrc + j + jj); + auto vtmp1 = _mm256_loadu_ps(srcptr + i * ldsrc + j + jj + VLen); + auto s0 = avx2_reduce_ps(vtmp); + auto s1 = avx2_reduce_ps(vtmp1); + tmp += s0; + tmp += s1; + } + if (jj + VLen <= vblock) { + for (; jj < vblock; jj += VLen) { + auto vtmp = _mm256_loadu_ps(srcptr + i * ldsrc + j + jj); + auto s0 = avx2_reduce_ps(vtmp); + tmp += s0; } - } else { - assert(0); } + for (; jj < blocksize; jj++) { + tmp += *(srcptr + i * ldsrc + j + jj); + } + reduce[i * ldr + j / blocksize] = tmp; + } + } + return BTLA_CODE::Success; +} + +static inline BTLA_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, float* dst_ptr, int row, int col, + int src_step, int dst_step, bool zeropadding) { + const int npadding = (dst_step - col) * sizeof(float); + constexpr int simd_proc_elt = 8; + auto col_body = col / simd_proc_elt * simd_proc_elt; + for (int i = 0; i < row; i++) { + auto src = const_cast(src_ptr + i * src_step); + auto dst = dst_ptr + i * dst_step; + int j = 0; + for (; j < col_body; j += simd_proc_elt) { + auto bf16_v = _mm_loadu_si128(reinterpret_cast<__m128i*>(src + j)); + auto fp32_v = _mm256_castsi256_ps(_mm256_bslli_epi128(_mm256_cvtepu16_epi32(bf16_v), 2)); + _mm256_storeu_ps(dst + j, fp32_v); + } + for (; j < col; j++) { + *(dst + j) = (src + j)->tofloat(); + } + if (zeropadding && npadding) std::memset(dst + col, 0, npadding); + } + return BTLA_CODE::Success; +} + +static const uint8_t avx2_bf16_convert_maigc_num[32] = { + 0x02, 0x03, 0x06, 0x07, 0x0a, 0x0b, 0x0e, 0x0f, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x02, 0x03, 0x06, 0x07, 0x0a, 0x0b, 0x0e, 0x0f, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; + +static inline __m128i cvt_fp32_to_bf16(const __m256 src, __m256i* and_helper, __m256i* add_helper) { + auto shuffle_v = _mm256_loadu_si256(reinterpret_cast(avx2_bf16_convert_maigc_num)); + auto round_bias = _mm256_castps_si256(src); + round_bias = _mm256_and_si256(*and_helper, _mm256_srli_si256(round_bias, 2)); + round_bias = _mm256_add_epi32(round_bias, *add_helper); + auto round_fp32_v = _mm256_add_epi32(_mm256_castps_si256(src), round_bias); + __m256i trunc_elements = _mm256_shuffle_epi8(round_fp32_v, shuffle_v); + __m256i ordered = _mm256_permute4x64_epi64(trunc_elements, 0x58); + return _mm256_castsi256_si128(ordered); +} + +static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, + int srcstride, int dststride, bool zeropadding) { + auto srcptr = reinterpret_cast(raw_srcptr); + auto dstptr = reinterpret_cast(raw_dstptr); + constexpr int simd_proc_elt = 8; + auto bf16_and_helper = _mm256_set1_epi32(0X00000001); + auto bf16_add_helper = _mm256_set1_epi32(0x00007FFF); + auto col_body_loop = col / simd_proc_elt * simd_proc_elt; + int npadding = dststride - col * sizeof(utils::bf16); + for (int i = 0; i < row; i++) { + auto src = srcptr + i * srcstride; + auto dst = dstptr + i * dststride; + int j = 0; + for (; j < col_body_loop; j += simd_proc_elt) { + auto pack_bf16_value = cvt_fp32_to_bf16(_mm256_loadu_ps(reinterpret_cast(src) + j), + &bf16_and_helper, &bf16_add_helper); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + j * sizeof(utils::bf16)), pack_bf16_value); } - return BTLA_CODE::Success; + for (; j < col; j++) { + (reinterpret_cast(dst) + j)->fromfloat(*(reinterpret_cast(src) + j)); + } + if (zeropadding && npadding) { + std::memset(dst + col * sizeof(utils::bf16), 0, npadding); + } + } + return BTLA_CODE::Success; +} + +static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, const float* biasptr, float epsilon, + int norm_size, float* dstptr, float* mean_out, float* mean_square_out, + bool simplified) { + int constexpr VLen = 8; + int norm_size8 = utils::padto_le(norm_size, VLen); + int h = 0; + __m256 vmean = _mm256_setzero_ps(), vmeansq = _mm256_setzero_ps(); + for (; h < norm_size8; h += VLen) { + auto tmp = _mm256_loadu_ps(srcptr + h); + vmean = _mm256_add_ps(vmean, tmp); + tmp = _mm256_mul_ps(tmp, tmp); + vmeansq = _mm256_add_ps(vmeansq, tmp); + } + float mean = avx2_reduce_ps(vmean); + float mean_square = avx2_reduce_ps(vmeansq); + for (; h < norm_size; h++) { + mean += srcptr[h]; + mean_square += srcptr[h] * srcptr[h]; + } + mean = mean / norm_size; + if (simplified) { + mean_square = std::sqrt(mean_square / norm_size + epsilon); } else { - for (int ir = 0; ir < row; ir += blocksize) { - int k_remain = utils::remainsize(ir, row, blocksize); - int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; - if constexpr (PackRow == 1) { - __m256 vscale_y[NReg]; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - for (int i = 0; i < NReg; i++) vscale_y[i] = _mm256_loadu_ps(sptr + i * 8); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * 8); - } - __m256i vzp_y[NReg]; - for (int i = 0; i < NReg; i++) vzp_y[i] = load_s8_s32(zero_points + ele_off + i * 8); - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; - for (int i = 0; i < NReg; i++) { - auto vdeq_y = dequant_s8_fp(b8ptr + i * 8, vscale_y[i], vzp_y[i]); - store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8); - } - } - } else if constexpr (PackRow == 4) { - const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, - 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); - __m256 vscale_y[PackRow * NReg]; - __m256i vzp_y[PackRow * NReg]; - for (int i = 0; i < NReg; i++) { - __m256 vraw; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - vraw = _mm256_loadu_ps(sptr + i * 8); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - vraw = load_bf16_fp32(sptr + i * 8); - } else { - assert(0); - } - auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); - vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_y); - vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_y); - vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); - vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_y); - vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + mean_square = std::sqrt(mean_square / norm_size - mean * mean + epsilon); + } + auto vm = _mm256_set1_ps(mean); + float inv_meansq = 1.f / mean_square; + auto vms = _mm256_set1_ps(inv_meansq); + h = 0; + if (simplified) { + if (scaleptr) { + for (; h < norm_size8; h += VLen) { + auto inp = _mm256_loadu_ps(srcptr + h); + auto scale = _mm256_loadu_ps(scaleptr + h); + inp = _mm256_mul_ps(inp, scale); + inp = _mm256_mul_ps(inp, vms); + _mm256_storeu_ps(dstptr + h, inp); + } + for (; h < norm_size; h++) { + dstptr[h] = srcptr[h] * inv_meansq * scaleptr[h]; + } + } else { + for (; h < norm_size8; h += VLen) { + auto inp = _mm256_loadu_ps(srcptr + h); + inp = _mm256_mul_ps(inp, vms); + _mm256_storeu_ps(dstptr + h, inp); + } + for (; h < norm_size; h++) { + dstptr[h] = srcptr[h] * inv_meansq; + } + } - auto tmp = load_s8_s32(zero_points + ele_off + i * 8); - auto vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_y); - vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); - vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); - vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_y); - vzp_y[i * PackRow + 2] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); - vzp_y[i * PackRow + 3] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); + } else { + if (scaleptr) { + if (biasptr == nullptr) { + for (; h < norm_size8; h += VLen) { + auto inp = _mm256_loadu_ps(srcptr + h); + auto scale = _mm256_loadu_ps(scaleptr + h); + inp = _mm256_sub_ps(inp, vm); + inp = _mm256_mul_ps(inp, scale); + inp = _mm256_mul_ps(inp, vms); + _mm256_storeu_ps(dstptr + h, inp); } - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; - for (int i = 0; i < NReg; i++) { - for (int ip = 0; ip < PackRow; ip++) { - auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip], - vzp_y[i * PackRow + ip]); - store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); - } - } + for (; h < norm_size; h++) { + dstptr[h] = (srcptr[h] - mean) * inv_meansq * scaleptr[h]; } - } else if constexpr (PackRow == 2) { - const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, - 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); - __m256 vscale_y[PackRow * NReg]; - __m256i vzp_y[PackRow * NReg]; - for (int i = 0; i < NReg; i++) { - __m256 vraw; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - vraw = _mm256_loadu_ps(sptr + i * 8); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - vraw = load_bf16_fp32(sptr + i * 8); - } - vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_y); - vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_y); - auto tmp = load_s8_s32(zero_points + ele_off + i * 8); - vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(tmp, vshuf_index_y); - vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(tmp, vshuf_index_y); + } else { + for (; h < norm_size8; h += VLen) { + auto inp = _mm256_loadu_ps(srcptr + h); + auto scale = _mm256_loadu_ps(scaleptr + h); + inp = _mm256_sub_ps(inp, vm); + inp = _mm256_mul_ps(inp, vms); + inp = _mm256_mul_ps(inp, scale); + auto bias = _mm256_loadu_ps(biasptr + h); + inp = _mm256_add_ps(inp, bias); + _mm256_storeu_ps(dstptr + h, inp); } - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; - for (int i = 0; i < NReg; i++) { - for (int ip = 0; ip < PackRow; ip++) { - auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip], - vzp_y[i * PackRow + ip]); - store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); - } - } + for (; h < norm_size; h++) { + dstptr[h] = (srcptr[h] - mean) * inv_meansq * scaleptr[h] + biasptr[h]; } - } else { - assert(0); + } + } else { + for (; h < norm_size8; h += VLen) { + auto inp = _mm256_loadu_ps(srcptr + h); + inp = _mm256_sub_ps(inp, vm); + inp = _mm256_mul_ps(inp, vms); + _mm256_storeu_ps(dstptr + h, inp); + } + for (; h < norm_size; h++) { + dstptr[h] = (srcptr[h] - mean) * inv_meansq; } } - return BTLA_CODE::Success; } + + if (mean_out) { + *mean_out = mean; + } + if (mean_square_out) { + *mean_square_out = mean_square; + } + return BTLA_CODE::Success; +} + +template +inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, + int interleave_n_offset, int unpack_elt, int8_t* tmp, size_t tmpsize) { + auto head_ignore_num = interleave_n_offset % 128; + const __m256i lowMask = _mm256_set1_epi8(0x03); + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + + auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { + __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1); + int32_t* bit1_ptr = reinterpret_cast(src2); + for (int i = 0; i < 4; i++) { + auto bit1x32 = _mm256_set1_epi32(bit1_ptr[i]); + bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1); + bit1x32 = _mm256_and_si256(bit1x32, bit1Mask); + bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2); + bit1x32 = _mm256_and_si256(highMask, bit1x32); + + auto bit2x32 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2 * i)); + auto res = _mm256_add_epi8(bit1x32, bit2x32); + res = _mm256_sub_epi8(res, highMask); + _mm256_storeu_si256((__m256i*)(dst + 32 * i), res); + } + }; + int compress_wei_ptr_offset = 0; + if (head_ignore_num != 0) { + assert(head_ignore_num % 8 == 0); + + auto base_bit2ptr = bit2ptr - head_ignore_num / 4; + auto base_bit1ptr = bit1ptr - head_ignore_num / 8; + auto head_write_num = 128 - head_ignore_num; + bit3_interleave_decompress_pack128(base_bit2ptr, base_bit1ptr, tmp); + for (int i = 0; i < head_write_num; i++) dstptr[i] = tmp[head_ignore_num + i]; + compress_wei_ptr_offset += head_write_num; + unpack_elt -= head_write_num; + } + auto body_loop = unpack_elt / 128; + auto tail_proc_num = unpack_elt % 128; + + bestla::kernel::jit::DecompressS3::forward_avx2(bit2ptr + compress_wei_ptr_offset / 4, + bit1ptr + compress_wei_ptr_offset / 8, + dstptr + compress_wei_ptr_offset, tmp, body_loop * 128); + compress_wei_ptr_offset += body_loop * 128; + if (tail_proc_num > 0) { + bit3_interleave_decompress_pack128(bit2ptr + compress_wei_ptr_offset / 4, bit1ptr + compress_wei_ptr_offset / 8, + tmp); + for (int i = 0; i < tail_proc_num; i++) dstptr[compress_wei_ptr_offset + i] = tmp[i]; + } + return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, DST_T* dstptr, int row, int col, void* scales_, - BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, - int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - auto ret = BTLA_CODE::NotSupport; - if (col == NTILE) { - int head_end = utils::padto(k_offset, blocksize); - head_end = std::min(head_end, k_offset + row); - int head_size = head_end - k_offset; - if (head_size > 0) { - decompress_kblock_s8_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, - k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); - } - int body_size = row - head_size; - if (body_size > 0) { - decompress_kblock_s8_fp_row(srcptr + head_size * NTILE, dstptr + head_size * NTILE, - body_size, scales_, sdtype, zero_points, head_end, n_offset, - blocksize, ldzp, tmp, tmpsize); +template +static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, + _DST_T* dstptr, int interleave_n_offset, int row, int col, + _ST* scales, int8_t* zero_points, int k_offset, int kblock, + int NPad, void* tmp, size_t tmpsize) { + auto unpack_elt = row * col; + decompress_kblock_s3_s8fp<_S3_T>(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, + reinterpret_cast(tmp), tmpsize); + // TODO(zhe): simd version + for (int i = 0; i < row; i++) { + int kpos = (k_offset + i) / kblock; + auto sptr = scales + kpos * NPad; + for (int j = 0; j < col; j++) { + float tmp = static_cast(dstptr[i * col + j]); + if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); + dstptr[i * col + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); } - return BTLA_CODE::Success; } - return ret; -} -template -inline BTLA_CODE decompress_kblock_s4_fp_row(utils::int4x2* srcptr, DST_T* dstptr, int row, void* scales_, - BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, - int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - int constexpr NReg = NTILE / 8; - const auto DstSize = row * NTILE * sizeof(DST_T); - const auto S8Size = row * NTILE * sizeof(int8_t); - auto tmps8ptr = (int8_t*)dstptr; - tmps8ptr += DstSize - S8Size; - auto ret = decompress_kblock_s4_s8(srcptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, - row, NTILE, tmp, tmpsize); - assert(ret == BTLA_CODE::Success); - return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, - n_offset, blocksize, ldzp, tmp, tmpsize); + return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, DST_T* dstptr, int row, int col, void* scales_, - BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, - int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - auto ret = BTLA_CODE::NotSupport; - if (col == NTILE) { - int head_end = utils::padto(k_offset, blocksize); - head_end = std::min(head_end, k_offset + row); - int head_size = head_end - k_offset; - if (head_size > 0) { - decompress_kblock_s4_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, - k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); - } - int body_size = row - head_size; - if (body_size > 0) { - decompress_kblock_s4_fp_row(srcptr + head_size * NTILE / 2, dstptr + head_size * NTILE, - body_size, scales_, sdtype, zero_points, head_end, n_offset, - blocksize, ldzp, tmp, tmpsize); +template +inline BTLA_CODE decompress_kblock_s2_s8fp(utils::bit2x4* bit2ptr, _DST_T* dstptr, int unpack_elt, int8_t* tmp, + size_t tmpsize) { + int constexpr VBits = 256; + int constexpr VElt = VBits / 8; + int i = 0; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + int elt_pad = utils::padto_le(unpack_elt, VElt); + for (; i < elt_pad; i += VElt) { + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + if (std::is_same_v<_DST_T, int8_t>) { + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + } else { + _mm256_storeu_si256((__m256i*)tmp, vout); + for (int j = 0; j < VElt; j++) { + dstptr[i + j] = tmp[j]; + } } - return BTLA_CODE::Success; } - return ret; -} - -template -inline BTLA_CODE decompress_kblock_s3_fp_row(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, - void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, - int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - int constexpr NReg = NTILE / 8; - const auto DstSize = row * NTILE * sizeof(DST_T); - const auto S8Size = row * NTILE * sizeof(int8_t); - auto tmps8ptr = (int8_t*)dstptr; - tmps8ptr += DstSize - S8Size; - auto ret = decompress_kblock_s3_s8(b2ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, - k_offset, row, NTILE, tmp, tmpsize); - assert(ret == BTLA_CODE::Success); - return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, - n_offset, blocksize, ldzp, tmp, tmpsize); + ref::decompress_kblock_s2_s8fp(bit2ptr + i / 4, dstptr + i, unpack_elt - i, tmp, tmpsize); + return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s3_fp(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, int col, - void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, - int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - auto ret = BTLA_CODE::NotSupport; - if (col == NTILE) { - int head_end = utils::padto(k_offset, blocksize); - head_end = std::min(head_end, k_offset + row); - int head_size = head_end - k_offset; - if (head_size > 0) { - decompress_kblock_s3_fp_row(b2ptr, b1ptr, dstptr, head_size, scales_, sdtype, zero_points, - k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); - } - int body_size = row - head_size; - if (body_size > 0) { - decompress_kblock_s3_fp_row( - b2ptr + head_size * NTILE / 4, b1ptr + head_size * NTILE / 8, dstptr + head_size * NTILE, body_size, scales_, - sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize); +template +static inline BTLA_CODE decompress_kblock_bit2_packrow_fp(utils::bit2x4* bit2ptr, _DST_T* dstptr, int row, int col, + _ST* scales, int8_t* zero_points, int k_offset, int kblock, + int NPad, void* tmp, size_t tmpsize) { + auto unpack_elt = row * col; + decompress_kblock_s2_s8fp<_S2_T>(bit2ptr, dstptr, unpack_elt, reinterpret_cast(tmp), tmpsize); + // TODO(zhe): simd version + for (int i = 0; i < row; i++) { + int kpos = (k_offset + i) / kblock; + auto sptr = scales + kpos * NPad; + for (int j = 0; j < col; j++) { + float tmp = static_cast(dstptr[i * col + j]); + if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); + dstptr[i * col + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); } - return BTLA_CODE::Success; } - return ret; + + return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s2_fp_row(utils::bit2x4* b2ptr, DST_T* dstptr, int row, void* scales_, - BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, - int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - int constexpr NReg = NTILE / 8; - const auto DstSize = row * NTILE * sizeof(DST_T); - const auto S8Size = row * NTILE * sizeof(int8_t); - auto tmps8ptr = (int8_t*)dstptr; - tmps8ptr += DstSize - S8Size; - auto ret = decompress_kblock_s2_s8(b2ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, - row, NTILE, tmp, tmpsize); - assert(ret == BTLA_CODE::Success); - return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, - n_offset, blocksize, ldzp, tmp, tmpsize); +inline __m256 poly_scale_2nd_ps(const __m256i z, const __m256 f, const __m256 c0, const __m256 c1, const __m256 c2) { + const auto y = _mm256_fmadd_ps(_mm256_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2; + static const auto mask_exp = _mm256_set1_epi32(0x7f800000); + static const auto mask_not_exp = _mm256_set1_epi32(~0x7f800000); + + const auto y_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_exp); + const auto y_not_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_not_exp); + + const auto y_exp_scaled = _mm256_add_epi32(y_exp, _mm256_slli_epi32(z, 23)); + return _mm256_castsi256_ps(_mm256_or_si256(y_not_exp, _mm256_and_si256(y_exp_scaled, mask_exp))); } -template -inline BTLA_CODE decompress_kblock_s2_fp(utils::bit2x4* b2ptr, DST_T* dstptr, int row, int col, void* scales_, - BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, - int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - auto ret = BTLA_CODE::NotSupport; - if (col == NTILE) { - int head_end = utils::padto(k_offset, blocksize); - head_end = std::min(head_end, k_offset + row); - int head_size = head_end - k_offset; - if (head_size > 0) { - decompress_kblock_s2_fp_row(b2ptr, dstptr, head_size, scales_, sdtype, zero_points, - k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); - } - int body_size = row - head_size; - if (body_size > 0) { - decompress_kblock_s2_fp_row(b2ptr + head_size * NTILE / 4, dstptr + head_size * NTILE, - body_size, scales_, sdtype, zero_points, head_end, n_offset, - blocksize, ldzp, tmp, tmpsize); - } - return BTLA_CODE::Success; +inline __m256 exp_ps_0_1(const __m256 x) { + static const auto c0 = _mm256_set1_ps(0.240226507f); + static const auto c1 = _mm256_set1_ps(0.452920674f); + static const auto c2 = _mm256_set1_ps(0.713483036f); + static const float v_log2e = std::log2(std::exp(1.f)); + static const auto log2e = _mm256_set1_ps(v_log2e); + static const auto half = _mm256_set1_ps(.5f); + + static const auto upper_bound = _mm256_set1_ps(88.722838); // log(max_positive_float) + static const auto lower_bound = _mm256_set1_ps(-87.336549); // log(min_positive_float) + __m256 x1 = _mm256_min_ps(x, upper_bound); + x1 = _mm256_max_ps(x1, lower_bound); + + x1 = _mm256_fmadd_ps(x1, log2e, half); // auto x1 = x * log2e + _mm256_set1_ps(.5f); + const auto z = _mm256_floor_ps(x1); + const auto f = _mm256_sub_ps(x1, z); // auto f = x1 - z; + + return poly_scale_2nd_ps(_mm256_cvtps_epi32(z), f, c0, c1, c2); +} + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-attributes" // https://stackoverflow.com/a/49216021 +#endif +// Interleave 8 xmm vectors of words inplace +static inline std::array<__m128i, 8> tr_x8_word(std::array<__m128i, 8>& src) { // NOLINT [runtime/references] + std::array<__m128i, 8> dst; + + for (int i = 0; i < 8; i += 2) { + dst[i + 0] = _mm_unpacklo_epi16(src[i + 0], src[i + 1]); + dst[i + 1] = _mm_unpackhi_epi16(src[i + 0], src[i + 1]); } - return ret; + for (int i = 0; i < 8; i += 4) { + src[i + 0] = _mm_unpacklo_epi32(dst[i + 0], dst[i + 2]); + src[i + 1] = _mm_unpackhi_epi32(dst[i + 0], dst[i + 2]); + src[i + 2] = _mm_unpacklo_epi32(dst[i + 1], dst[i + 3]); + src[i + 3] = _mm_unpackhi_epi32(dst[i + 1], dst[i + 3]); + } + dst[0] = _mm_unpacklo_epi64(src[0], src[4]); + dst[1] = _mm_unpackhi_epi64(src[0], src[4]); + dst[2] = _mm_unpacklo_epi64(src[1], src[5]); + dst[3] = _mm_unpackhi_epi64(src[1], src[5]); + dst[4] = _mm_unpacklo_epi64(src[2], src[6]); + dst[5] = _mm_unpackhi_epi64(src[2], src[6]); + dst[6] = _mm_unpacklo_epi64(src[3], src[7]); + dst[7] = _mm_unpackhi_epi64(src[3], src[7]); + return dst; } -template -static inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, _ST* scales, int k_offset, int kblock, int NPad, - int8_t* tmp, size_t tmpsize) { - if constexpr (_PACK_ROW == 1) { - if (col == 24) { - return decompress_kblock_bit4_packrow1<_F4_T, true, 24, _ST, _DST_T>( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, tmp, tmpsize); - } - if (col == 48) { - return decompress_kblock_bit4_packrow1<_F4_T, true, 48, _ST, _DST_T>( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, tmp, tmpsize); - } - } else if constexpr (_PACK_ROW == 2) { - return decompress_kblock_bit4_packrow2<_F4_T, true, _ST, _DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - nullptr, k_offset, kblock, NPad, tmp, tmpsize); +template +inline std::array<__m128i, 8> load_fp32_fp16_tr_x8_word(const float* a, size_t lda) { + static_assert(tail > 0 && tail <= 8, "Unexpected tail value."); + std::array<__m128i, 8> dst; + for (int i = 0; i < tail; ++i) { + dst[i] = _mm256_cvtps_ph(_mm256_loadu_ps(a + i * lda), _MM_FROUND_TO_NEAREST_INT); } - assert(0); - return BTLA_CODE::NotSupport; + for (int i = tail; i < 8; ++i) dst[i] = _mm_setzero_si128(); + return tr_x8_word(dst); } +constexpr decltype(load_fp32_fp16_tr_x8_word<1>)* load_fp32_fp16_tr_x8_word_tbl[9]{ + load_fp32_fp16_tr_x8_word<1>, load_fp32_fp16_tr_x8_word<1>, load_fp32_fp16_tr_x8_word<2>, + load_fp32_fp16_tr_x8_word<3>, load_fp32_fp16_tr_x8_word<4>, load_fp32_fp16_tr_x8_word<5>, + load_fp32_fp16_tr_x8_word<6>, load_fp32_fp16_tr_x8_word<7>, load_fp32_fp16_tr_x8_word<8>}; -enum class AVX2_REDUCE_TYPE { MAX, MIN, ADD }; -#define AVX2_REDUCE_OP \ - if constexpr (TYPE == AVX2_REDUCE_TYPE::MAX) x = _mm256_max_ps(x, y); \ - if constexpr (TYPE == AVX2_REDUCE_TYPE::MIN) x = _mm256_min_ps(x, y); \ - if constexpr (TYPE == AVX2_REDUCE_TYPE::ADD) x = _mm256_add_ps(x, y); - -template -inline float avx2_reduce_ps(__m256 x) { - __m256 y = _mm256_permute2f128_ps(x, x, 1); - AVX2_REDUCE_OP - y = _mm256_permute_ps(x, 0b01001110); - AVX2_REDUCE_OP - y = _mm256_permute_ps(x, 0b10110001); - AVX2_REDUCE_OP - return _mm256_cvtss_f32(x); +template +inline std::array<__m128i, 8> load_maskz_fp32_fp16_tr_x8_word(const float* a, size_t lda, __m256i mask) { + static_assert(tail > 0 && tail <= 8, "Unexpected tail value."); + std::array<__m128i, 8> dst; + for (int i = 0; i < tail; ++i) { + dst[i] = _mm256_cvtps_ph(_mm256_maskload_ps(a + i * lda, mask), _MM_FROUND_TO_NEAREST_INT); + } + for (int i = tail; i < 8; ++i) dst[i] = _mm_setzero_si128(); + return tr_x8_word(dst); } +constexpr decltype(load_maskz_fp32_fp16_tr_x8_word<1>)* load_maskz_fp32_fp16_tr_x8_word_tbl[9]{ + load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<2>, + load_maskz_fp32_fp16_tr_x8_word<3>, load_maskz_fp32_fp16_tr_x8_word<4>, load_maskz_fp32_fp16_tr_x8_word<5>, + load_maskz_fp32_fp16_tr_x8_word<6>, load_maskz_fp32_fp16_tr_x8_word<7>, load_maskz_fp32_fp16_tr_x8_word<8>}; -#define AVX2_REDUCE_OP_EPI32(dst, src) \ - if constexpr (TYPE == AVX2_REDUCE_TYPE::MAX) dst = _mm256_max_epi32(dst, src); \ - if constexpr (TYPE == AVX2_REDUCE_TYPE::MIN) dst = _mm256_min_epi32(dst, src); \ - if constexpr (TYPE == AVX2_REDUCE_TYPE::ADD) dst = _mm256_add_epi32(dst, src); - -#ifndef _mm256_cvtsi256_si32 -#define _mm256_cvtsi256_si32(a) (_mm_cvtsi128_si32(_mm256_castsi256_si128(a))) +#ifdef __GNUC__ +#pragma GCC diagnostic pop #endif -template -inline int avx2_reduce_epi32(__m256i xd) { - auto x = _mm256_castsi256_ps(xd); - __m256 y = _mm256_permute2f128_ps(x, x, 1); - auto yd = _mm256_castps_si256(y); - AVX2_REDUCE_OP_EPI32(xd, yd); - x = _mm256_castsi256_ps(xd); - y = _mm256_permute_ps(x, 0b01001110); - yd = _mm256_castps_si256(y); - AVX2_REDUCE_OP_EPI32(xd, yd); - x = _mm256_castsi256_ps(xd); - y = _mm256_permute_ps(x, 0b10110001); - yd = _mm256_castps_si256(y); - AVX2_REDUCE_OP_EPI32(xd, yd); - return _mm256_cvtsi256_si32(xd); -} - -inline __m128i avx2_cvtepi32_epu8(__m256i x) { - auto out_v = _mm_packus_epi32(_mm256_castsi256_si128(x), _mm256_extractf128_si256(x, 1)); - out_v = _mm_packus_epi16(out_v, out_v); - return out_v; -} - -template -static inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, - int ld_dst, float* scales, int ld_scale, uint8_t* zps, int blocksize, - float* blkreduce) { - int constexpr VLen = 8; - auto vff = _mm256_set1_epi32(255); - auto v0 = _mm256_set1_epi32(0); - int vblocksize = utils::padto_le(blocksize, VLen); - int colblk = utils::padto_le(col, blocksize); - for (int i = 0; i < row; i++) { - size_t j = 0; - for (; j < colblk; j += blocksize) { - __m256 vmaxval = _mm256_set1_ps(0.f); - __m256 vminval = _mm256_set1_ps(0.f); - size_t ij = 0; - for (; ij < vblocksize; ij += VLen) { - __m256 vsrc; - if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); - if constexpr (std::is_same_v) { - auto vtmp = - _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); - vsrc = ymm_cvt_bf16_fp32(vtmp); - } - vmaxval = _mm256_max_ps(vmaxval, vsrc); - vminval = _mm256_min_ps(vminval, vsrc); - } - auto maxval = avx2_reduce_ps(vmaxval); - auto minval = avx2_reduce_ps(vminval); - if (ij < blocksize) { - for (; ij < blocksize; ij++) { - auto srcval = (float)srcptr[(j + ij) + i * ld_src]; - maxval = std::max(maxval, srcval); - minval = std::min(minval, srcval); - } +template +static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m256* vacc, __m256* vsca) { + if constexpr (MTILE == 1) { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m256 va = _mm256_set1_ps(*(Aptr + ikk)); + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); + ftmp = _mm256_mul_ps(ftmp, vsca[i]); + vacc[i] = _mm256_fmadd_ps(va, ftmp, vacc[i]); } - float scale = (maxval - minval) / 255; - uint8_t zp = utils::cast((0 - minval) / scale); - scales[j / blocksize + i * ld_scale] = scale; - zps[j / blocksize + i * ld_scale] = zp; - int sum = 0; - float rscale = 1.f / scale; - auto vrscale = _mm256_set1_ps(rscale); - auto vdzp = _mm256_set1_epi32(zp); - ij = 0; - if (blkreduce) { - for (; ij < vblocksize; ij += VLen) { - __m256 vsrc; - if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); - if constexpr (std::is_same_v) { - auto vtmp = - _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); - vsrc = ymm_cvt_bf16_fp32(vtmp); + } + } else { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m256 va[MTILE]; + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); + ftmp = _mm256_mul_ps(ftmp, vsca[i]); + for (int im = 0; im < MTILE; im++) { + if (i == 0) { + va[im] = _mm256_set1_ps(*(Aptr + ikk + im * lda)); } - vsrc = _mm256_mul_ps(vsrc, vrscale); - auto vdsrc = _mm256_cvtps_epi32(vsrc); - sum += avx2_reduce_epi32(vdsrc); - vdsrc = _mm256_add_epi32(vdsrc, vdzp); - vdsrc = _mm256_min_epi32(vdsrc, vff); - vdsrc = _mm256_max_epi32(vdsrc, v0); - auto vbsrc = avx2_cvtepi32_epu8(vdsrc); - _mm_storel_epi64(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst]), vbsrc); + vacc[im * NReg + i] = _mm256_fmadd_ps(va[im], ftmp, vacc[im * NReg + i]); } - } else { - for (; ij < vblocksize; ij += VLen) { - __m256 vsrc; - if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); - if constexpr (std::is_same_v) { - auto vtmp = - _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); - vsrc = ymm_cvt_bf16_fp32(vtmp); + } + } + } +} + +template +static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m256* vacc_loc) { + if constexpr (MTILE == 1) { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m256 va = _mm256_set1_ps(*(Aptr + ikk)); + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); + vacc_loc[i] = _mm256_fmadd_ps(va, ftmp, vacc_loc[i]); + } + } + } else { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m256 va[MTILE]; + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); + for (int im = 0; im < MTILE; im++) { + if (i == 0) { + va[im] = _mm256_set1_ps(*(Aptr + ikk + im * lda)); } - vsrc = _mm256_mul_ps(vsrc, vrscale); - auto vdsrc = _mm256_cvtps_epi32(vsrc); - vdsrc = _mm256_add_epi32(vdsrc, vdzp); - vdsrc = _mm256_min_epi32(vdsrc, vff); - vdsrc = _mm256_max_epi32(vdsrc, v0); - auto vbsrc = avx2_cvtepi32_epu8(vdsrc); - _mm_storel_epi64(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst]), vbsrc); + vacc_loc[im * NReg + i] = _mm256_fmadd_ps(va[im], ftmp, vacc_loc[im * NReg + i]); } } - for (; ij < blocksize; ij++) { - auto srcval = (float)srcptr[(j + ij) + i * ld_src]; - srcval = srcval * rscale; - auto srcint = int(roundf(srcval)); - sum += srcint; - srcint += zp; - srcint = std::min(srcint, 0xff); - srcint = std::max(srcint, 0); - dstptr[(j + ij) + i * ld_dst] = static_cast(srcint); - } - if (blkreduce) { - blkreduce[j / blocksize + i * ld_scale] = sum * scale; + } + } +} + +template +static inline BTLA_CODE gemv_4bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& b4ptr = B.b4ptr; + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(8); + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + __m256 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * 8); + } + + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); } - } - if (j < col) { - float maxval = 0.f; - float minval = 0.f; - for (size_t ij = j; ij < col; ij++) { - maxval = std::max((float)srcptr[ij + i * ld_src], maxval); - minval = std::min((float)srcptr[ij + i * ld_src], minval); + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); } - float scale = (maxval - minval) / 255; - uint8_t zp = utils::cast((0 - minval) / scale); - float rscale = 1.f / scale; - scales[j / blocksize + i * ld_scale] = scale; - zps[j / blocksize + i * ld_scale] = zp; - int sum = 0; - for (size_t ij = j; ij < col; ij++) { - auto srcint = utils::cast(srcptr[ij + i * ld_src] * rscale); - sum += srcint; - srcint += zp; - srcint = srcint <= 255 ? srcint : 255; - srcint = srcint >= 0 ? srcint : 0; - dstptr[ij + i * ld_dst] = utils::cast(srcint); + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, bzp[i]); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); } - if (blkreduce) { - blkreduce[j / blocksize + i * ld_scale] = sum * scale; + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, vbias); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); } } } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } return BTLA_CODE::Success; } -template -static inline BTLA_CODE col_block_reduce_sum(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, - float* reduce, int ldr) { - int constexpr VLen = 8; - auto vblock2_ = utils::padto_le(blocksize, VLen * 2); - auto vblock_ = utils::padto_le(blocksize, VLen); - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j += blocksize) { - auto tmp = 0.f; - auto vsum = _mm256_set1_ps(0.f); - int jj = 0; - auto vblock2 = j + vblock2_ <= col ? vblock2_ : 0; - auto vblock = j + vblock_ <= col ? vblock_ : 0; - for (; jj < vblock2; jj += VLen * 2) { - auto vtmp = _mm256_loadu_ps(srcptr + i * ldsrc + j + jj); - auto vtmp1 = _mm256_loadu_ps(srcptr + i * ldsrc + j + jj + VLen); - auto s0 = avx2_reduce_ps(vtmp); - auto s1 = avx2_reduce_ps(vtmp1); - tmp += s0; - tmp += s1; +template +static inline BTLA_CODE gemv_2bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = (utils::bit2x4*)B.b2ptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + auto vbias = _mm256_set1_epi8(2); + + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + + __m256 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm256_setzero_ps(); + } + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); } - if (jj + VLen <= vblock) { - for (; jj < vblock; jj += VLen) { - auto vtmp = _mm256_loadu_ps(srcptr + i * ldsrc + j + jj); - auto s0 = avx2_reduce_ps(vtmp); - tmp += s0; + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, bzp[i]); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b2ptr += 8 * Unroll / 4; } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } - for (; jj < blocksize; jj++) { - tmp += *(srcptr + i * ldsrc + j + jj); + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, vbias); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b2ptr += 8 * Unroll / 4; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } - reduce[i * ldr + j / blocksize] = tmp; } - } - return BTLA_CODE::Success; -} -static inline BTLA_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, float* dst_ptr, int row, int col, - int src_step, int dst_step, bool zeropadding) { - const int npadding = (dst_step - col) * sizeof(float); - constexpr int simd_proc_elt = 8; - auto col_body = col / simd_proc_elt * simd_proc_elt; - for (int i = 0; i < row; i++) { - auto src = const_cast(src_ptr + i * src_step); - auto dst = dst_ptr + i * dst_step; - int j = 0; - for (; j < col_body; j += simd_proc_elt) { - auto bf16_v = _mm_loadu_si128(reinterpret_cast<__m128i*>(src + j)); - auto fp32_v = _mm256_castsi256_ps(_mm256_bslli_epi128(_mm256_cvtepu16_epi32(bf16_v), 2)); - _mm256_storeu_ps(dst + j, fp32_v); + __m256 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * 8); } - for (; j < col; j++) { - *(dst + j) = (src + j)->tofloat(); + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NReg; in++) { + acc[im * NReg + in] = _mm256_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); + } } - if (zeropadding && npadding) std::memset(dst + col, 0, npadding); } - return BTLA_CODE::Success; -} - -static const uint8_t avx2_bf16_convert_maigc_num[32] = { - 0x02, 0x03, 0x06, 0x07, 0x0a, 0x0b, 0x0e, 0x0f, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, - 0x02, 0x03, 0x06, 0x07, 0x0a, 0x0b, 0x0e, 0x0f, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; - -static inline __m128i cvt_fp32_to_bf16(const __m256 src, __m256i* and_helper, __m256i* add_helper) { - auto shuffle_v = _mm256_loadu_si256(reinterpret_cast(avx2_bf16_convert_maigc_num)); - auto round_bias = _mm256_castps_si256(src); - round_bias = _mm256_and_si256(*and_helper, _mm256_srli_si256(round_bias, 2)); - round_bias = _mm256_add_epi32(round_bias, *add_helper); - auto round_fp32_v = _mm256_add_epi32(_mm256_castps_si256(src), round_bias); - __m256i trunc_elements = _mm256_shuffle_epi8(round_fp32_v, shuffle_v); - __m256i ordered = _mm256_permute4x64_epi64(trunc_elements, 0x58); - return _mm256_castsi256_si128(ordered); -} -static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, - int srcstride, int dststride, bool zeropadding) { - auto srcptr = reinterpret_cast(raw_srcptr); - auto dstptr = reinterpret_cast(raw_dstptr); - constexpr int simd_proc_elt = 8; - auto bf16_and_helper = _mm256_set1_epi32(0X00000001); - auto bf16_add_helper = _mm256_set1_epi32(0x00007FFF); - auto col_body_loop = col / simd_proc_elt * simd_proc_elt; - int npadding = dststride - col * sizeof(utils::bf16); - for (int i = 0; i < row; i++) { - auto src = srcptr + i * srcstride; - auto dst = dstptr + i * dststride; - int j = 0; - for (; j < col_body_loop; j += simd_proc_elt) { - auto pack_bf16_value = cvt_fp32_to_bf16(_mm256_loadu_ps(reinterpret_cast(src) + j), - &bf16_and_helper, &bf16_add_helper); - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + j * sizeof(utils::bf16)), pack_bf16_value); - } - for (; j < col; j++) { - (reinterpret_cast(dst) + j)->fromfloat(*(reinterpret_cast(src) + j)); - } - if (zeropadding && npadding) { - std::memset(dst + col * sizeof(utils::bf16), 0, npadding); + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); } } return BTLA_CODE::Success; } -static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, const float* biasptr, float epsilon, - int norm_size, float* dstptr, float* mean_out, float* mean_square_out, - bool simplified) { - int constexpr VLen = 8; - int norm_size8 = utils::padto_le(norm_size, VLen); - int h = 0; - __m256 vmean = _mm256_setzero_ps(), vmeansq = _mm256_setzero_ps(); - for (; h < norm_size8; h += VLen) { - auto tmp = _mm256_loadu_ps(srcptr + h); - vmean = _mm256_add_ps(vmean, tmp); - tmp = _mm256_mul_ps(tmp, tmp); - vmeansq = _mm256_add_ps(vmeansq, tmp); - } - float mean = avx2_reduce_ps(vmean); - float mean_square = avx2_reduce_ps(vmeansq); - for (; h < norm_size; h++) { - mean += srcptr[h]; - mean_square += srcptr[h] * srcptr[h]; - } - mean = mean / norm_size; - if (simplified) { - mean_square = std::sqrt(mean_square / norm_size + epsilon); - } else { - mean_square = std::sqrt(mean_square / norm_size - mean * mean + epsilon); - } - auto vm = _mm256_set1_ps(mean); - float inv_meansq = 1.f / mean_square; - auto vms = _mm256_set1_ps(inv_meansq); - h = 0; - if (simplified) { - if (scaleptr) { - for (; h < norm_size8; h += VLen) { - auto inp = _mm256_loadu_ps(srcptr + h); - auto scale = _mm256_loadu_ps(scaleptr + h); - inp = _mm256_mul_ps(inp, scale); - inp = _mm256_mul_ps(inp, vms); - _mm256_storeu_ps(dstptr + h, inp); - } - for (; h < norm_size; h++) { - dstptr[h] = srcptr[h] * inv_meansq * scaleptr[h]; - } - } else { - for (; h < norm_size8; h += VLen) { - auto inp = _mm256_loadu_ps(srcptr + h); - inp = _mm256_mul_ps(inp, vms); - _mm256_storeu_ps(dstptr + h, inp); - } - for (; h < norm_size; h++) { - dstptr[h] = srcptr[h] * inv_meansq; - } +template +static inline BTLA_CODE gemv_3bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = (utils::bit2x4*)B.b2ptr; + auto b1ptr = (utils::bit1x8*)B.b1ptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + auto vbias = _mm256_set1_epi8(4); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + + __m256 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm256_setzero_ps(); } + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); - } else { - if (scaleptr) { - if (biasptr == nullptr) { - for (; h < norm_size8; h += VLen) { - auto inp = _mm256_loadu_ps(srcptr + h); - auto scale = _mm256_loadu_ps(scaleptr + h); - inp = _mm256_sub_ps(inp, vm); - inp = _mm256_mul_ps(inp, scale); - inp = _mm256_mul_ps(inp, vms); - _mm256_storeu_ps(dstptr + h, inp); - } - for (; h < norm_size; h++) { - dstptr[h] = (srcptr[h] - mean) * inv_meansq * scaleptr[h]; - } - } else { - for (; h < norm_size8; h += VLen) { - auto inp = _mm256_loadu_ps(srcptr + h); - auto scale = _mm256_loadu_ps(scaleptr + h); - inp = _mm256_sub_ps(inp, vm); - inp = _mm256_mul_ps(inp, vms); - inp = _mm256_mul_ps(inp, scale); - auto bias = _mm256_loadu_ps(biasptr + h); - inp = _mm256_add_ps(inp, bias); - _mm256_storeu_ps(dstptr + h, inp); - } - for (; h < norm_size; h++) { - dstptr[h] = (srcptr[h] - mean) * inv_meansq * scaleptr[h] + biasptr[h]; + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b2ptr += 8 * Unroll / 4; + b1ptr += 8 * Unroll / 8; } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } + } else { - for (; h < norm_size8; h += VLen) { - auto inp = _mm256_loadu_ps(srcptr + h); - inp = _mm256_sub_ps(inp, vm); - inp = _mm256_mul_ps(inp, vms); - _mm256_storeu_ps(dstptr + h, inp); + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b2ptr += 8 * Unroll / 4; + b1ptr += 8 * Unroll / 8; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } - for (; h < norm_size; h++) { - dstptr[h] = (srcptr[h] - mean) * inv_meansq; + } + + __m256 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * 8); + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NReg; in++) { + acc[im * NReg + in] = _mm256_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); } } } - if (mean_out) { - *mean_out = mean; - } - if (mean_square_out) { - *mean_square_out = mean_square; + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } } return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, - int interleave_n_offset, int unpack_elt, int8_t* tmp, size_t tmpsize) { - auto head_ignore_num = interleave_n_offset % 128; - const __m256i lowMask = _mm256_set1_epi8(0x03); +template +static inline BTLA_CODE gemv_5bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b4ptr = (utils::bit4x2*)B.b4ptr; + auto b1ptr = (utils::bit1x8*)B.b1ptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); + const __m256i highMask = _mm256_set1_epi8(0x04); const __m256i bit1Mask = _mm256_set1_epi32(0x0F); const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; - auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { - __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1); - int32_t* bit1_ptr = reinterpret_cast(src2); - for (int i = 0; i < 4; i++) { - auto bit1x32 = _mm256_set1_epi32(bit1_ptr[i]); - bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1); - bit1x32 = _mm256_and_si256(bit1x32, bit1Mask); - bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2); - bit1x32 = _mm256_and_si256(highMask, bit1x32); + __m256 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm256_setzero_ps(); + } + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); - auto bit2x32 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2 * i)); - auto res = _mm256_add_epi8(bit1x32, bit2x32); - res = _mm256_sub_epi8(res, highMask); - _mm256_storeu_si256((__m256i*)(dst + 32 * i), res); + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b4ptr += 8 * Unroll / 2; + b1ptr += 8 * Unroll / 8; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b4ptr += 8 * Unroll / 2; + b1ptr += 8 * Unroll / 8; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } } - }; - int compress_wei_ptr_offset = 0; - if (head_ignore_num != 0) { - assert(head_ignore_num % 8 == 0); - auto base_bit2ptr = bit2ptr - head_ignore_num / 4; - auto base_bit1ptr = bit1ptr - head_ignore_num / 8; - auto head_write_num = 128 - head_ignore_num; - bit3_interleave_decompress_pack128(base_bit2ptr, base_bit1ptr, tmp); - for (int i = 0; i < head_write_num; i++) dstptr[i] = tmp[head_ignore_num + i]; - compress_wei_ptr_offset += head_write_num; - unpack_elt -= head_write_num; + __m256 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * 8); + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NReg; in++) { + acc[im * NReg + in] = _mm256_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); + } + } } - auto body_loop = unpack_elt / 128; - auto tail_proc_num = unpack_elt % 128; - bestla::kernel::jit::DecompressS3::forward_avx2(bit2ptr + compress_wei_ptr_offset / 4, - bit1ptr + compress_wei_ptr_offset / 8, - dstptr + compress_wei_ptr_offset, tmp, body_loop * 128); - compress_wei_ptr_offset += body_loop * 128; - if (tail_proc_num > 0) { - bit3_interleave_decompress_pack128(bit2ptr + compress_wei_ptr_offset / 4, bit1ptr + compress_wei_ptr_offset / 8, - tmp); - for (int i = 0; i < tail_proc_num; i++) dstptr[compress_wei_ptr_offset + i] = tmp[i]; + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } } return BTLA_CODE::Success; } -template -static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, - _DST_T* dstptr, int interleave_n_offset, int row, int col, - _ST* scales, int8_t* zero_points, int k_offset, int kblock, - int NPad, void* tmp, size_t tmpsize) { - auto unpack_elt = row * col; - decompress_kblock_s3_s8fp<_S3_T>(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, - reinterpret_cast(tmp), tmpsize); - // TODO(zhe): simd version - for (int i = 0; i < row; i++) { - int kpos = (k_offset + i) / kblock; - auto sptr = scales + kpos * NPad; - for (int j = 0; j < col; j++) { - float tmp = static_cast(dstptr[i * col + j]); - if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); - dstptr[i * col + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); - } +template +static inline BTLA_CODE gemv_6bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b4ptr = (utils::bit4x2*)B.b4ptr; + auto b2ptr = (utils::bit2x4*)B.b2ptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); } - return BTLA_CODE::Success; -} + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); -template -inline BTLA_CODE decompress_kblock_s2_s8fp(utils::bit2x4* bit2ptr, _DST_T* dstptr, int unpack_elt, int8_t* tmp, - size_t tmpsize) { - int constexpr VBits = 256; - int constexpr VElt = VBits / 8; - int i = 0; - uint64_t mask0 = 0x0303030303030303; - auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + uint32_t mask0 = 0x03030303; + auto vmask0 = _mm256_set1_epi32(*(int32_t*)&mask0); auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); - int elt_pad = utils::padto_le(unpack_elt, VElt); - for (; i < elt_pad; i += VElt) { - auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - if (std::is_same_v<_DST_T, int8_t>) { - _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + + __m256 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm256_setzero_ps(); + } + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b4ptr += 8 * Unroll / 2; + b2ptr += 8 * Unroll / 4; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + } else { - _mm256_storeu_si256((__m256i*)tmp, vout); - for (int j = 0; j < VElt; j++) { - dstptr[i + j] = tmp[j]; + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b4ptr += 8 * Unroll / 2; + b2ptr += 8 * Unroll / 4; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } } - } - ref::decompress_kblock_s2_s8fp(bit2ptr + i / 4, dstptr + i, unpack_elt - i, tmp, tmpsize); - return BTLA_CODE::Success; -} -template -static inline BTLA_CODE decompress_kblock_bit2_packrow_fp(utils::bit2x4* bit2ptr, _DST_T* dstptr, int row, int col, - _ST* scales, int8_t* zero_points, int k_offset, int kblock, - int NPad, void* tmp, size_t tmpsize) { - auto unpack_elt = row * col; - decompress_kblock_s2_s8fp<_S2_T>(bit2ptr, dstptr, unpack_elt, reinterpret_cast(tmp), tmpsize); - // TODO(zhe): simd version - for (int i = 0; i < row; i++) { - int kpos = (k_offset + i) / kblock; - auto sptr = scales + kpos * NPad; - for (int j = 0; j < col; j++) { - float tmp = static_cast(dstptr[i * col + j]); - if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); - dstptr[i * col + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); + __m256 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * 8); + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NReg; in++) { + acc[im * NReg + in] = _mm256_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); + } } } + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } return BTLA_CODE::Success; } -inline __m256 poly_scale_2nd_ps(const __m256i z, const __m256 f, const __m256 c0, const __m256 c1, const __m256 c2) { - const auto y = _mm256_fmadd_ps(_mm256_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2; - static const auto mask_exp = _mm256_set1_epi32(0x7f800000); - static const auto mask_not_exp = _mm256_set1_epi32(~0x7f800000); - - const auto y_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_exp); - const auto y_not_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_not_exp); - - const auto y_exp_scaled = _mm256_add_epi32(y_exp, _mm256_slli_epi32(z, 23)); - return _mm256_castsi256_ps(_mm256_or_si256(y_not_exp, _mm256_and_si256(y_exp_scaled, mask_exp))); +static inline __m256i _mm256_dpbusd_avx2_epi32(__m256i& c, const __m256i& a, const __m256i& b) { + const __m256i dot2 = _mm256_maddubs_epi16(a, b); + const __m256i ones = _mm256_set1_epi16(1); + const __m256i sum4 = _mm256_madd_epi16(ones, dot2); + return _mm256_add_epi32(c, sum4); } -inline __m256 exp_ps_0_1(const __m256 x) { - static const auto c0 = _mm256_set1_ps(0.240226507f); - static const auto c1 = _mm256_set1_ps(0.452920674f); - static const auto c2 = _mm256_set1_ps(0.713483036f); - static const float v_log2e = std::log2(std::exp(1.f)); - static const auto log2e = _mm256_set1_ps(v_log2e); - static const auto half = _mm256_set1_ps(.5f); - - static const auto upper_bound = _mm256_set1_ps(88.722838); // log(max_positive_float) - static const auto lower_bound = _mm256_set1_ps(-87.336549); // log(min_positive_float) - __m256 x1 = _mm256_min_ps(x, upper_bound); - x1 = _mm256_max_ps(x1, lower_bound); +template +static inline void gemv_dequant_s32fp32(const float* asptr, int ldzp, const ScaleT* bsptr, __m256i* iacc, + __m256* facc) { + __m256 v_a_scale[MTILE]; + for (int im = 0; im < MTILE; im++) { + v_a_scale[im] = _mm256_set1_ps(*(asptr + im * ldzp)); + } - x1 = _mm256_fmadd_ps(x1, log2e, half); // auto x1 = x * log2e + _mm256_set1_ps(.5f); - const auto z = _mm256_floor_ps(x1); - const auto f = _mm256_sub_ps(x1, z); // auto f = x1 - z; + for (int i = 0; i < NReg; i++) { + __m256 v_b_scale = load_T_fp32(bsptr + i * 8); + for (int im = 0; im < MTILE; im++) { + auto vtmp = _mm256_mul_ps(v_a_scale[im], v_b_scale); + auto tmp = _mm256_cvtepi32_ps(iacc[im * NReg + i]); + facc[im * NReg + i] = _mm256_fmadd_ps(tmp, vtmp, facc[im * NReg + i]); + } + } +} - return poly_scale_2nd_ps(_mm256_cvtps_epi32(z), f, c0, c1, c2); +template +static inline void gemv_remove_zp(const uint8_t* azptr, int ldzp, __m256i* iacc, __m256i* bacc) { + if constexpr (MReg == 1) { + auto zp = int(azptr[0]); + __m256i v_a_zp = _mm256_set1_epi32(zp); + for (int in = 0; in < NReg; in++) { + auto vtmp = _mm256_mullo_epi32(v_a_zp, bacc[in]); + iacc[in] = _mm256_sub_epi32(iacc[in], vtmp); + } + } else { + __m256i v_a_zp[MReg]; + for (int im = 0; im < MReg; im++) { + auto zp = int(azptr[im * ldzp]); + v_a_zp[im] = _mm256_set1_epi32(zp); + for (int in = 0; in < NReg; in++) { + auto vtmp = _mm256_mullo_epi32(v_a_zp[im], bacc[in]); + iacc[im * NReg + in] = _mm256_sub_epi32(iacc[im * NReg + in], vtmp); + } + } + } } -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wignored-attributes" // https://stackoverflow.com/a/49216021 -#endif -// Interleave 8 xmm vectors of words inplace -static inline std::array<__m128i, 8> tr_x8_word(std::array<__m128i, 8>& src) { // NOLINT [runtime/references] - std::array<__m128i, 8> dst; +template +static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& a8ptr = A.aptr; + auto& b4ptr = B.b4ptr; + auto& asptr = A.sptr; + auto& azptr = A.zpptr; - for (int i = 0; i < 8; i += 2) { - dst[i + 0] = _mm_unpacklo_epi16(src[i + 0], src[i + 1]); - dst[i + 1] = _mm_unpackhi_epi16(src[i + 0], src[i + 1]); - } - for (int i = 0; i < 8; i += 4) { - src[i + 0] = _mm_unpacklo_epi32(dst[i + 0], dst[i + 2]); - src[i + 1] = _mm_unpackhi_epi32(dst[i + 0], dst[i + 2]); - src[i + 2] = _mm_unpacklo_epi32(dst[i + 1], dst[i + 3]); - src[i + 3] = _mm_unpackhi_epi32(dst[i + 1], dst[i + 3]); + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); } - dst[0] = _mm_unpacklo_epi64(src[0], src[4]); - dst[1] = _mm_unpackhi_epi64(src[0], src[4]); - dst[2] = _mm_unpacklo_epi64(src[1], src[5]); - dst[3] = _mm_unpackhi_epi64(src[1], src[5]); - dst[4] = _mm_unpacklo_epi64(src[2], src[6]); - dst[5] = _mm_unpackhi_epi64(src[2], src[6]); - dst[6] = _mm_unpacklo_epi64(src[3], src[7]); - dst[7] = _mm_unpackhi_epi64(src[3], src[7]); - return dst; -} + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(8); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + __m256i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += 4) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += 4) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + } + } + } + } -template -inline std::array<__m128i, 8> load_fp32_fp16_tr_x8_word(const float* a, size_t lda) { - static_assert(tail > 0 && tail <= 8, "Unexpected tail value."); - std::array<__m128i, 8> dst; - for (int i = 0; i < tail; ++i) { - dst[i] = _mm256_cvtps_ph(_mm256_loadu_ps(a + i * lda), _MM_FROUND_TO_NEAREST_INT); + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } - for (int i = tail; i < 8; ++i) dst[i] = _mm_setzero_si128(); - return tr_x8_word(dst); -} -constexpr decltype(load_fp32_fp16_tr_x8_word<1>)* load_fp32_fp16_tr_x8_word_tbl[9]{ - load_fp32_fp16_tr_x8_word<1>, load_fp32_fp16_tr_x8_word<1>, load_fp32_fp16_tr_x8_word<2>, - load_fp32_fp16_tr_x8_word<3>, load_fp32_fp16_tr_x8_word<4>, load_fp32_fp16_tr_x8_word<5>, - load_fp32_fp16_tr_x8_word<6>, load_fp32_fp16_tr_x8_word<7>, load_fp32_fp16_tr_x8_word<8>}; -template -inline std::array<__m128i, 8> load_maskz_fp32_fp16_tr_x8_word(const float* a, size_t lda, __m256i mask) { - static_assert(tail > 0 && tail <= 8, "Unexpected tail value."); - std::array<__m128i, 8> dst; - for (int i = 0; i < tail; ++i) { - dst[i] = _mm256_cvtps_ph(_mm256_maskload_ps(a + i * lda, mask), _MM_FROUND_TO_NEAREST_INT); + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } } - for (int i = tail; i < 8; ++i) dst[i] = _mm_setzero_si128(); - return tr_x8_word(dst); + return BTLA_CODE::Success; } -constexpr decltype(load_maskz_fp32_fp16_tr_x8_word<1>)* load_maskz_fp32_fp16_tr_x8_word_tbl[9]{ - load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<2>, - load_maskz_fp32_fp16_tr_x8_word<3>, load_maskz_fp32_fp16_tr_x8_word<4>, load_maskz_fp32_fp16_tr_x8_word<5>, - load_maskz_fp32_fp16_tr_x8_word<6>, load_maskz_fp32_fp16_tr_x8_word<7>, load_maskz_fp32_fp16_tr_x8_word<8>}; -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif +template +static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); -template -static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m256* vacc, __m256* vsca) { - if constexpr (MTILE == 1) { - for (int ikk = 0; ikk < Unroll; ikk++) { - __m256 va = _mm256_set1_ps(*(Aptr + ikk)); - for (int i = 0; i < NReg; i++) { - auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); - ftmp = _mm256_mul_ps(ftmp, vsca[i]); - vacc[i] = _mm256_fmadd_ps(va, ftmp, vacc[i]); - } + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(4); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + int constexpr KTILE = 4; + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + __m256i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); } - } else { - for (int ikk = 0; ikk < Unroll; ikk++) { - __m256 va[MTILE]; + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; for (int i = 0; i < NReg; i++) { - auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); - ftmp = _mm256_mul_ps(ftmp, vsca[i]); - for (int im = 0; im < MTILE; im++) { - if (i == 0) { - va[im] = _mm256_set1_ps(*(Aptr + ikk + im * lda)); + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; } - vacc[im * NReg + i] = _mm256_fmadd_ps(va[im], ftmp, vacc[im * NReg + i]); } } - } - } -} + } else { + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); -template -static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m256* vacc_loc) { - if constexpr (MTILE == 1) { - for (int ikk = 0; ikk < Unroll; ikk++) { - __m256 va = _mm256_set1_ps(*(Aptr + ikk)); - for (int i = 0; i < NReg; i++) { - auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); - vacc_loc[i] = _mm256_fmadd_ps(va, ftmp, vacc_loc[i]); - } - } - } else { - for (int ikk = 0; ikk < Unroll; ikk++) { - __m256 va[MTILE]; - for (int i = 0; i < NReg; i++) { - auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); - for (int im = 0; im < MTILE; im++) { - if (i == 0) { - va[im] = _mm256_set1_ps(*(Aptr + ikk + im * lda)); + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; } - vacc_loc[im * NReg + i] = _mm256_fmadd_ps(va[im], ftmp, vacc_loc[im * NReg + i]); } } } + + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } } + return BTLA_CODE::Success; } template -static inline BTLA_CODE gemv_4bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_5bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto& b4ptr = B.b4ptr; + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + int blks = k / blocksize; int constexpr NReg = NTILE / 8; int constexpr MReg = MTILE; - // Initialize accumulator with zeros __m256 acc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } + + int constexpr FullRange = 1 << (5 - 1); uint32_t mask = 0x0f0f0f0f; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); - auto vbias = _mm256_set1_epi8(8); + auto vbias = _mm256_set1_epi8(FullRange); + + const __m256i onesu8 = _mm256_set1_epi8(1); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { - auto bsptr = B.sptr + ib * B.ldzp; - __m256 v_b_scale[NReg]; + __m256i iacc[NReg * MReg]; + __m256i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } for (int i = 0; i < NReg; i++) { - v_b_scale[i] = load_T_fp32(bsptr + i * 8); + bacc[i] = _mm256_setzero_si256(); } - - int constexpr Unroll = 4; - assert((blocksize % 4) == 0); - assert(tmpsize >= NTILE * Unroll); - if (B.zpptr) { __m256i bzp[NReg]; auto bzptr = B.zpptr + ib * B.ldzp; - - for (int i = 0; i < Unroll; i++) { - memcpy(tmp + i * NTILE, bzptr, NTILE); - } for (int i = 0; i < NReg; i++) { - bzp[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); bzp[i] = _mm256_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += Unroll) { - for (int i = 0; i < NReg; i++) { - auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); - vb = _mm256_sub_epi8(vb, bzp[i]); - _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + b4ptr += 8 * KTILE / 2; + b1ptr += 8 * KTILE / 8; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b4ptr += 8 * KTILE / 2; + b1ptr += 8 * KTILE / 8; + } } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); } - } else { - for (int ik = 0; ik < blocksize; ik += Unroll) { - for (int i = 0; i < NReg; i++) { - auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); - vb = _mm256_sub_epi8(vb, vbias); - _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + b4ptr += 8 * KTILE / 2; + b1ptr += 8 * KTILE / 8; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b4ptr += 8 * KTILE / 2; + b1ptr += 8 * KTILE / 8; + } } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); } } + + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } for (int j = 0; j < MReg; j++) { @@ -2450,79 +3821,124 @@ static inline BTLA_CODE gemv_4bit_fp32_fp32(const float* A, int lda, const utils } template -static inline BTLA_CODE gemv_2bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_6bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto b2ptr = (utils::bit2x4*)B.b2ptr; + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b2ptr = reinterpret_cast(B.b2ptr); int blks = k / blocksize; int constexpr NReg = NTILE / 8; int constexpr MReg = MTILE; - // Initialize accumulator with zeros __m256 acc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } - uint64_t mask0 = 0x0303030303030303; - auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); + + const __m256i onesu8 = _mm256_set1_epi8(1); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + + uint32_t mask0 = 0x03030303; + auto vmask0 = _mm256_set1_epi32(*(int32_t*)&mask0); auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); - auto vbias = _mm256_set1_epi8(2); - - int constexpr KTILE = 1; + int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { - auto bsptr = B.sptr + ib * B.ldzp; - - __m256 acc_loc[NReg * MReg]; + __m256i iacc[NReg * MReg]; + __m256i bacc[NReg]; for (int i = 0; i < NReg * MReg; i++) { - acc_loc[i] = _mm256_setzero_ps(); + iacc[i] = _mm256_setzero_si256(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm256_setzero_si256(); } - int constexpr Unroll = 4; - assert((blocksize % 4) == 0); - assert(tmpsize >= NTILE * Unroll); - if (B.zpptr) { __m256i bzp[NReg]; auto bzptr = B.zpptr + ib * B.ldzp; - for (int i = 0; i < Unroll; i++) { - memcpy(tmp + i * NTILE, bzptr, NTILE); - } for (int i = 0; i < NReg; i++) { - bzp[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); bzp[i] = _mm256_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += Unroll) { - for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - vb = _mm256_sub_epi8(vb, bzp[i]); - _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); - b2ptr += 8 * Unroll / 4; + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + b4ptr += 8 * KTILE / 2; + b2ptr += 8 * KTILE / 4; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b4ptr += 8 * KTILE / 2; + b2ptr += 8 * KTILE / 4; + } } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } - } else { - for (int ik = 0; ik < blocksize; ik += Unroll) { - for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - vb = _mm256_sub_epi8(vb, vbias); - _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); - b2ptr += 8 * Unroll / 4; + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + b4ptr += 8 * KTILE / 2; + b2ptr += 8 * KTILE / 4; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b4ptr += 8 * KTILE / 2; + b2ptr += 8 * KTILE / 4; + } } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } } - __m256 v_b_scale[NReg]; - for (int i = 0; i < NReg; i++) { - v_b_scale[i] = load_T_fp32(bsptr + i * 8); - } - for (int im = 0; im < MTILE; im++) { - for (int in = 0; in < NReg; in++) { - acc[im * NReg + in] = _mm256_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); - } - } + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } for (int j = 0; j < MReg; j++) { @@ -2534,91 +3950,101 @@ static inline BTLA_CODE gemv_2bit_fp32_fp32(const float* A, int lda, const utils } template -static inline BTLA_CODE gemv_3bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto b2ptr = (utils::bit2x4*)B.b2ptr; - auto b1ptr = (utils::bit1x8*)B.b1ptr; + auto b2ptr = reinterpret_cast(B.b2ptr); int blks = k / blocksize; int constexpr NReg = NTILE / 8; int constexpr MReg = MTILE; - // Initialize accumulator with zeros __m256 acc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } - uint64_t mask0 = 0x0303030303030303; auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); - auto vbias = _mm256_set1_epi8(4); - - const __m256i highMask = _mm256_set1_epi8(0x04); - const __m256i bit1Mask = _mm256_set1_epi32(0x0F); - const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); - const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); - int constexpr KTILE = 1; + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(2); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { - auto bsptr = B.sptr + ib * B.ldzp; - - __m256 acc_loc[NReg * MReg]; + __m256i iacc[NReg * MReg]; + __m256i bacc[NReg]; for (int i = 0; i < NReg * MReg; i++) { - acc_loc[i] = _mm256_setzero_ps(); + iacc[i] = _mm256_setzero_si256(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm256_setzero_si256(); } - int constexpr Unroll = 4; - assert((blocksize % 4) == 0); - assert(tmpsize >= NTILE * Unroll); - if (B.zpptr) { __m256i bzp[NReg]; auto bzptr = B.zpptr + ib * B.ldzp; - for (int i = 0; i < Unroll; i++) { - memcpy(tmp + i * NTILE, bzptr, NTILE); - } for (int i = 0; i < NReg; i++) { - bzp[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); bzp[i] = _mm256_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += Unroll) { - for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); - vb = _mm256_or_si256(vb, vb1); - vb = _mm256_sub_epi8(vb, bzp[i]); - _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); - b2ptr += 8 * Unroll / 4; - b1ptr += 8 * Unroll / 8; + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + } } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } - } else { - for (int ik = 0; ik < blocksize; ik += Unroll) { - for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); - vb = _mm256_or_si256(vb, vb1); - vb = _mm256_sub_epi8(vb, vbias); - _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); - b2ptr += 8 * Unroll / 4; - b1ptr += 8 * Unroll / 8; + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + } } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } } - __m256 v_b_scale[NReg]; - for (int i = 0; i < NReg; i++) { - v_b_scale[i] = load_T_fp32(bsptr + i * 8); - } - for (int im = 0; im < MTILE; im++) { - for (int in = 0; in < NReg; in++) { - acc[im * NReg + in] = _mm256_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); - } - } + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } for (int j = 0; j < MReg; j++) { @@ -2629,52 +4055,13 @@ static inline BTLA_CODE gemv_3bit_fp32_fp32(const float* A, int lda, const utils return BTLA_CODE::Success; } -static inline __m256i _mm256_dpbusd_avx2_epi32(__m256i& c, const __m256i& a, const __m256i& b) { - const __m256i dot2 = _mm256_maddubs_epi16(a, b); - const __m256i ones = _mm256_set1_epi16(1); - const __m256i sum4 = _mm256_madd_epi16(ones, dot2); - return _mm256_add_epi32(c, sum4); -} - -template -static inline void gemv_dequant_s32fp32(const float* asptr, int ldzp, const ScaleT* bsptr, __m256i* iacc, - __m256* facc) { - __m256 v_a_scale[MTILE]; - for (int im = 0; im < MTILE; im++) { - v_a_scale[im] = _mm256_set1_ps(*(asptr + im * ldzp)); - } - - for (int i = 0; i < NReg; i++) { - __m256 v_b_scale = load_T_fp32(bsptr + i * 8); - for (int im = 0; im < MTILE; im++) { - auto vtmp = _mm256_mul_ps(v_a_scale[im], v_b_scale); - auto tmp = _mm256_cvtepi32_ps(iacc[im * NReg + i]); - facc[im * NReg + i] = _mm256_fmadd_ps(tmp, vtmp, facc[im * NReg + i]); - } - } -} +namespace vnni { -template -static inline void gemv_remove_zp(const uint8_t* azptr, int ldzp, __m256i* iacc, __m256i* bacc) { - if constexpr (MReg == 1) { - auto zp = int(azptr[0]); - __m256i v_a_zp = _mm256_set1_epi32(zp); - for (int in = 0; in < NReg; in++) { - auto vtmp = _mm256_mullo_epi32(v_a_zp, bacc[in]); - iacc[in] = _mm256_sub_epi32(iacc[in], vtmp); - } - } else { - __m256i v_a_zp[MReg]; - for (int im = 0; im < MReg; im++) { - auto zp = int(azptr[im * ldzp]); - v_a_zp[im] = _mm256_set1_epi32(zp); - for (int in = 0; in < NReg; in++) { - auto vtmp = _mm256_mullo_epi32(v_a_zp[im], bacc[in]); - iacc[im * NReg + in] = _mm256_sub_epi32(iacc[im * NReg + in], vtmp); - } - } - } -} +#if CompileAVXVNNI() +#ifdef __GNUC__ +#pragma GCC push_options +#pragma GCC target("avxvnni") +#endif template static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, @@ -2718,12 +4105,11 @@ static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const ut for (int ik = 0; ik < blocksize; ik += 4) { if constexpr (MTILE == 1) { __m256i va = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); - for (int i = 0; i < NReg; i++) { auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); vb = _mm256_sub_epi8(vb, bzp[i]); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); - iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); } } else { __m256i va[MReg]; @@ -2733,9 +4119,9 @@ static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const ut for (int i = 0; i < NReg; i++) { auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); vb = _mm256_sub_epi8(vb, bzp[i]); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { - iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); } } } @@ -2747,8 +4133,8 @@ static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const ut for (int i = 0; i < NReg; i++) { auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); vb = _mm256_sub_epi8(vb, vbias); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); - iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); } } else { __m256i va[MReg]; @@ -2758,16 +4144,16 @@ static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const ut for (int i = 0; i < NReg; i++) { auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); vb = _mm256_sub_epi8(vb, vbias); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { - iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); } } } } } - gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } @@ -2780,10 +4166,11 @@ static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const ut } template -static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto b2ptr = reinterpret_cast(B.b2ptr); - auto b1ptr = reinterpret_cast(B.b1ptr); + auto& a8ptr = A.aptr; + auto& b4ptr = B.b4ptr; + auto& asptr = A.sptr; int blks = k / blocksize; int constexpr NReg = NTILE / 8; @@ -2792,255 +4179,337 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } - uint64_t mask0 = 0x0303030303030303; - auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); - auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); - auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, - 13, 9, 5, 1, 12, 8, 4, 0); - auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); - const __m256i onesu8 = _mm256_set1_epi8(1); - const __m256i vbias = _mm256_set1_epi8(4); + const __m256i vbias = _mm256_set1_epi8(8); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += 4) { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, bzp[i]); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm256_sign_epi8(vb, va[j]); + auto vabsa = _mm256_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += 4) { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, vbias); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm256_sign_epi8(vb, va[j]); + auto vabsa = _mm256_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); + } + } + } + } + + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_u8s8_fp32_align128(const utils::GemvParamA& A, const utils::GemvParamB& B, + float* C, int k, int ld_scaleb, int blocksize, int8_t* tmp, + size_t tmpsize) { + auto a8ptr = A.aptr; + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + auto asptr = A.sptr; + auto azptr = A.zpptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + // Initialize accumulator with zeros + __m256 acc[NReg]; + int constexpr EltPadding = 128; + static_assert(NTILE % 8 == 0); + int constexpr KTILE = 4; + int constexpr UnpackElt = EltPadding / 8 / KTILE; + int constexpr TotalElt = UnpackElt * NTILE * KTILE; + int constexpr Loop128 = TotalElt / 128; + int8_t UnpackBuf[TotalElt]; + for (int i = 0; i < NReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i lowMask = _mm256_set1_epi8(0x03); const __m256i highMask = _mm256_set1_epi8(0x04); const __m256i bit1Mask = _mm256_set1_epi32(0x0F); const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); - int constexpr KTILE = 4; + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { + __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1); + int32_t* bit1_ptr = reinterpret_cast(src2); + for (int i = 0; i < 4; i++) { + auto bit1x32 = _mm256_set1_epi32(bit1_ptr[i]); + bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1); + bit1x32 = _mm256_and_si256(bit1x32, bit1Mask); + bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2); + bit1x32 = _mm256_and_si256(highMask, bit1x32); + + auto bit2x32 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2 * i)); + auto res = _mm256_add_epi8(bit1x32, bit2x32); + res = _mm256_sub_epi8(res, highMask); + _mm256_storeu_si256((__m256i*)(dst + 32 * i), res); + } + }; + assert(azptr); for (int ib = 0; ib < blks; ib += 1) { - __m256i iacc[NReg * MReg]; + __m256i iacc[NReg]; __m256i bacc[NReg]; - for (int i = 0; i < NReg * MReg; i++) { - iacc[i] = _mm256_setzero_si256(); - } for (int i = 0; i < NReg; i++) { + iacc[i] = _mm256_setzero_si256(); bacc[i] = _mm256_setzero_si256(); } if (B.zpptr) { __m256i bzp[NReg]; - auto bzptr = B.zpptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * ld_scaleb; for (int i = 0; i < NReg; i++) { bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); - bzp[i] = _mm256_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += KTILE) { - if constexpr (MTILE == 1) { - __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); - for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); - vb = _mm256_or_si256(vb, vb1); - vb = _mm256_sub_epi8(vb, bzp[i]); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); - iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); - b2ptr += 8 * KTILE / 4; - b1ptr += 8 * KTILE / 8; - } - } else { - __m256i va[MReg]; - for (int i = 0; i < MReg; i++) { - va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); - } - for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); - vb = _mm256_or_si256(vb, vb1); + for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { + for (int il = 0; il < Loop128; il++) { + bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); + b2ptr += 128 / 4; + b1ptr += 128 / 8; + } + for (int iu = 0; iu < UnpackElt; iu++) { + auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); + for (int i = 0; i < NReg; i++) { + auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); vb = _mm256_sub_epi8(vb, bzp[i]); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); - for (int j = 0; j < MReg; j++) { - iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); - } - b2ptr += 8 * KTILE / 4; - b1ptr += 8 * KTILE / 8; + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); } } + a8ptr += KTILE * UnpackElt; } } else { - for (int ik = 0; ik < blocksize; ik += KTILE) { - if constexpr (MTILE == 1) { - __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); - for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); - vb = _mm256_or_si256(vb, vb1); - vb = _mm256_sub_epi8(vb, vbias); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); - iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); - - b2ptr += 8 * KTILE / 4; - b1ptr += 8 * KTILE / 8; - } - } else { - __m256i va[MReg]; - for (int i = 0; i < MReg; i++) { - va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); - } + for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { + for (int il = 0; il < Loop128; il++) { + bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); + b2ptr += 128 / 4; + b1ptr += 128 / 8; + } + for (int iu = 0; iu < UnpackElt; iu++) { + auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); - vb = _mm256_or_si256(vb, vb1); - vb = _mm256_sub_epi8(vb, vbias); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); - for (int j = 0; j < MReg; j++) { - iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); - } - b2ptr += 8 * KTILE / 4; - b1ptr += 8 * KTILE / 8; + auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); } } + a8ptr += KTILE * UnpackElt; } } - gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); - gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); - } - - for (int j = 0; j < MReg; j++) { + const __m256 v_a_scale = _mm256_set1_ps(*(asptr + ib)); + auto zp = int(azptr[ib]); + const __m256i v_a_zp = _mm256_set1_epi32(zp); + auto bsptr = B.sptr + ib * ld_scaleb; for (int i = 0; i < NReg; i++) { - _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + bacc[i] = _mm256_mullo_epi32(v_a_zp, bacc[i]); + iacc[i] = _mm256_sub_epi32(iacc[i], bacc[i]); + __m256 v_b_scale; + if constexpr (std::is_same_v) { + v_b_scale = _mm256_loadu_ps(bsptr + i * 8); + } else if constexpr (std::is_same_v) { + auto tmp = _mm_loadu_si128((const __m128i*)(bsptr + i * 8)); + v_b_scale = kernel::avx2::ymm_cvt_bf16_fp32(tmp); + } + v_b_scale = _mm256_mul_ps(v_a_scale, v_b_scale); + auto tmp = _mm256_cvtepi32_ps(iacc[i]); + acc[i] = _mm256_fmadd_ps(tmp, v_b_scale, acc[i]); } } + + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8, acc[i]); + } return BTLA_CODE::Success; } -template -static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, - int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { +template +static inline BTLA_CODE gemv_3bit_s8s8_fp32_align128(const utils::GemvParamA& A, const utils::GemvParamB& B, + float* C, int k, int ld_scaleb, int blocksize, int8_t* tmp, + size_t tmpsize) { + auto a8ptr = A.aptr; auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + auto asptr = A.sptr; + auto azptr = A.zpptr; int blks = k / blocksize; int constexpr NReg = NTILE / 8; - int constexpr MReg = MTILE; - __m256 acc[NReg * MReg]; - for (int i = 0; i < NReg * MReg; i++) { + // Initialize accumulator with zeros + __m256 acc[NReg]; + int constexpr EltPadding = 128; + static_assert(NTILE % 8 == 0); + int constexpr KTILE = 4; + int constexpr UnpackElt = EltPadding / 8 / KTILE; + int constexpr TotalElt = UnpackElt * NTILE * KTILE; + int constexpr Loop128 = TotalElt / 128; + int8_t UnpackBuf[TotalElt]; + for (int i = 0; i < NReg; i++) { acc[i] = _mm256_setzero_ps(); } - uint64_t mask0 = 0x0303030303030303; - auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); - auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); - auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, - 13, 9, 5, 1, 12, 8, 4, 0); - auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); - const __m256i onesu8 = _mm256_set1_epi8(1); - const __m256i vbias = _mm256_set1_epi8(2); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + const __m256i lowMask = _mm256_set1_epi8(0x03); + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); - int constexpr KTILE = 4; - for (int ib = 0; ib < blks; ib += 1) { - __m256i iacc[NReg * MReg]; - __m256i bacc[NReg]; - for (int i = 0; i < NReg * MReg; i++) { - iacc[i] = _mm256_setzero_si256(); + auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { + __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1); + int32_t* bit1_ptr = reinterpret_cast(src2); + for (int i = 0; i < 4; i++) { + auto bit1x32 = _mm256_set1_epi32(bit1_ptr[i]); + bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1); + bit1x32 = _mm256_and_si256(bit1x32, bit1Mask); + bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2); + bit1x32 = _mm256_and_si256(highMask, bit1x32); + + auto bit2x32 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2 * i)); + auto res = _mm256_add_epi8(bit1x32, bit2x32); + res = _mm256_slli_epi32(res, 5); + _mm256_storeu_si256((__m256i*)(dst + 32 * i), res); } + }; + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg]; for (int i = 0; i < NReg; i++) { - bacc[i] = _mm256_setzero_si256(); + iacc[i] = _mm256_setzero_si256(); } if (B.zpptr) { __m256i bzp[NReg]; - auto bzptr = B.zpptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * ld_scaleb; for (int i = 0; i < NReg; i++) { bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); - bzp[i] = _mm256_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += KTILE) { - if constexpr (MTILE == 1) { - __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); - for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - vb = _mm256_sub_epi8(vb, bzp[i]); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); - iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); - b2ptr += 8 * KTILE / 4; - } - } else { - __m256i va[MReg]; - for (int i = 0; i < MReg; i++) { - va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); - } + for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { + for (int il = 0; il < Loop128; il++) { + bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); + b2ptr += 128 / 4; + b1ptr += 128 / 8; + } + for (int iu = 0; iu < UnpackElt; iu++) { + auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); + auto vabsa = _mm256_sign_epi8(va, va); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); vb = _mm256_sub_epi8(vb, bzp[i]); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); - for (int j = 0; j < MReg; j++) { - iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); - } - b2ptr += 8 * KTILE / 4; + vb = _mm256_sign_epi8(vb, va); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], vabsa, vb); } } + a8ptr += KTILE * UnpackElt; } } else { - for (int ik = 0; ik < blocksize; ik += KTILE) { - if constexpr (MTILE == 1) { - __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); - for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - vb = _mm256_sub_epi8(vb, vbias); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); - iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); - b2ptr += 8 * KTILE / 4; - } - } else { - __m256i va[MReg]; - for (int i = 0; i < MReg; i++) { - va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); - } + for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { + for (int il = 0; il < Loop128; il++) { + bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); + b2ptr += 128 / 4; + b1ptr += 128 / 8; + } + for (int iu = 0; iu < UnpackElt; iu++) { + auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); + auto vabsa = _mm256_sign_epi8(va, va); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - vb = _mm256_sub_epi8(vb, vbias); - bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); - for (int j = 0; j < MReg; j++) { - iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); - } - b2ptr += 8 * KTILE / 4; + auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); + vb = _mm256_sign_epi8(vb, va); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], vabsa, vb); } } + a8ptr += KTILE * UnpackElt; } } - gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); - gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); - } - - for (int j = 0; j < MReg; j++) { + const __m256 v_a_scale = _mm256_set1_ps(*(asptr + ib)); + auto bsptr = B.sptr + ib * ld_scaleb; for (int i = 0; i < NReg; i++) { - _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + __m256 v_b_scale; + if constexpr (std::is_same_v) { + v_b_scale = _mm256_loadu_ps(bsptr + i * 8); + } else if constexpr (std::is_same_v) { + auto tmp = _mm_loadu_si128((const __m128i*)(bsptr + i * 8)); + v_b_scale = kernel::avx2::ymm_cvt_bf16_fp32(tmp); + } + v_b_scale = _mm256_mul_ps(v_a_scale, v_b_scale); + auto tmp = _mm256_cvtepi32_ps(iacc[i]); + acc[i] = _mm256_fmadd_ps(tmp, v_b_scale, acc[i]); } } + + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8, acc[i]); + } return BTLA_CODE::Success; } -namespace vnni { - -#if CompileAVXVNNI() -#ifdef __GNUC__ -#pragma GCC push_options -#pragma GCC target("avxvnni") -#endif - template -static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto& a8ptr = A.aptr; - auto& b4ptr = B.b4ptr; - auto& asptr = A.sptr; - auto& azptr = A.zpptr; + auto b2ptr = reinterpret_cast(B.b2ptr); int blks = k / blocksize; int constexpr NReg = NTILE / 8; int constexpr MReg = MTILE; - // Initialize accumulator with zeros __m256 acc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } - uint32_t mask = 0x0f0f0f0f; - auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); const __m256i onesu8 = _mm256_set1_epi8(1); - const __m256i vbias = _mm256_set1_epi8(8); + const __m256i vbias = _mm256_set1_epi8(2); const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); - + int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { __m256i iacc[NReg * MReg]; __m256i bacc[NReg]; @@ -3057,52 +4526,56 @@ static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const ut bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); bzp[i] = _mm256_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += 4) { + for (int ik = 0; ik < blocksize; ik += KTILE) { if constexpr (MTILE == 1) { - __m256i va = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); vb = _mm256_sub_epi8(vb, bzp[i]); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; } } else { __m256i va[MReg]; for (int i = 0; i < MReg; i++) { - va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); vb = _mm256_sub_epi8(vb, bzp[i]); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); } + b2ptr += 8 * KTILE / 4; } } } } else { - for (int ik = 0; ik < blocksize; ik += 4) { + for (int ik = 0; ik < blocksize; ik += KTILE) { if constexpr (MTILE == 1) { - __m256i va = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); vb = _mm256_sub_epi8(vb, vbias); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; } } else { __m256i va[MReg]; for (int i = 0; i < MReg; i++) { - va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); vb = _mm256_sub_epi8(vb, vbias); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); } + b2ptr += 8 * KTILE / 4; } } } @@ -3121,11 +4594,9 @@ static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const ut } template -static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto& a8ptr = A.aptr; - auto& b4ptr = B.b4ptr; - auto& asptr = A.sptr; + auto b2ptr = reinterpret_cast(B.b2ptr); int blks = k / blocksize; int constexpr NReg = NTILE / 8; @@ -3134,56 +4605,64 @@ static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const ut for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } - const __m256i vbias = _mm256_set1_epi8(8); - uint32_t mask = 0x0f0f0f0f; - auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(2); const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { __m256i iacc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { iacc[i] = _mm256_setzero_si256(); } + if (B.zpptr) { __m256i bzp[NReg]; auto bzptr = B.zpptr + ib * B.ldzp; for (int i = 0; i < NReg; i++) { bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); - bzp[i] = _mm256_add_epi8(bzp[i], vbias); + bzp[i] = _mm256_add_epi8(vbias, bzp[i]); } - for (int ik = 0; ik < blocksize; ik += 4) { + for (int ik = 0; ik < blocksize; ik += KTILE) { __m256i va[MReg]; for (int i = 0; i < MReg; i++) { - va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); vb = _mm256_sub_epi8(vb, bzp[i]); for (int j = 0; j < MReg; j++) { auto vsb = _mm256_sign_epi8(vb, va[j]); auto vabsa = _mm256_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); } + b2ptr += 8 * KTILE / 4; } } } else { - for (int ik = 0; ik < blocksize; ik += 4) { + for (int ik = 0; ik < blocksize; ik += KTILE) { __m256i va[MReg]; for (int i = 0; i < MReg; i++) { - va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); vb = _mm256_sub_epi8(vb, vbias); for (int j = 0; j < MReg; j++) { auto vsb = _mm256_sign_epi8(vb, va[j]); auto vabsa = _mm256_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); } + b2ptr += 8 * KTILE / 4; } } } - gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } @@ -3195,257 +4674,229 @@ static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const ut return BTLA_CODE::Success; } -template -static inline BTLA_CODE gemv_3bit_u8s8_fp32_align128(const utils::GemvParamA& A, const utils::GemvParamB& B, - float* C, int k, int ld_scaleb, int blocksize, int8_t* tmp, - size_t tmpsize) { - auto a8ptr = A.aptr; +template +static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { auto b2ptr = reinterpret_cast(B.b2ptr); auto b1ptr = reinterpret_cast(B.b1ptr); - auto asptr = A.sptr; - auto azptr = A.zpptr; int blks = k / blocksize; int constexpr NReg = NTILE / 8; - // Initialize accumulator with zeros - __m256 acc[NReg]; - int constexpr EltPadding = 128; - static_assert(NTILE % 8 == 0); - int constexpr KTILE = 4; - int constexpr UnpackElt = EltPadding / 8 / KTILE; - int constexpr TotalElt = UnpackElt * NTILE * KTILE; - int constexpr Loop128 = TotalElt / 128; - int8_t UnpackBuf[TotalElt]; - for (int i = 0; i < NReg; i++) { + int constexpr MReg = MTILE; + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } - - uint32_t mask = 0x0f0f0f0f; - auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); const __m256i onesu8 = _mm256_set1_epi8(1); - const __m256i lowMask = _mm256_set1_epi8(0x03); + const __m256i vbias = _mm256_set1_epi8(4); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + const __m256i highMask = _mm256_set1_epi8(0x04); const __m256i bit1Mask = _mm256_set1_epi32(0x0F); const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); - const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, - 4, 4, 4, 0, 0, 0, 0); - auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { - __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1); - int32_t* bit1_ptr = reinterpret_cast(src2); - for (int i = 0; i < 4; i++) { - auto bit1x32 = _mm256_set1_epi32(bit1_ptr[i]); - bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1); - bit1x32 = _mm256_and_si256(bit1x32, bit1Mask); - bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2); - bit1x32 = _mm256_and_si256(highMask, bit1x32); - - auto bit2x32 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2 * i)); - auto res = _mm256_add_epi8(bit1x32, bit2x32); - res = _mm256_sub_epi8(res, highMask); - _mm256_storeu_si256((__m256i*)(dst + 32 * i), res); - } - }; - assert(azptr); + int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { - __m256i iacc[NReg]; + __m256i iacc[NReg * MReg]; __m256i bacc[NReg]; - for (int i = 0; i < NReg; i++) { + for (int i = 0; i < NReg * MReg; i++) { iacc[i] = _mm256_setzero_si256(); + } + for (int i = 0; i < NReg; i++) { bacc[i] = _mm256_setzero_si256(); } if (B.zpptr) { __m256i bzp[NReg]; - auto bzptr = B.zpptr + ib * ld_scaleb; + auto bzptr = B.zpptr + ib * B.ldzp; for (int i = 0; i < NReg; i++) { bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { - for (int il = 0; il < Loop128; il++) { - bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); - b2ptr += 128 / 4; - b1ptr += 128 / 8; - } - for (int iu = 0; iu < UnpackElt; iu++) { - auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; } } - a8ptr += KTILE * UnpackElt; } } else { - for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { - for (int il = 0; il < Loop128; il++) { - bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); - b2ptr += 128 / 4; - b1ptr += 128 / 8; - } - for (int iu = 0; iu < UnpackElt; iu++) { - auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; } } - a8ptr += KTILE * UnpackElt; } } - const __m256 v_a_scale = _mm256_set1_ps(*(asptr + ib)); - auto zp = int(azptr[ib]); - const __m256i v_a_zp = _mm256_set1_epi32(zp); - auto bsptr = B.sptr + ib * ld_scaleb; - for (int i = 0; i < NReg; i++) { - bacc[i] = _mm256_mullo_epi32(v_a_zp, bacc[i]); - iacc[i] = _mm256_sub_epi32(iacc[i], bacc[i]); - __m256 v_b_scale; - if constexpr (std::is_same_v) { - v_b_scale = _mm256_loadu_ps(bsptr + i * 8); - } else if constexpr (std::is_same_v) { - auto tmp = _mm_loadu_si128((const __m128i*)(bsptr + i * 8)); - v_b_scale = kernel::avx2::ymm_cvt_bf16_fp32(tmp); - } - v_b_scale = _mm256_mul_ps(v_a_scale, v_b_scale); - auto tmp = _mm256_cvtepi32_ps(iacc[i]); - acc[i] = _mm256_fmadd_ps(tmp, v_b_scale, acc[i]); - } + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } - for (int i = 0; i < NReg; i++) { - _mm256_storeu_ps(C + i * 8, acc[i]); + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } } return BTLA_CODE::Success; } -template -static inline BTLA_CODE gemv_3bit_s8s8_fp32_align128(const utils::GemvParamA& A, const utils::GemvParamB& B, - float* C, int k, int ld_scaleb, int blocksize, int8_t* tmp, - size_t tmpsize) { - auto a8ptr = A.aptr; - auto b2ptr = reinterpret_cast(B.b2ptr); - auto b1ptr = reinterpret_cast(B.b1ptr); - auto asptr = A.sptr; - auto azptr = A.zpptr; - - int blks = k / blocksize; - int constexpr NReg = NTILE / 8; - // Initialize accumulator with zeros - __m256 acc[NReg]; - int constexpr EltPadding = 128; - static_assert(NTILE % 8 == 0); - int constexpr KTILE = 4; - int constexpr UnpackElt = EltPadding / 8 / KTILE; - int constexpr TotalElt = UnpackElt * NTILE * KTILE; - int constexpr Loop128 = TotalElt / 128; - int8_t UnpackBuf[TotalElt]; - for (int i = 0; i < NReg; i++) { +template +static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } - uint32_t mask = 0x0f0f0f0f; - auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); - const __m256i lowMask = _mm256_set1_epi8(0x03); + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(4); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + const __m256i highMask = _mm256_set1_epi8(0x04); const __m256i bit1Mask = _mm256_set1_epi32(0x0F); const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); - const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, - 4, 4, 4, 0, 0, 0, 0); - auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { - __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1); - int32_t* bit1_ptr = reinterpret_cast(src2); - for (int i = 0; i < 4; i++) { - auto bit1x32 = _mm256_set1_epi32(bit1_ptr[i]); - bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1); - bit1x32 = _mm256_and_si256(bit1x32, bit1Mask); - bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2); - bit1x32 = _mm256_and_si256(highMask, bit1x32); - - auto bit2x32 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2 * i)); - auto res = _mm256_add_epi8(bit1x32, bit2x32); - res = _mm256_slli_epi32(res, 5); - _mm256_storeu_si256((__m256i*)(dst + 32 * i), res); - } - }; + int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { - __m256i iacc[NReg]; - for (int i = 0; i < NReg; i++) { + __m256i iacc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { iacc[i] = _mm256_setzero_si256(); } if (B.zpptr) { __m256i bzp[NReg]; - auto bzptr = B.zpptr + ib * ld_scaleb; + auto bzptr = B.zpptr + ib * B.ldzp; for (int i = 0; i < NReg; i++) { bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { - for (int il = 0; il < Loop128; il++) { - bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); - b2ptr += 128 / 4; - b1ptr += 128 / 8; + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } - for (int iu = 0; iu < UnpackElt; iu++) { - auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); - auto vabsa = _mm256_sign_epi8(va, va); - for (int i = 0; i < NReg; i++) { - auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); - vb = _mm256_sub_epi8(vb, bzp[i]); - vb = _mm256_sign_epi8(vb, va); - iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], vabsa, vb); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm256_sign_epi8(vb, va[j]); + auto vabsa = _mm256_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; } - a8ptr += KTILE * UnpackElt; } } else { - for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { - for (int il = 0; il < Loop128; il++) { - bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); - b2ptr += 128 / 4; - b1ptr += 128 / 8; + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } - for (int iu = 0; iu < UnpackElt; iu++) { - auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); - auto vabsa = _mm256_sign_epi8(va, va); - for (int i = 0; i < NReg; i++) { - auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); - vb = _mm256_sign_epi8(vb, va); - iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], vabsa, vb); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm256_sign_epi8(vb, va[j]); + auto vabsa = _mm256_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; } - a8ptr += KTILE * UnpackElt; } } - const __m256 v_a_scale = _mm256_set1_ps(*(asptr + ib)); - auto bsptr = B.sptr + ib * ld_scaleb; - for (int i = 0; i < NReg; i++) { - __m256 v_b_scale; - if constexpr (std::is_same_v) { - v_b_scale = _mm256_loadu_ps(bsptr + i * 8); - } else if constexpr (std::is_same_v) { - auto tmp = _mm_loadu_si128((const __m128i*)(bsptr + i * 8)); - v_b_scale = kernel::avx2::ymm_cvt_bf16_fp32(tmp); - } - v_b_scale = _mm256_mul_ps(v_a_scale, v_b_scale); - auto tmp = _mm256_cvtepi32_ps(iacc[i]); - acc[i] = _mm256_fmadd_ps(tmp, v_b_scale, acc[i]); - } + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } - for (int i = 0; i < NReg; i++) { - _mm256_storeu_ps(C + i * 8, acc[i]); + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } } return BTLA_CODE::Success; } template -static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_5bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto b2ptr = reinterpret_cast(B.b2ptr); + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); int blks = k / blocksize; int constexpr NReg = NTILE / 8; @@ -3454,16 +4905,18 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } - uint64_t mask0 = 0x0303030303030303; - auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); - auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); - auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, - 13, 9, 5, 1, 12, 8, 4, 0); - auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); const __m256i onesu8 = _mm256_set1_epi8(1); - const __m256i vbias = _mm256_set1_epi8(2); const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { __m256i iacc[NReg * MReg]; @@ -3485,11 +4938,15 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut if constexpr (MTILE == 1) { __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, bzp[i]); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); - b2ptr += 8 * KTILE / 4; + b4ptr += 8 * KTILE / 2; + b1ptr += 8 * KTILE / 8; } } else { __m256i va[MReg]; @@ -3497,13 +4954,17 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, bzp[i]); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); } - b2ptr += 8 * KTILE / 4; + b4ptr += 8 * KTILE / 2; + b1ptr += 8 * KTILE / 8; } } } @@ -3512,11 +4973,16 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut if constexpr (MTILE == 1) { __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, vbias); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); - b2ptr += 8 * KTILE / 4; + + b4ptr += 8 * KTILE / 2; + b1ptr += 8 * KTILE / 8; } } else { __m256i va[MReg]; @@ -3524,13 +4990,17 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, vbias); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); } - b2ptr += 8 * KTILE / 4; + b4ptr += 8 * KTILE / 2; + b1ptr += 8 * KTILE / 8; } } } @@ -3549,9 +5019,10 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut } template -static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_5bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto b2ptr = reinterpret_cast(B.b2ptr); + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); int blks = k / blocksize; int constexpr NReg = NTILE / 8; @@ -3560,29 +5031,30 @@ static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const ut for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } - uint64_t mask0 = 0x0303030303030303; - auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); - auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); - auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, - 13, 9, 5, 1, 12, 8, 4, 0); - auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); const __m256i onesu8 = _mm256_set1_epi8(1); - const __m256i vbias = _mm256_set1_epi8(2); const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { __m256i iacc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { iacc[i] = _mm256_setzero_si256(); } - if (B.zpptr) { __m256i bzp[NReg]; auto bzptr = B.zpptr + ib * B.ldzp; for (int i = 0; i < NReg; i++) { bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); - bzp[i] = _mm256_add_epi8(vbias, bzp[i]); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); } for (int ik = 0; ik < blocksize; ik += KTILE) { __m256i va[MReg]; @@ -3590,14 +5062,18 @@ static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, bzp[i]); for (int j = 0; j < MReg; j++) { auto vsb = _mm256_sign_epi8(vb, va[j]); auto vabsa = _mm256_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); } - b2ptr += 8 * KTILE / 4; + b4ptr += 8 * KTILE / 2; + b1ptr += 8 * KTILE / 8; } } } else { @@ -3607,17 +5083,22 @@ static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb1 = _mm256_slli_epi32(vb1, 2); + vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, vbias); for (int j = 0; j < MReg; j++) { auto vsb = _mm256_sign_epi8(vb, va[j]); auto vabsa = _mm256_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); } - b2ptr += 8 * KTILE / 4; + b4ptr += 8 * KTILE / 2; + b1ptr += 8 * KTILE / 8; } } } + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } @@ -3630,10 +5111,10 @@ static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const ut } template -static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_6bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b4ptr = reinterpret_cast(B.b4ptr); auto b2ptr = reinterpret_cast(B.b2ptr); - auto b1ptr = reinterpret_cast(B.b1ptr); int blks = k / blocksize; int constexpr NReg = NTILE / 8; @@ -3642,21 +5123,22 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } - uint64_t mask0 = 0x0303030303030303; - auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); - auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); - auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, - 13, 9, 5, 1, 12, 8, 4, 0); - auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); + const __m256i onesu8 = _mm256_set1_epi8(1); - const __m256i vbias = _mm256_set1_epi8(4); const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); - const __m256i highMask = _mm256_set1_epi8(0x04); - const __m256i bit1Mask = _mm256_set1_epi32(0x0F); - const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); - const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + uint32_t mask0 = 0x03030303; + auto vmask0 = _mm256_set1_epi32(*(int32_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { __m256i iacc[NReg * MReg]; @@ -3678,14 +5160,15 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut if constexpr (MTILE == 1) { __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, bzp[i]); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + b4ptr += 8 * KTILE / 2; b2ptr += 8 * KTILE / 4; - b1ptr += 8 * KTILE / 8; } } else { __m256i va[MReg]; @@ -3693,16 +5176,17 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, bzp[i]); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); } + b4ptr += 8 * KTILE / 2; b2ptr += 8 * KTILE / 4; - b1ptr += 8 * KTILE / 8; } } } @@ -3711,15 +5195,15 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut if constexpr (MTILE == 1) { __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, vbias); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); - + b4ptr += 8 * KTILE / 2; b2ptr += 8 * KTILE / 4; - b1ptr += 8 * KTILE / 8; } } else { __m256i va[MReg]; @@ -3727,16 +5211,17 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, vbias); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); } + b4ptr += 8 * KTILE / 2; b2ptr += 8 * KTILE / 4; - b1ptr += 8 * KTILE / 8; } } } @@ -3755,10 +5240,10 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut } template -static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_6bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b4ptr = reinterpret_cast(B.b4ptr); auto b2ptr = reinterpret_cast(B.b2ptr); - auto b1ptr = reinterpret_cast(B.b1ptr); int blks = k / blocksize; int constexpr NReg = NTILE / 8; @@ -3767,21 +5252,20 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm256_setzero_ps(); } - uint64_t mask0 = 0x0303030303030303; - auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); - auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); - auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, - 13, 9, 5, 1, 12, 8, 4, 0); - auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(FullRange); const __m256i onesu8 = _mm256_set1_epi8(1); - const __m256i vbias = _mm256_set1_epi8(4); const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); - const __m256i highMask = _mm256_set1_epi8(0x04); - const __m256i bit1Mask = _mm256_set1_epi32(0x0F); - const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); - const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + uint32_t mask0 = 0x03030303; + auto vmask0 = _mm256_set1_epi32(*(int32_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { __m256i iacc[NReg * MReg]; @@ -3801,8 +5285,9 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, bzp[i]); for (int j = 0; j < MReg; j++) { @@ -3810,8 +5295,8 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut auto vabsa = _mm256_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); } + b4ptr += 8 * KTILE / 2; b2ptr += 8 * KTILE / 4; - b1ptr += 8 * KTILE / 8; } } } else { @@ -3821,8 +5306,9 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm256_slli_epi32(vb1, 4); vb = _mm256_or_si256(vb, vb1); vb = _mm256_sub_epi8(vb, vbias); for (int j = 0; j < MReg; j++) { @@ -3830,8 +5316,8 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut auto vabsa = _mm256_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); } + b4ptr += 8 * KTILE / 2; b2ptr += 8 * KTILE / 4; - b1ptr += 8 * KTILE / 8; } } } diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index 41c82bbcf..95423f27b 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -3113,457 +3113,1604 @@ static inline BTLA_CODE decompress_kblock_s3_s8(utils::bit2x4* bit2ptr, utils::b return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s8_fp_row(int8_t* srcptr, DST_T* dstptr, int row, void* scales_, BTLA_DTYPE sdtype, - int8_t* zero_points, int k_offset, int n_offset, int blocksize, int ldzp, - int8_t* tmp, size_t tmpsize) { +static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8* bit1ptr, int8_t* dstptr, + size_t unpack_elt, int8_t* tmp, size_t tmpsize) { + int constexpr VBits = 512; + int constexpr VElt = VBits / 8; + int i = 0; + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + int elt_pad = utils::padto_le(unpack_elt, VElt); + for (; i < elt_pad; i += VElt) { + auto vout = unpack_4bits(bit4ptr + i / 2, vmask); + auto vb1 = unpack_1bits(bit1ptr + i / 8, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + vout = _mm512_or_si512(vout, vb1); + vout = _mm512_sub_epi8(vout, vbias); + _mm512_storeu_si512((__m512i*)(dstptr + i), vout); + } + if (elt_pad < unpack_elt) { + if (unpack_elt >= VElt) { + i = unpack_elt - VElt; + auto vout = unpack_4bits(bit4ptr + i / 2, vmask); + auto vb1 = unpack_1bits(bit1ptr + i / 8, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + vout = _mm512_or_si512(vout, vb1); + vout = _mm512_sub_epi8(vout, vbias); + _mm512_storeu_si512((__m512i*)(dstptr + i), vout); + } else { + ref::decompress_s5_s8(bit4ptr + i / 2, bit1ptr + i / 8, dstptr + i, unpack_elt - i, tmp, tmpsize); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s5_s8_pack1_row(utils::bit4x2* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { int constexpr VLen = 16; int constexpr NReg = NTILE / VLen; - const auto DstSize = row * NTILE * sizeof(DST_T); - const auto S8Size = row * NTILE * sizeof(int8_t); - const auto vshuf_index_low = _mm512_set_epi32(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); - const auto vshuf_index_high = _mm512_set_epi32(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8); - if (zero_points == nullptr) { - for (int ir = 0; ir < row; ir += blocksize) { - int k_remain = utils::remainsize(ir, row, blocksize); - int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + __m512i v_zp_y[NReg]; + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, zptr, NTILE * sizeof(int8_t)); + } + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 32, vmask); + auto vb1 = unpack_1bits(b1ptr + i * 8, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb4ptr = tmp; + memcpy(tmpb4ptr, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpb1ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb1ptr, bit1ptr + (ir + ib) * NTILE / 8, k_tail * NTILE / 8); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits((utils::bit2x4*)(tmpb4ptr + i * 32), vmask); + auto vb1 = unpack_1bits((utils::bit1x8*)(tmpb1ptr + i * 8), zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(tmpout + i * 64), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s5_s8_pack2_row(utils::bit4x2* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + __m512i v_zp_y[NReg]; + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + + const auto vindex = _mm512_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, + 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, + 14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + memcpy(tmp, zptr, NTILE * sizeof(int8_t)); + memcpy(tmp + NTILE, zptr, NTILE * sizeof(int8_t)); + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi16(tmp + i * 32, vindex); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, PackRow * Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += PackRow * Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 32, vmask); + auto vb1 = unpack_1bits(b1ptr + i * 8, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb4ptr = tmp; + memcpy(tmpb4ptr, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpb1ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb1ptr, bit1ptr + (ir + ib) * NTILE / 8, k_tail * NTILE / 8); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits((utils::bit2x4*)(tmpb4ptr + i * 32), vmask); + auto vb1 = unpack_1bits((utils::bit1x8*)(tmpb1ptr + i * 8), zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(tmpout + i * 64), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s5_s8_pack4_row(utils::bit4x2* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 4; + __m512i v_zp_y[NReg]; + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi32(zptr + i * 16, vindex); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 32, vmask); + auto vb1 = unpack_1bits(b1ptr + i * 8, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, + int row, int col, int8_t* tmp, size_t tmpsize) { + if (zpptr) { + typedef BTLA_CODE (*decompfunc)(utils::bit4x2 * bit4ptr, utils::bit1x8 * bit1ptr, int8_t * zpptr, int8_t * dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, + size_t tmpsize); + decompfunc func = nullptr; + if (col == NTILE) { if constexpr (PackRow == 1) { - __m512 vscale_y[NReg]; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - for (int i = 0; i < NReg; i++) vscale_y[i] = _mm512_loadu_ps(sptr + i * VLen); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * VLen); - } - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; - for (int i = 0; i < NReg; i++) { - auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen, vscale_y[i]); - store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen); - } - } - } else if constexpr (PackRow == 4) { - __m512 vscale_y[PackRow * NReg]; - for (int i = 0; i < NReg; i++) { - __m512 vraw; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - vraw = _mm512_loadu_ps(sptr + i * VLen); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - vraw = load_bf16_fp32(sptr + i * VLen); - } else { - assert(0); - } - auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); - vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); - vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); - vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); - vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); - vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + func = &decompress_kblock_s5_s8_pack1_row; + } + if constexpr (PackRow == 2) { + func = &decompress_kblock_s5_s8_pack2_row; + } + if constexpr (PackRow == 4) { + func = &decompress_kblock_s5_s8_pack4_row; + } + if (func) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + (*func)(bit4ptr, bit1ptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); } - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; - for (int i = 0; i < NReg; i++) { - for (int ip = 0; ip < PackRow; ip++) { - auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip]); - store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); - } - } + int body_size = row - head_size; + if (body_size > 0) { + (*func)(bit4ptr + head_size * NTILE / 2, bit1ptr + head_size * NTILE / 8, zpptr, dstptr + head_size * NTILE, + blocksize, ldzp, n_offset, head_end, body_size, tmp, tmpsize); } - } else if constexpr (PackRow == 2) { - __m512 vscale_y[PackRow * NReg]; - for (int i = 0; i < NReg; i++) { - __m512 vraw; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - vraw = _mm512_loadu_ps(sptr + i * VLen); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - vraw = load_bf16_fp32(sptr + i * VLen); - } - vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); - vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + return BTLA_CODE::Success; + } + } + assert(0); + return BTLA_CODE::NotSupport; + } else { + size_t elesize = static_cast(row) * col; + return decompress_s5_s8(bit4ptr, bit1ptr, dstptr, elesize, tmp, tmpsize); + } + return BTLA_CODE::Success; +} + +static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, int8_t* dstptr, + size_t unpack_elt, int8_t* tmp, size_t tmpsize) { + int constexpr VBits = 512; + int constexpr VElt = VBits / 8; + int i = 0; + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); + + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + int elt_pad = utils::padto_le(unpack_elt, VElt); + for (; i < elt_pad; i += VElt) { + auto vout = unpack_4bits(bit4ptr + i / 2, vmask); + auto vb1 = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); + vout = _mm512_or_si512(vout, vb1); + vout = _mm512_sub_epi8(vout, vbias); + _mm512_storeu_si512((__m512i*)(dstptr + i), vout); + } + if (elt_pad < unpack_elt) { + if (unpack_elt >= VElt) { + i = unpack_elt - VElt; + auto vout = unpack_4bits(bit4ptr + i / 2, vmask); + auto vb1 = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); + vout = _mm512_or_si512(vout, vb1); + vout = _mm512_sub_epi8(vout, vbias); + _mm512_storeu_si512((__m512i*)(dstptr + i), vout); + } else { + ref::decompress_s6_s8(bit4ptr + i / 2, bit2ptr + i / 4, dstptr + i, unpack_elt - i, tmp, tmpsize); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s6_s8_pack1_row(utils::bit4x2* srcptr, utils::bit2x4* bit2ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + __m512i v_zp_y[NReg]; + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); + + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, zptr, NTILE * sizeof(int8_t)); + } + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b2ptr = bit2ptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 32, vmask); + auto vb1 = unpack_2bits(b2ptr + i * 16, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb4ptr = tmp; + memcpy(tmpb4ptr, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpb2ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb2ptr, bit2ptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits((utils::bit2x4*)(tmpb4ptr + i * 32), vmask); + auto vb1 = unpack_2bits((utils::bit2x4*)(tmpb2ptr + i * 16), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(tmpout + i * 64), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s6_s8_pack2_row(utils::bit4x2* srcptr, utils::bit2x4* bit2ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + __m512i v_zp_y[NReg]; + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); + + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + const auto vindex = _mm512_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, + 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, + 14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + memcpy(tmp, zptr, NTILE * sizeof(int8_t)); + memcpy(tmp + NTILE, zptr, NTILE * sizeof(int8_t)); + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi16(tmp + i * 32, vindex); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, PackRow * Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += PackRow * Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b2ptr = bit2ptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 32, vmask); + auto vb1 = unpack_2bits(b2ptr + i * 16, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb4ptr = tmp; + memcpy(tmpb4ptr, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpb2ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb2ptr, bit2ptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits((utils::bit2x4*)(tmpb4ptr + i * 32), vmask); + auto vb1 = unpack_2bits((utils::bit2x4*)(tmpb2ptr + i * 16), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(tmpout + i * 64), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s6_s8_pack4_row(utils::bit4x2* srcptr, utils::bit2x4* bit2ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 4; + __m512i v_zp_y[NReg]; + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); + + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi32(zptr + i * 16, vindex); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + auto b2ptr = bit2ptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 32, vmask); + auto vb1 = unpack_2bits(b2ptr + i * 16, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, + int row, int col, int8_t* tmp, size_t tmpsize) { + if (zpptr) { + typedef BTLA_CODE (*decompfunc)(utils::bit4x2 * bit4ptr, utils::bit2x4 * bit2ptr, int8_t * zpptr, int8_t * dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, + size_t tmpsize); + decompfunc func = nullptr; + if (col == NTILE) { + if constexpr (PackRow == 1) { + func = &decompress_kblock_s6_s8_pack1_row; + } + if constexpr (PackRow == 2) { + func = &decompress_kblock_s6_s8_pack2_row; + } + if constexpr (PackRow == 4) { + func = &decompress_kblock_s6_s8_pack4_row; + } + if (func) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + (*func)(bit4ptr, bit2ptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); } - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; - for (int i = 0; i < NReg; i++) { - for (int ip = 0; ip < PackRow; ip++) { - auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip]); - store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); - } + int body_size = row - head_size; + if (body_size > 0) { + (*func)(bit4ptr + head_size * NTILE / 2, bit2ptr + head_size * NTILE / 4, zpptr, dstptr + head_size * NTILE, + blocksize, ldzp, n_offset, head_end, body_size, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + } + assert(0); + return BTLA_CODE::NotSupport; + } else { + size_t elesize = static_cast(row) * col; + return decompress_s6_s8(bit4ptr, bit2ptr, dstptr, elesize, tmp, tmpsize); + } + return BTLA_CODE::Success; +} + +template +inline BTLA_CODE decompress_kblock_s8_fp_row(int8_t* srcptr, DST_T* dstptr, int row, void* scales_, BTLA_DTYPE sdtype, + int8_t* zero_points, int k_offset, int n_offset, int blocksize, int ldzp, + int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + const auto vshuf_index_low = _mm512_set_epi32(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); + const auto vshuf_index_high = _mm512_set_epi32(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8); + if (zero_points == nullptr) { + for (int ir = 0; ir < row; ir += blocksize) { + int k_remain = utils::remainsize(ir, row, blocksize); + int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; + if constexpr (PackRow == 1) { + __m512 vscale_y[NReg]; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * VLen); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen, vscale_y[i]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen); + } + } + } else if constexpr (PackRow == 4) { + __m512 vscale_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m512 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * VLen); + } else { + assert(0); + } + auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); + } + } + } + } else if constexpr (PackRow == 2) { + __m512 vscale_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m512 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * VLen); + } + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); + } + } + } + } else { + assert(0); + } + } + return BTLA_CODE::Success; + } else { + for (int ir = 0; ir < row; ir += blocksize) { + int k_remain = utils::remainsize(ir, row, blocksize); + int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; + if constexpr (PackRow == 1) { + __m512 vscale_y[NReg]; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * VLen); + } + __m512i vzp_y[NReg]; + for (int i = 0; i < NReg; i++) vzp_y[i] = load_s8_s32(zero_points + ele_off + i * VLen); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen, vscale_y[i], vzp_y[i]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen); + } + } + } else if constexpr (PackRow == 4) { + __m512 vscale_y[PackRow * NReg]; + __m512i vzp_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m512 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * VLen); + } else { + assert(0); + } + auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + + auto tmp = load_s8_s32(zero_points + ele_off + i * VLen); + auto vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); + vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); + vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); + vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); + vzp_y[i * PackRow + 2] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); + vzp_y[i * PackRow + 3] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip], + vzp_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); + } + } + } + } else if constexpr (PackRow == 2) { + __m512 vscale_y[PackRow * NReg]; + __m512i vzp_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m512 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * VLen); + } + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + auto tmp = load_s8_s32(zero_points + ele_off + i * VLen); + vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); + vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip], + vzp_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); + } + } + } + } else { + assert(0); + } + } + return BTLA_CODE::Success; + } +} + +template +inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s8_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s8_fp_row(srcptr + head_size * NTILE, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} +template +inline BTLA_CODE decompress_kblock_s4_fp_row(utils::int4x2* srcptr, DST_T* dstptr, int row, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s4_s8(srcptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, + row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s4_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s4_fp_row(srcptr + head_size * NTILE / 2, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s2_fp_row(utils::bit2x4* b2ptr, DST_T* dstptr, int row, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s2_s8(b2ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, + row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s2_fp(utils::bit2x4* b2ptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s2_fp_row(b2ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s2_fp_row(b2ptr + head_size * NTILE / 4, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s3_fp_row(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s3_s8(b2ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, + k_offset, row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s3_fp(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, int col, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s3_fp_row(b2ptr, b1ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s3_fp_row( + b2ptr + head_size * NTILE / 4, b1ptr + head_size * NTILE / 8, dstptr + head_size * NTILE, body_size, scales_, + sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s5_fp_row(utils::bit4x2* b4ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s5_s8(b4ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, + k_offset, row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s5_fp(utils::bit4x2* b4ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, int col, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s5_fp_row(b4ptr, b1ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s5_fp_row( + b4ptr + head_size * NTILE / 2, b1ptr + head_size * NTILE / 8, dstptr + head_size * NTILE, body_size, scales_, + sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s6_fp_row(utils::bit4x2* b4ptr, utils::bit2x4* b2ptr, DST_T* dstptr, int row, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s6_s8(b4ptr, b2ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, + k_offset, row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s6_fp(utils::bit4x2* b4ptr, utils::bit2x4* b2ptr, DST_T* dstptr, int row, int col, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s6_fp_row(b4ptr, b2ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s6_fp_row( + b4ptr + head_size * NTILE / 2, b2ptr + head_size * NTILE / 4, dstptr + head_size * NTILE, body_size, scales_, + sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +static inline __m512 load_T_fp32(const T* srcptr) { + __m512 vtmp; + if constexpr (std::is_same_v) { + vtmp = _mm512_loadu_ps(srcptr); + } else if constexpr (std::is_same_v) { + vtmp = load_bf16_fp32(srcptr); + } else { + assert(0); + } + return vtmp; +} + +static inline __m512 load_s8_fp32(int8_t* srcptr) { + auto src_y = load_s8_s32(srcptr); + auto dst_y = _mm512_cvtepi32_ps(src_y); + return dst_y; +} + +static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) { + __m512i zero = _mm512_setzero_si512(); + __mmask64 blt0 = _mm512_movepi8_mask(b); + return _mm512_mask_sub_epi8(a, blt0, zero, a); + ; +} + +template +static inline void gemv_dequant_s32fp32(const float* asptr, int ldzp, const ScaleT* bsptr, __m512i* iacc, + __m512* facc) { + __m512 v_a_scale[MTILE]; + for (int im = 0; im < MTILE; im++) { + v_a_scale[im] = _mm512_set1_ps(*(asptr + im * ldzp)); + } + + for (int i = 0; i < NReg; i++) { + __m512 v_b_scale = load_T_fp32(bsptr + i * 16); + for (int im = 0; im < MTILE; im++) { + auto vtmp = _mm512_mul_ps(v_a_scale[im], v_b_scale); + auto tmp = _mm512_cvtepi32_ps(iacc[im * NReg + i]); + facc[im * NReg + i] = _mm512_fmadd_ps(tmp, vtmp, facc[im * NReg + i]); + } + } +} + +template +static inline void gemv_remove_zp(const uint8_t* azptr, int ldzp, __m512i* iacc, __m512i* bacc) { + if constexpr (MReg == 1) { + auto zp = int(azptr[0]); + __m512i v_a_zp = _mm512_set1_epi32(zp); + for (int in = 0; in < NReg; in++) { + auto vtmp = _mm512_mullo_epi32(v_a_zp, bacc[in]); + iacc[in] = _mm512_sub_epi32(iacc[in], vtmp); + } + } else { + __m512i v_a_zp[MReg]; + for (int im = 0; im < MReg; im++) { + auto zp = int(azptr[im * ldzp]); + v_a_zp[im] = _mm512_set1_epi32(zp); + for (int in = 0; in < NReg; in++) { + auto vtmp = _mm512_mullo_epi32(v_a_zp[im], bacc[in]); + iacc[im * NReg + in] = _mm512_sub_epi32(iacc[im * NReg + in], vtmp); + } + } + } +} + +template +static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m512* vacc, __m512* vsca) { + if constexpr (MTILE == 1) { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m512 va = _mm512_set1_ps(*(Aptr + ikk)); + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); + ftmp = _mm512_mul_ps(ftmp, vsca[i]); + vacc[i] = _mm512_fmadd_ps(va, ftmp, vacc[i]); + } + } + } else { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m512 va[MTILE]; + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); + ftmp = _mm512_mul_ps(ftmp, vsca[i]); + for (int im = 0; im < MTILE; im++) { + if (i == 0) { + va[im] = _mm512_set1_ps(*(Aptr + ikk + im * lda)); } + vacc[im * NReg + i] = _mm512_fmadd_ps(va[im], ftmp, vacc[im * NReg + i]); } - } else { - assert(0); } } - return BTLA_CODE::Success; + } +} + +template +static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m512* vacc_loc) { + if constexpr (MTILE == 1) { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m512 va = _mm512_set1_ps(*(Aptr + ikk)); + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); + vacc_loc[i] = _mm512_fmadd_ps(va, ftmp, vacc_loc[i]); + } + } } else { - for (int ir = 0; ir < row; ir += blocksize) { - int k_remain = utils::remainsize(ir, row, blocksize); - int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; - if constexpr (PackRow == 1) { - __m512 vscale_y[NReg]; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - for (int i = 0; i < NReg; i++) vscale_y[i] = _mm512_loadu_ps(sptr + i * VLen); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * VLen); - } - __m512i vzp_y[NReg]; - for (int i = 0; i < NReg; i++) vzp_y[i] = load_s8_s32(zero_points + ele_off + i * VLen); - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; - for (int i = 0; i < NReg; i++) { - auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen, vscale_y[i], vzp_y[i]); - store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen); + for (int ikk = 0; ikk < Unroll; ikk++) { + __m512 va[MTILE]; + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); + for (int im = 0; im < MTILE; im++) { + if (i == 0) { + va[im] = _mm512_set1_ps(*(Aptr + ikk + im * lda)); } + vacc_loc[im * NReg + i] = _mm512_fmadd_ps(va[im], ftmp, vacc_loc[im * NReg + i]); } - } else if constexpr (PackRow == 4) { - __m512 vscale_y[PackRow * NReg]; - __m512i vzp_y[PackRow * NReg]; + } + } + } +} + +template +static inline BTLA_CODE gemv_4bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& b4ptr = B.b4ptr; + int blks = k / blocksize; + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(8); + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + __m512 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * VLen); + } + + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { for (int i = 0; i < NReg; i++) { - __m512 vraw; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - vraw = _mm512_loadu_ps(sptr + i * VLen); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - vraw = load_bf16_fp32(sptr + i * VLen); - } else { - assert(0); - } - auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); - vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); - vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); - vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); - vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); - vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, bzp[i]); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); + } - auto tmp = load_s8_s32(zero_points + ele_off + i * VLen); - auto vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); - vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); - vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); - vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); - vzp_y[i * PackRow + 2] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); - vzp_y[i * PackRow + 3] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, vbias); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); } - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; - for (int i = 0; i < NReg; i++) { - for (int ip = 0; ip < PackRow; ip++) { - auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip], - vzp_y[i * PackRow + ip]); - store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); - } - } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); + } + } + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_2bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = (utils::bit2x4*)B.b2ptr; + int constexpr VLen = 16; + int blks = k / blocksize; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(2); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + + __m512 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm512_setzero_ps(); + } + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, bzp[i]); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b2ptr += VLen * Unroll / 4; } - } else if constexpr (PackRow == 2) { - __m512 vscale_y[PackRow * NReg]; - __m512i vzp_y[PackRow * NReg]; + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, vbias); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b2ptr += VLen * Unroll / 4; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + } + + __m512 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * VLen); + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NReg; in++) { + acc[im * NReg + in] = _mm512_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); + } + } + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = (utils::bit2x4*)B.b2ptr; + auto b1ptr = (utils::bit1x8*)B.b1ptr; + + int constexpr VLen = 16; + int blks = k / blocksize; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(4); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + + __m512 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm512_setzero_ps(); + } + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { for (int i = 0; i < NReg; i++) { - __m512 vraw; - if (sdtype == BTLA_DTYPE::F32) { - auto sptr = (float*)scales_ + ele_off; - vraw = _mm512_loadu_ps(sptr + i * VLen); - } else if (sdtype == BTLA_DTYPE::BF16) { - auto sptr = (utils::bf16*)scales_ + ele_off; - vraw = load_bf16_fp32(sptr + i * VLen); - } - vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); - vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); - auto tmp = load_s8_s32(zero_points + ele_off + i * VLen); - vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); - vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, bzp[i]); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b2ptr += VLen * Unroll / 4; + b1ptr += VLen * Unroll / 8; } - for (int ib = 0; ib < k_remain; ib += PackRow) { - auto b8ptr = srcptr + (ir + ib) * NTILE; - for (int i = 0; i < NReg; i++) { - for (int ip = 0; ip < PackRow; ip++) { - auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip], - vzp_y[i * PackRow + ip]); - store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); - } - } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, vbias); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b2ptr += VLen * Unroll / 4; + b1ptr += VLen * Unroll / 8; } - } else { - assert(0); + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } } - return BTLA_CODE::Success; - } -} -template -inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, DST_T* dstptr, int row, int col, void* scales_, - BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, - int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - auto ret = BTLA_CODE::NotSupport; - if (col == NTILE) { - int head_end = utils::padto(k_offset, blocksize); - head_end = std::min(head_end, k_offset + row); - int head_size = head_end - k_offset; - if (head_size > 0) { - decompress_kblock_s8_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, - k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + __m512 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * VLen); } - int body_size = row - head_size; - if (body_size > 0) { - decompress_kblock_s8_fp_row(srcptr + head_size * NTILE, dstptr + head_size * NTILE, - body_size, scales_, sdtype, zero_points, head_end, n_offset, - blocksize, ldzp, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NReg; in++) { + acc[im * NReg + in] = _mm512_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); + } } - return BTLA_CODE::Success; } - return ret; -} -template -inline BTLA_CODE decompress_kblock_s4_fp_row(utils::int4x2* srcptr, DST_T* dstptr, int row, void* scales_, - BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, - int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - int constexpr NReg = NTILE / 8; - const auto DstSize = row * NTILE * sizeof(DST_T); - const auto S8Size = row * NTILE * sizeof(int8_t); - auto tmps8ptr = (int8_t*)dstptr; - tmps8ptr += DstSize - S8Size; - auto ret = decompress_kblock_s4_s8(srcptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, - row, NTILE, tmp, tmpsize); - assert(ret == BTLA_CODE::Success); - return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, - n_offset, blocksize, ldzp, tmp, tmpsize); -} -template -inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, DST_T* dstptr, int row, int col, void* scales_, - BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, - int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - auto ret = BTLA_CODE::NotSupport; - if (col == NTILE) { - int head_end = utils::padto(k_offset, blocksize); - head_end = std::min(head_end, k_offset + row); - int head_size = head_end - k_offset; - if (head_size > 0) { - decompress_kblock_s4_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, - k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); - } - int body_size = row - head_size; - if (body_size > 0) { - decompress_kblock_s4_fp_row(srcptr + head_size * NTILE / 2, dstptr + head_size * NTILE, - body_size, scales_, sdtype, zero_points, head_end, n_offset, - blocksize, ldzp, tmp, tmpsize); + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); } - return BTLA_CODE::Success; } - return ret; + return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s2_fp_row(utils::bit2x4* b2ptr, DST_T* dstptr, int row, void* scales_, - BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, - int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - int constexpr NReg = NTILE / 8; - const auto DstSize = row * NTILE * sizeof(DST_T); - const auto S8Size = row * NTILE * sizeof(int8_t); - auto tmps8ptr = (int8_t*)dstptr; - tmps8ptr += DstSize - S8Size; - auto ret = decompress_kblock_s2_s8(b2ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, - row, NTILE, tmp, tmpsize); - assert(ret == BTLA_CODE::Success); - return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, - n_offset, blocksize, ldzp, tmp, tmpsize); -} +template +static inline BTLA_CODE gemv_5bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b4ptr = (utils::bit4x2*)B.b4ptr; + auto b1ptr = (utils::bit1x8*)B.b1ptr; -template -inline BTLA_CODE decompress_kblock_s2_fp(utils::bit2x4* b2ptr, DST_T* dstptr, int row, int col, void* scales_, - BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, - int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - auto ret = BTLA_CODE::NotSupport; - if (col == NTILE) { - int head_end = utils::padto(k_offset, blocksize); - head_end = std::min(head_end, k_offset + row); - int head_size = head_end - k_offset; - if (head_size > 0) { - decompress_kblock_s2_fp_row(b2ptr, dstptr, head_size, scales_, sdtype, zero_points, - k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); - } - int body_size = row - head_size; - if (body_size > 0) { - decompress_kblock_s2_fp_row(b2ptr + head_size * NTILE / 4, dstptr + head_size * NTILE, - body_size, scales_, sdtype, zero_points, head_end, n_offset, - blocksize, ldzp, tmp, tmpsize); - } - return BTLA_CODE::Success; + int constexpr VLen = 16; + int blks = k / blocksize; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); } - return ret; -} + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); -template -inline BTLA_CODE decompress_kblock_s3_fp_row(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, - void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, - int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - int constexpr NReg = NTILE / 8; - const auto DstSize = row * NTILE * sizeof(DST_T); - const auto S8Size = row * NTILE * sizeof(int8_t); - auto tmps8ptr = (int8_t*)dstptr; - tmps8ptr += DstSize - S8Size; - auto ret = decompress_kblock_s3_s8(b2ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, - k_offset, row, NTILE, tmp, tmpsize); - assert(ret == BTLA_CODE::Success); - return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, - n_offset, blocksize, ldzp, tmp, tmpsize); -} + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; -template -inline BTLA_CODE decompress_kblock_s3_fp(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, int col, - void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, - int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { - auto ret = BTLA_CODE::NotSupport; - if (col == NTILE) { - int head_end = utils::padto(k_offset, blocksize); - head_end = std::min(head_end, k_offset + row); - int head_size = head_end - k_offset; - if (head_size > 0) { - decompress_kblock_s3_fp_row(b2ptr, b1ptr, dstptr, head_size, scales_, sdtype, zero_points, - k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + __m512 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm512_setzero_ps(); + } + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, bzp[i]); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b4ptr += VLen * Unroll / 2; + b1ptr += VLen * Unroll / 8; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, vbias); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b4ptr += VLen * Unroll / 2; + b1ptr += VLen * Unroll / 8; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } } - int body_size = row - head_size; - if (body_size > 0) { - decompress_kblock_s3_fp_row( - b2ptr + head_size * NTILE / 4, b1ptr + head_size * NTILE / 8, dstptr + head_size * NTILE, body_size, scales_, - sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize); + + __m512 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * VLen); + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NReg; in++) { + acc[im * NReg + in] = _mm512_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); + } } - return BTLA_CODE::Success; } - return ret; -} -template -static inline __m512 load_T_fp32(const T* srcptr) { - __m512 vtmp; - if constexpr (std::is_same_v) { - vtmp = _mm512_loadu_ps(srcptr); - } else if constexpr (std::is_same_v) { - vtmp = load_bf16_fp32(srcptr); - } else { - assert(0); + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } } - return vtmp; -} - -static inline __m512 load_s8_fp32(int8_t* srcptr) { - auto src_y = load_s8_s32(srcptr); - auto dst_y = _mm512_cvtepi32_ps(src_y); - return dst_y; + return BTLA_CODE::Success; } -static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) { - __m512i zero = _mm512_setzero_si512(); - __mmask64 blt0 = _mm512_movepi8_mask(b); - return _mm512_mask_sub_epi8(a, blt0, zero, a); - ; -} +template +static inline BTLA_CODE gemv_6bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b4ptr = (utils::bit4x2*)B.b4ptr; + auto b2ptr = (utils::bit2x4*)B.b2ptr; -template -static inline void gemv_dequant_s32fp32(const float* asptr, int ldzp, const ScaleT* bsptr, __m512i* iacc, - __m512* facc) { - __m512 v_a_scale[MTILE]; - for (int im = 0; im < MTILE; im++) { - v_a_scale[im] = _mm512_set1_ps(*(asptr + im * ldzp)); + int constexpr VLen = 16; + int blks = k / blocksize; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); } + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); - for (int i = 0; i < NReg; i++) { - __m512 v_b_scale = load_T_fp32(bsptr + i * 16); - for (int im = 0; im < MTILE; im++) { - auto vtmp = _mm512_mul_ps(v_a_scale[im], v_b_scale); - auto tmp = _mm512_cvtepi32_ps(iacc[im * NReg + i]); - facc[im * NReg + i] = _mm512_fmadd_ps(tmp, vtmp, facc[im * NReg + i]); + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + + __m512 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm512_setzero_ps(); } - } -} + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); -template -static inline void gemv_remove_zp(const uint8_t* azptr, int ldzp, __m512i* iacc, __m512i* bacc) { - if constexpr (MReg == 1) { - auto zp = int(azptr[0]); - __m512i v_a_zp = _mm512_set1_epi32(zp); - for (int in = 0; in < NReg; in++) { - auto vtmp = _mm512_mullo_epi32(v_a_zp, bacc[in]); - iacc[in] = _mm512_sub_epi32(iacc[in], vtmp); + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, bzp[i]); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b4ptr += VLen * Unroll / 2; + b2ptr += VLen * Unroll / 4; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, vbias); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b4ptr += VLen * Unroll / 2; + b2ptr += VLen * Unroll / 4; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } } - } else { - __m512i v_a_zp[MReg]; - for (int im = 0; im < MReg; im++) { - auto zp = int(azptr[im * ldzp]); - v_a_zp[im] = _mm512_set1_epi32(zp); + + __m512 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * VLen); + } + for (int im = 0; im < MTILE; im++) { for (int in = 0; in < NReg; in++) { - auto vtmp = _mm512_mullo_epi32(v_a_zp[im], bacc[in]); - iacc[im * NReg + in] = _mm512_sub_epi32(iacc[im * NReg + in], vtmp); + acc[im * NReg + in] = _mm512_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); } } } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; } -template -static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m512* vacc, __m512* vsca) { - if constexpr (MTILE == 1) { - for (int ikk = 0; ikk < Unroll; ikk++) { - __m512 va = _mm512_set1_ps(*(Aptr + ikk)); - for (int i = 0; i < NReg; i++) { - auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); - ftmp = _mm512_mul_ps(ftmp, vsca[i]); - vacc[i] = _mm512_fmadd_ps(va, ftmp, vacc[i]); - } +namespace vnni { + +#if CompileAVX512VNNI() +#ifdef __GNUC__ +#pragma GCC push_options +#pragma GCC target("avx512vnni") +#endif + +template +static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& a8ptr = A.aptr; + auto& b4ptr = B.b4ptr; + auto& asptr = A.sptr; + auto& azptr = A.zpptr; + int constexpr VLen = 16; + int blks = k / blocksize; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + const __m512i onesu8 = _mm512_set1_epi8(1); + const __m512i vbias = _mm512_set1_epi8(8); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + + for (int ib = 0; ib < blks; ib += 1) { + __m512i iacc[NReg * MReg]; + __m512i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm512_setzero_si512(); } - } else { - for (int ikk = 0; ikk < Unroll; ikk++) { - __m512 va[MTILE]; + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm512_setzero_si512(); + } + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; for (int i = 0; i < NReg; i++) { - auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); - ftmp = _mm512_mul_ps(ftmp, vsca[i]); - for (int im = 0; im < MTILE; im++) { - if (i == 0) { - va[im] = _mm512_set1_ps(*(Aptr + ikk + im * lda)); + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * VLen, vindex); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += 4) { + if constexpr (MTILE == 1) { + __m512i va = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, bzp[i]); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + } + } else { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, bzp[i]); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); + } + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += 4) { + if constexpr (MTILE == 1) { + __m512i va = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, vbias); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + } + } else { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, vbias); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); + } } - vacc[im * NReg + i] = _mm512_fmadd_ps(va[im], ftmp, vacc[im * NReg + i]); } } } + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } -} -template -static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m512* vacc_loc) { - if constexpr (MTILE == 1) { - for (int ikk = 0; ikk < Unroll; ikk++) { - __m512 va = _mm512_set1_ps(*(Aptr + ikk)); - for (int i = 0; i < NReg; i++) { - auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); - vacc_loc[i] = _mm512_fmadd_ps(va, ftmp, vacc_loc[i]); - } - } - } else { - for (int ikk = 0; ikk < Unroll; ikk++) { - __m512 va[MTILE]; - for (int i = 0; i < NReg; i++) { - auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); - for (int im = 0; im < MTILE; im++) { - if (i == 0) { - va[im] = _mm512_set1_ps(*(Aptr + ikk + im * lda)); - } - vacc_loc[im * NReg + i] = _mm512_fmadd_ps(va[im], ftmp, vacc_loc[im * NReg + i]); - } - } + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); } } + return BTLA_CODE::Success; } template -static inline BTLA_CODE gemv_4bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& a8ptr = A.aptr; auto& b4ptr = B.b4ptr; + auto& asptr = A.sptr; + int blks = k / blocksize; int constexpr VLen = 16; int constexpr NReg = NTILE / VLen; @@ -3573,50 +4720,58 @@ static inline BTLA_CODE gemv_4bit_fp32_fp32(const float* A, int lda, const utils for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm512_setzero_ps(); } + const __m512i vbias = _mm512_set1_epi8(8); uint32_t mask = 0x0f0f0f0f; auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); - auto vbias = _mm512_set1_epi8(8); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); for (int ib = 0; ib < blks; ib += 1) { - auto bsptr = B.sptr + ib * B.ldzp; - __m512 v_b_scale[NReg]; - for (int i = 0; i < NReg; i++) { - v_b_scale[i] = load_T_fp32(bsptr + i * VLen); + __m512i iacc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm512_setzero_si512(); } - - int constexpr Unroll = 4; - assert((blocksize % 4) == 0); - assert(tmpsize >= NTILE * Unroll); - if (B.zpptr) { __m512i bzp[NReg]; auto bzptr = B.zpptr + ib * B.ldzp; - - for (int i = 0; i < Unroll; i++) { - memcpy(tmp + i * NTILE, bzptr, NTILE); - } for (int i = 0; i < NReg; i++) { - bzp[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * VLen, vindex); bzp[i] = _mm512_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int ik = 0; ik < blocksize; ik += 4) { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } for (int i = 0; i < NReg; i++) { auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); vb = _mm512_sub_epi8(vb, bzp[i]); - _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm512_sign_epi8(vb, va[j]); + auto vabsa = _mm512_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); + } } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); } - } else { - for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int ik = 0; ik < blocksize; ik += 4) { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } for (int i = 0; i < NReg; i++) { auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); vb = _mm512_sub_epi8(vb, vbias); - _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm512_sign_epi8(vb, va[j]); + auto vabsa = _mm512_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); + } } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); } } + + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } for (int j = 0; j < MReg; j++) { @@ -3628,9 +4783,9 @@ static inline BTLA_CODE gemv_4bit_fp32_fp32(const float* A, int lda, const utils } template -static inline BTLA_CODE gemv_2bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto b2ptr = (utils::bit2x4*)B.b2ptr; + auto b2ptr = reinterpret_cast(B.b2ptr); int constexpr VLen = 16; int blks = k / blocksize; int constexpr NReg = NTILE / VLen; @@ -3639,6 +4794,11 @@ static inline BTLA_CODE gemv_2bit_fp32_fp32(const float* A, int lda, const utils for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm512_setzero_ps(); } + + const auto onesu8 = _mm512_set1_epi8(1); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); uint64_t mask0 = 0x0303030303030303; auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); auto vbias = _mm512_set1_epi8(2); @@ -3647,60 +4807,80 @@ static inline BTLA_CODE gemv_2bit_fp32_fp32(const float* A, int lda, const utils 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); - - int constexpr KTILE = 1; + int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { - auto bsptr = B.sptr + ib * B.ldzp; - - __m512 acc_loc[NReg * MReg]; + __m512i iacc[NReg * MReg]; + __m512i bacc[NReg]; for (int i = 0; i < NReg * MReg; i++) { - acc_loc[i] = _mm512_setzero_ps(); + iacc[i] = _mm512_setzero_si512(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm512_setzero_si512(); } - int constexpr Unroll = 4; - assert((blocksize % 4) == 0); - assert(tmpsize >= NTILE * Unroll); - if (B.zpptr) { __m512i bzp[NReg]; auto bzptr = B.zpptr + ib * B.ldzp; - for (int i = 0; i < Unroll; i++) { - memcpy(tmp + i * NTILE, bzptr, NTILE); - } for (int i = 0; i < NReg; i++) { - bzp[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 16, vindex); bzp[i] = _mm512_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += Unroll) { - for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - vb = _mm512_sub_epi8(vb, bzp[i]); - _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); - b2ptr += VLen * Unroll / 4; + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, bzp[i]); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + b2ptr += VLen * KTILE / 4; + } + } else { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, bzp[i]); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += VLen * KTILE / 4; + } } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } - } else { - for (int ik = 0; ik < blocksize; ik += Unroll) { - for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - vb = _mm512_sub_epi8(vb, vbias); - _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); - b2ptr += VLen * Unroll / 4; + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, vbias); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + b2ptr += VLen * KTILE / 4; + } + } else { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, vbias); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += VLen * KTILE / 4; + } } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } } - __m512 v_b_scale[NReg]; - for (int i = 0; i < NReg; i++) { - v_b_scale[i] = load_T_fp32(bsptr + i * VLen); - } - for (int im = 0; im < MTILE; im++) { - for (int in = 0; in < NReg; in++) { - acc[im * NReg + in] = _mm512_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); - } - } + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } for (int j = 0; j < MReg; j++) { @@ -3712,11 +4892,9 @@ static inline BTLA_CODE gemv_2bit_fp32_fp32(const float* A, int lda, const utils } template -static inline BTLA_CODE gemv_3bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto b2ptr = (utils::bit2x4*)B.b2ptr; - auto b1ptr = (utils::bit1x8*)B.b1ptr; - + auto b2ptr = reinterpret_cast(B.b2ptr); int constexpr VLen = 16; int blks = k / blocksize; int constexpr NReg = NTILE / VLen; @@ -3725,76 +4903,68 @@ static inline BTLA_CODE gemv_3bit_fp32_fp32(const float* A, int lda, const utils for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm512_setzero_ps(); } + + const auto onesu8 = _mm512_set1_epi8(1); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); uint64_t mask0 = 0x0303030303030303; auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); - auto vbias = _mm512_set1_epi8(4); + auto vbias = _mm512_set1_epi8(2); auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); - - auto zmm_0x04 = _mm512_set1_epi8(0x04); - auto zmm_0x00 = _mm512_set1_epi8(0x00); - int constexpr KTILE = 1; + int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { - auto bsptr = B.sptr + ib * B.ldzp; - - __m512 acc_loc[NReg * MReg]; + __m512i iacc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { - acc_loc[i] = _mm512_setzero_ps(); + iacc[i] = _mm512_setzero_si512(); } - int constexpr Unroll = 4; - assert((blocksize % 4) == 0); - assert(tmpsize >= NTILE * Unroll); if (B.zpptr) { __m512i bzp[NReg]; auto bzptr = B.zpptr + ib * B.ldzp; - for (int i = 0; i < Unroll; i++) { - memcpy(tmp + i * NTILE, bzptr, NTILE); - } for (int i = 0; i < NReg; i++) { - bzp[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); - bzp[i] = _mm512_add_epi8(bzp[i], vbias); + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 16, vindex); + bzp[i] = _mm512_add_epi8(vbias, bzp[i]); } - for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } for (int i = 0; i < NReg; i++) { auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); - vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, bzp[i]); - _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); - b2ptr += VLen * Unroll / 4; - b1ptr += VLen * Unroll / 8; + for (int j = 0; j < MReg; j++) { + auto vsb = _mm512_sign_epi8(vb, va[j]); + auto vabsa = _mm512_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); + } + b2ptr += VLen * KTILE / 4; } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); } - } else { - for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } for (int i = 0; i < NReg; i++) { auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); - vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, vbias); - _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); - b2ptr += VLen * Unroll / 4; - b1ptr += VLen * Unroll / 8; + for (int j = 0; j < MReg; j++) { + auto vsb = _mm512_sign_epi8(vb, va[j]); + auto vabsa = _mm512_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); + } + b2ptr += VLen * KTILE / 4; } - accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); - } - } - - __m512 v_b_scale[NReg]; - for (int i = 0; i < NReg; i++) { - v_b_scale[i] = load_T_fp32(bsptr + i * VLen); - } - for (int im = 0; im < MTILE; im++) { - for (int in = 0; in < NReg; in++) { - acc[im * NReg + in] = _mm512_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); } } + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } for (int j = 0; j < MReg; j++) { @@ -3805,38 +4975,36 @@ static inline BTLA_CODE gemv_3bit_fp32_fp32(const float* A, int lda, const utils return BTLA_CODE::Success; } -namespace vnni { - -#if CompileAVX512VNNI() -#ifdef __GNUC__ -#pragma GCC push_options -#pragma GCC target("avx512vnni") -#endif - template -static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto& a8ptr = A.aptr; - auto& b4ptr = B.b4ptr; - auto& asptr = A.sptr; - auto& azptr = A.zpptr; - int constexpr VLen = 16; + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + int blks = k / blocksize; + int constexpr VLen = 16; int constexpr NReg = NTILE / VLen; int constexpr MReg = MTILE; - // Initialize accumulator with zeros __m512 acc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm512_setzero_ps(); } - uint32_t mask = 0x0f0f0f0f; - auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); - const __m512i onesu8 = _mm512_set1_epi8(1); - const __m512i vbias = _mm512_set1_epi8(8); + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(4); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); - + const auto onesu8 = _mm512_set1_epi8(1); + int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { __m512i iacc[NReg * MReg]; __m512i bacc[NReg]; @@ -3850,59 +5018,76 @@ static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const ut __m512i bzp[NReg]; auto bzptr = B.zpptr + ib * B.ldzp; for (int i = 0; i < NReg; i++) { - bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * VLen, vindex); + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 16, vindex); bzp[i] = _mm512_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += 4) { + for (int ik = 0; ik < blocksize; ik += KTILE) { if constexpr (MTILE == 1) { - __m512i va = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, bzp[i]); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; } } else { __m512i va[MReg]; for (int i = 0; i < MReg; i++) { - va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, bzp[i]); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); } + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; } } } } else { - for (int ik = 0; ik < blocksize; ik += 4) { + for (int ik = 0; ik < blocksize; ik += KTILE) { if constexpr (MTILE == 1) { - __m512i va = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, vbias); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; } } else { __m512i va[MReg]; for (int i = 0; i < MReg; i++) { - va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, vbias); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); } + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; } } } } + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } @@ -3916,27 +5101,34 @@ static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const ut } template -static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto& a8ptr = A.aptr; - auto& b4ptr = B.b4ptr; - auto& asptr = A.sptr; + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); int blks = k / blocksize; int constexpr VLen = 16; int constexpr NReg = NTILE / VLen; int constexpr MReg = MTILE; - // Initialize accumulator with zeros __m512 acc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm512_setzero_ps(); } - const __m512i vbias = _mm512_set1_epi8(8); - uint32_t mask = 0x0f0f0f0f; - auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(4); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { __m512i iacc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { @@ -3946,38 +5138,46 @@ static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const ut __m512i bzp[NReg]; auto bzptr = B.zpptr + ib * B.ldzp; for (int i = 0; i < NReg; i++) { - bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * VLen, vindex); + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 16, vindex); bzp[i] = _mm512_add_epi8(bzp[i], vbias); } - for (int ik = 0; ik < blocksize; ik += 4) { + for (int ik = 0; ik < blocksize; ik += KTILE) { __m512i va[MReg]; for (int i = 0; i < MReg; i++) { - va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, bzp[i]); for (int j = 0; j < MReg; j++) { auto vsb = _mm512_sign_epi8(vb, va[j]); auto vabsa = _mm512_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); } + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; } } } else { - for (int ik = 0; ik < blocksize; ik += 4) { + for (int ik = 0; ik < blocksize; ik += KTILE) { __m512i va[MReg]; for (int i = 0; i < MReg; i++) { - va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, vbias); for (int j = 0; j < MReg; j++) { auto vsb = _mm512_sign_epi8(vb, va[j]); auto vabsa = _mm512_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); } + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; } } } @@ -3994,30 +5194,30 @@ static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const ut } template -static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_5bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto b2ptr = reinterpret_cast(B.b2ptr); - int constexpr VLen = 16; + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + int blks = k / blocksize; + int constexpr VLen = 16; int constexpr NReg = NTILE / VLen; int constexpr MReg = MTILE; __m512 acc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm512_setzero_ps(); } + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); - const auto onesu8 = _mm512_set1_epi8(1); + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); - uint64_t mask0 = 0x0303030303030303; - auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); - auto vbias = _mm512_set1_epi8(2); - auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); - auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, - 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, - 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); - auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + const auto onesu8 = _mm512_set1_epi8(1); int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { __m512i iacc[NReg * MReg]; @@ -4039,11 +5239,15 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut if constexpr (MTILE == 1) { __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, bzp[i]); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); - b2ptr += VLen * KTILE / 4; + b4ptr += VLen * KTILE / 2; + b1ptr += VLen * KTILE / 8; } } else { __m512i va[MReg]; @@ -4051,13 +5255,17 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, bzp[i]); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); } - b2ptr += VLen * KTILE / 4; + b4ptr += VLen * KTILE / 2; + b1ptr += VLen * KTILE / 8; } } } @@ -4066,11 +5274,15 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut if constexpr (MTILE == 1) { __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, vbias); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); - b2ptr += VLen * KTILE / 4; + b4ptr += VLen * KTILE / 2; + b1ptr += VLen * KTILE / 8; } } else { __m512i va[MReg]; @@ -4078,13 +5290,17 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, vbias); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); } - b2ptr += VLen * KTILE / 4; + b4ptr += VLen * KTILE / 2; + b1ptr += VLen * KTILE / 8; } } } @@ -4103,43 +5319,41 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut } template -static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_5bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { - auto b2ptr = reinterpret_cast(B.b2ptr); - int constexpr VLen = 16; + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + int blks = k / blocksize; + int constexpr VLen = 16; int constexpr NReg = NTILE / VLen; int constexpr MReg = MTILE; __m512 acc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm512_setzero_ps(); } + int constexpr FullRange = 1 << (5 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); - const auto onesu8 = _mm512_set1_epi8(1); + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); - uint64_t mask0 = 0x0303030303030303; - auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); - auto vbias = _mm512_set1_epi8(2); - auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); - auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, - 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, - 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); - auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); int constexpr KTILE = 4; for (int ib = 0; ib < blks; ib += 1) { __m512i iacc[NReg * MReg]; for (int i = 0; i < NReg * MReg; i++) { iacc[i] = _mm512_setzero_si512(); } - if (B.zpptr) { __m512i bzp[NReg]; auto bzptr = B.zpptr + ib * B.ldzp; for (int i = 0; i < NReg; i++) { bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 16, vindex); - bzp[i] = _mm512_add_epi8(vbias, bzp[i]); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); } for (int ik = 0; ik < blocksize; ik += KTILE) { __m512i va[MReg]; @@ -4147,14 +5361,18 @@ static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, bzp[i]); for (int j = 0; j < MReg; j++) { auto vsb = _mm512_sign_epi8(vb, va[j]); auto vabsa = _mm512_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); } - b2ptr += VLen * KTILE / 4; + b4ptr += VLen * KTILE / 2; + b1ptr += VLen * KTILE / 8; } } } else { @@ -4164,17 +5382,22 @@ static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb1 = _mm512_slli_epi32(vb1, 2); + vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, vbias); for (int j = 0; j < MReg; j++) { auto vsb = _mm512_sign_epi8(vb, va[j]); auto vabsa = _mm512_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); } - b2ptr += VLen * KTILE / 4; + b4ptr += VLen * KTILE / 2; + b1ptr += VLen * KTILE / 8; } } } + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); } @@ -4187,10 +5410,10 @@ static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const ut } template -static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_6bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b4ptr = reinterpret_cast(B.b4ptr); auto b2ptr = reinterpret_cast(B.b2ptr); - auto b1ptr = reinterpret_cast(B.b1ptr); int blks = k / blocksize; int constexpr VLen = 16; @@ -4200,17 +5423,18 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm512_setzero_ps(); } + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); + uint64_t mask0 = 0x0303030303030303; auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); - auto vbias = _mm512_set1_epi8(4); auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); - - auto zmm_0x04 = _mm512_set1_epi8(0x04); - auto zmm_0x00 = _mm512_set1_epi8(0x00); const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); @@ -4236,14 +5460,15 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut if constexpr (MTILE == 1) { __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, bzp[i]); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + b4ptr += VLen * KTILE / 2; b2ptr += VLen * KTILE / 4; - b1ptr += VLen * KTILE / 8; } } else { __m512i va[MReg]; @@ -4251,16 +5476,17 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, bzp[i]); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); } + b4ptr += VLen * KTILE / 2; b2ptr += VLen * KTILE / 4; - b1ptr += VLen * KTILE / 8; } } } @@ -4269,14 +5495,15 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut if constexpr (MTILE == 1) { __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, vbias); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + b4ptr += VLen * KTILE / 2; b2ptr += VLen * KTILE / 4; - b1ptr += VLen * KTILE / 8; } } else { __m512i va[MReg]; @@ -4284,16 +5511,17 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, vbias); bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); for (int j = 0; j < MReg; j++) { iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); } + b4ptr += VLen * KTILE / 2; b2ptr += VLen * KTILE / 4; - b1ptr += VLen * KTILE / 8; } } } @@ -4312,10 +5540,10 @@ static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const ut } template -static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, +static inline BTLA_CODE gemv_6bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b4ptr = reinterpret_cast(B.b4ptr); auto b2ptr = reinterpret_cast(B.b2ptr); - auto b1ptr = reinterpret_cast(B.b1ptr); int blks = k / blocksize; int constexpr VLen = 16; @@ -4325,17 +5553,18 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut for (int i = 0; i < NReg * MReg; i++) { acc[i] = _mm512_setzero_ps(); } + int constexpr FullRange = 1 << (6 - 1); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(FullRange); + uint64_t mask0 = 0x0303030303030303; auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); - auto vbias = _mm512_set1_epi8(4); auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); - - auto zmm_0x04 = _mm512_set1_epi8(0x04); - auto zmm_0x00 = _mm512_set1_epi8(0x00); const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); @@ -4358,8 +5587,9 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, bzp[i]); for (int j = 0; j < MReg; j++) { @@ -4367,8 +5597,8 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut auto vabsa = _mm512_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); } + b4ptr += VLen * KTILE / 2; b2ptr += VLen * KTILE / 4; - b1ptr += VLen * KTILE / 8; } } } else { @@ -4378,8 +5608,9 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); } for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); - auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + auto vb = unpack_4bits(b4ptr, vmask); + auto vb1 = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb1 = _mm512_slli_epi32(vb1, 4); vb = _mm512_or_si512(vb, vb1); vb = _mm512_sub_epi8(vb, vbias); for (int j = 0; j < MReg; j++) { @@ -4387,8 +5618,8 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut auto vabsa = _mm512_sign_epi8(va[j], va[j]); iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); } + b4ptr += VLen * KTILE / 2; b2ptr += VLen * KTILE / 4; - b1ptr += VLen * KTILE / 8; } } } diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index fb3fb1f65..b133340f1 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -152,28 +152,25 @@ static inline BTLA_CODE transpose2d(const _T* srcptr, _T* dstptr, int row, int c return BTLA_CODE::Success; } -static inline BTLA_CODE compress_s8_s4(const int8_t* srcptr, utils::int4x2* dstptr, int row, int col, int ld_src, - int ld_dst) { - for (int j = 0; j < row; j++) { - for (int ii = 0; ii < col; ii += 2) { - utils::int4x2 tmp; - tmp.x = utils::int4x2::convert(srcptr[j * ld_src + ii + 0]) + 8; - tmp.y = utils::int4x2::convert(srcptr[j * ld_src + ii + 1]) + 8; - dstptr[j * ld_dst / 2 + ii / 2] = tmp; - } +static inline BTLA_CODE compress_s8_s4(const int8_t* srcptr, utils::int4x2* dstptr, size_t size) { + int8_t constexpr FullRange = 1 << (4 - 1); + assert(size % 2 == 0); + for (int ii = 0; ii < size; ii += 2) { + utils::int4x2 tmp; + tmp.x = srcptr[ii + 0] + FullRange; + tmp.y = srcptr[ii + 1] + FullRange; + dstptr[ii / 2] = tmp; } return BTLA_CODE::Success; } -static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, int row, int col, int ld_src, - int ld_dst) { - for (int j = 0; j < row; j++) { - for (int ii = 0; ii < col; ii += 2) { - utils::f4x2 tmp; - tmp.x = srcptr[j * ld_src + ii + 0]; - tmp.y = srcptr[j * ld_src + ii + 1]; - dstptr[j * ld_dst / 2 + ii / 2] = tmp; - } +static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, size_t size) { + for (int ii = 0; ii < size; ii += 2) { + assert(size % 2 == 0); + utils::f4x2 tmp; + tmp.x = srcptr[ii + 0]; + tmp.y = srcptr[ii + 1]; + dstptr[ii / 2] = tmp; } return BTLA_CODE::Success; } @@ -226,6 +223,63 @@ static inline BTLA_CODE compress_3bit_align128(const int8_t* srcptr, bestla::uti return BTLA_CODE::Success; } +static inline BTLA_CODE compress_6bit(const int8_t* srcptr, bestla::utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, + size_t size) { + assert(size % 4 == 0); + int8_t constexpr FullRange = 1 << (6 - 1); + for (int j = 0; j < size; j += 4) { + auto tmp = srcptr[j + 0] + FullRange; + bit4ptr[j / 2 + 0].x = tmp & 0xf; + bit2ptr[j / 4].a = tmp >> 4; + tmp = srcptr[j + 1] + FullRange; + bit4ptr[j / 2 + 0].y = tmp & 0xf; + bit2ptr[j / 4].b = tmp >> 4; + tmp = srcptr[j + 2] + FullRange; + bit4ptr[j / 2 + 1].x = tmp & 0xf; + bit2ptr[j / 4].c = tmp >> 4; + tmp = srcptr[j + 3] + FullRange; + bit4ptr[j / 2 + 1].y = tmp & 0xf; + bit2ptr[j / 4].d = tmp >> 4; + } + + return BTLA_CODE::Success; +} + +static inline BTLA_CODE compress_5bit(const int8_t* srcptr, bestla::utils::bit4x2* bit4ptr, utils::bit1x8* bit1ptr, + size_t size) { + assert(size % 8 == 0); + int8_t constexpr FullRange = 1 << (5 - 1); + for (int j = 0; j < size; j += 8) { + auto tmp = srcptr[j + 0] + FullRange; + bit4ptr[j / 2 + 0].x = tmp & 0xf; + bit1ptr[j / 8].a = tmp >> 4; + tmp = srcptr[j + 1] + FullRange; + bit4ptr[j / 2 + 0].y = tmp & 0xf; + bit1ptr[j / 8].b = tmp >> 4; + tmp = srcptr[j + 2] + FullRange; + bit4ptr[j / 2 + 1].x = tmp & 0xf; + bit1ptr[j / 8].c = tmp >> 4; + tmp = srcptr[j + 3] + FullRange; + bit4ptr[j / 2 + 1].y = tmp & 0xf; + bit1ptr[j / 8].d = tmp >> 4; + + tmp = srcptr[j + 4] + FullRange; + bit4ptr[j / 2 + 2].x = tmp & 0xf; + bit1ptr[j / 8].e = tmp >> 4; + tmp = srcptr[j + 5] + FullRange; + bit4ptr[j / 2 + 2].y = tmp & 0xf; + bit1ptr[j / 8].f = tmp >> 4; + tmp = srcptr[j + 6] + FullRange; + bit4ptr[j / 2 + 3].x = tmp & 0xf; + bit1ptr[j / 8].g = tmp >> 4; + tmp = srcptr[j + 7] + FullRange; + bit4ptr[j / 2 + 3].y = tmp & 0xf; + bit1ptr[j / 8].h = tmp >> 4; + } + + return BTLA_CODE::Success; +} + static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, size_t size) { assert(size % 8 == 0); @@ -331,6 +385,42 @@ static inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { dstptr[7] = static_cast(tmp); } +static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, int8_t* dstptr, int unpack_elt, + int8_t* tmp, size_t tmpsize) { + int constexpr FullRange = 1 << (6 - 1); + for (size_t i = 0; i < unpack_elt; i += 4) { + auto bit2 = bit2ptr[i / 4]; + auto tmp = bit4ptr[i / 2]; + dstptr[i + 0] = (tmp.x | (bit2.a << 4)) - FullRange; + dstptr[i + 1] = (tmp.y | (bit2.b << 4)) - FullRange; + tmp = bit4ptr[i / 2 + 1]; + dstptr[i + 2] = (tmp.x | (bit2.c << 4)) - FullRange; + dstptr[i + 3] = (tmp.y | (bit2.d << 4)) - FullRange; + } + return BTLA_CODE::Success; +} + +static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8* bit1ptr, int8_t* dstptr, int unpack_elt, + int8_t* tmp, size_t tmpsize) { + int constexpr FullRange = 1 << (5 - 1); + for (size_t i = 0; i < unpack_elt; i += 8) { + auto bit1 = bit1ptr[i / 8]; + auto tmp = bit4ptr[i / 2]; + dstptr[i + 0] = (tmp.x | (bit1.a << 4)) - FullRange; + dstptr[i + 1] = (tmp.y | (bit1.b << 4)) - FullRange; + tmp = bit4ptr[i / 2 + 1]; + dstptr[i + 2] = (tmp.x | (bit1.c << 4)) - FullRange; + dstptr[i + 3] = (tmp.y | (bit1.d << 4)) - FullRange; + tmp = bit4ptr[i / 2 + 2]; + dstptr[i + 4] = (tmp.x | (bit1.e << 4)) - FullRange; + dstptr[i + 5] = (tmp.y | (bit1.f << 4)) - FullRange; + tmp = bit4ptr[i / 2 + 3]; + dstptr[i + 6] = (tmp.x | (bit1.g << 4)) - FullRange; + dstptr[i + 7] = (tmp.y | (bit1.h << 4)) - FullRange; + } + return BTLA_CODE::Success; +} + static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, size_t unpackelt, int8_t* tmp, size_t tmpsize) { for (int j = 0; j < unpackelt; j += 2) { @@ -371,6 +461,145 @@ static inline BTLA_CODE decompress_s2_s8(utils::bit2x4* srcptr, int8_t* dstptr, return BTLA_CODE::Success; } +template +static inline BTLA_CODE decompress_kblock_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, + int row, int col, int8_t* tmp, size_t tmpsize) { + int constexpr FullRange = 1 << (6 - 1); + static_assert(NTILE % 4 == 0); + assert(((col * PackRow) % 4) == 0); + if (zpptr) { + if constexpr (PackRow == 4) { + for (int i = 0; i < row; i += PackRow) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 1) { + auto zp = zptr[j] + FullRange; + auto bit2 = bit2ptr[(i * col + j * PackRow) / 4]; + auto tmp = bit4ptr[(i * col + j * PackRow) / 2]; + dstptr[i * col + j * PackRow + 0] = (tmp.x | (bit2.a << 4)) - zp; + dstptr[i * col + j * PackRow + 1] = (tmp.y | (bit2.b << 4)) - zp; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 1]; + dstptr[i * col + j * PackRow + 2] = (tmp.x | (bit2.c << 4)) - zp; + dstptr[i * col + j * PackRow + 3] = (tmp.y | (bit2.d << 4)) - zp; + } + } + } else if constexpr (PackRow == 1) { + for (int i = 0; i < row; i += 1) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 4) { + auto bit2 = bit2ptr[(i * col + j * PackRow) / 4]; + auto tmp = bit4ptr[(i * col + j * PackRow) / 2]; + dstptr[i * col + j * PackRow + 0] = (tmp.x | (bit2.a << 4)) - FullRange - zptr[j + 0]; + dstptr[i * col + j * PackRow + 1] = (tmp.y | (bit2.b << 4)) - FullRange - zptr[j + 1]; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 1]; + dstptr[i * col + j * PackRow + 2] = (tmp.x | (bit2.c << 4)) - FullRange - zptr[j + 2]; + dstptr[i * col + j * PackRow + 3] = (tmp.y | (bit2.d << 4)) - FullRange - zptr[j + 3]; + } + } + } else if constexpr (PackRow == 2) { + for (int i = 0; i < row; i += PackRow) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 2) { + auto bit2 = bit2ptr[(i * col + j * PackRow) / 4]; + auto tmp = bit4ptr[(i * col + j * PackRow) / 2]; + auto zp = zptr[j] + FullRange; + dstptr[i * col + j * PackRow + 0] = (tmp.x | (bit2.a << 4)) - zp; + dstptr[i * col + j * PackRow + 1] = (tmp.y | (bit2.b << 4)) - zp; + zp = zptr[j + 1] + FullRange; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 1]; + dstptr[i * col + j * PackRow + 2] = (tmp.x | (bit2.c << 4)) - zp; + dstptr[i * col + j * PackRow + 3] = (tmp.y | (bit2.d << 4)) - zp; + } + } + } else { + static_assert(PackRow == 1 || PackRow == 2 || PackRow == 4); + } + } else { + return decompress_s6_s8(bit4ptr, bit2ptr, dstptr, size_t(row) * col, tmp, tmpsize); + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, + int row, int col, int8_t* tmp, size_t tmpsize) { + int constexpr FullRange = 1 << (5 - 1); + static_assert(NTILE % 8 == 0); + assert(((col * PackRow) % 8) == 0); + if (zpptr) { + if constexpr (PackRow == 4) { + for (int i = 0; i < row; i += PackRow) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 2) { + auto zp = zptr[j] + FullRange; + auto bit1 = bit1ptr[(i * col + j * PackRow) / 8]; + auto tmp = bit4ptr[(i * col + j * PackRow) / 2]; + dstptr[i * col + j * PackRow + 0] = (tmp.x | (bit1.a << 4)) - zp; + dstptr[i * col + j * PackRow + 1] = (tmp.y | (bit1.b << 4)) - zp; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 1]; + dstptr[i * col + j * PackRow + 2] = (tmp.x | (bit1.c << 4)) - zp; + dstptr[i * col + j * PackRow + 3] = (tmp.y | (bit1.d << 4)) - zp; + zp = zptr[j + 1] + FullRange; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 2]; + dstptr[i * col + j * PackRow + 4] = (tmp.x | (bit1.e << 4)) - zp; + dstptr[i * col + j * PackRow + 5] = (tmp.y | (bit1.f << 4)) - zp; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 3]; + dstptr[i * col + j * PackRow + 6] = (tmp.x | (bit1.g << 4)) - zp; + dstptr[i * col + j * PackRow + 7] = (tmp.y | (bit1.h << 4)) - zp; + } + } + } else if constexpr (PackRow == 1) { + for (int i = 0; i < row; i += 1) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 8) { + auto bit1 = bit1ptr[(i * col + j * PackRow) / 8]; + auto tmp = bit4ptr[(i * col + j * PackRow) / 2]; + dstptr[i * col + j * PackRow + 0] = (tmp.x | (bit1.a << 4)) - FullRange - zptr[j + 0]; + dstptr[i * col + j * PackRow + 1] = (tmp.y | (bit1.b << 4)) - FullRange - zptr[j + 1]; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 1]; + dstptr[i * col + j * PackRow + 2] = (tmp.x | (bit1.c << 4)) - FullRange - zptr[j + 2]; + dstptr[i * col + j * PackRow + 3] = (tmp.y | (bit1.d << 4)) - FullRange - zptr[j + 3]; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 2]; + dstptr[i * col + j * PackRow + 4] = (tmp.x | (bit1.e << 4)) - FullRange - zptr[j + 4]; + dstptr[i * col + j * PackRow + 5] = (tmp.y | (bit1.f << 4)) - FullRange - zptr[j + 5]; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 3]; + dstptr[i * col + j * PackRow + 6] = (tmp.x | (bit1.g << 4)) - FullRange - zptr[j + 6]; + dstptr[i * col + j * PackRow + 7] = (tmp.y | (bit1.h << 4)) - FullRange - zptr[j + 7]; + } + } + } else if constexpr (PackRow == 2) { + for (int i = 0; i < row; i += PackRow) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 4) { + auto bit1 = bit1ptr[(i * col + j * PackRow) / 8]; + auto tmp = bit4ptr[(i * col + j * PackRow) / 2]; + auto zp = zptr[j] + FullRange; + dstptr[i * col + j * PackRow + 0] = (tmp.x | (bit1.a << 4)) - zp; + dstptr[i * col + j * PackRow + 1] = (tmp.y | (bit1.b << 4)) - zp; + zp = zptr[j + 1] + FullRange; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 1]; + dstptr[i * col + j * PackRow + 2] = (tmp.x | (bit1.c << 4)) - zp; + dstptr[i * col + j * PackRow + 3] = (tmp.y | (bit1.d << 4)) - zp; + zp = zptr[j + 2] + FullRange; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 2]; + dstptr[i * col + j * PackRow + 4] = (tmp.x | (bit1.e << 4)) - zp; + dstptr[i * col + j * PackRow + 5] = (tmp.y | (bit1.f << 4)) - zp; + zp = zptr[j + 3] + FullRange; + tmp = bit4ptr[(i * col + j * PackRow) / 2 + 3]; + dstptr[i * col + j * PackRow + 6] = (tmp.x | (bit1.g << 4)) - zp; + dstptr[i * col + j * PackRow + 7] = (tmp.y | (bit1.h << 4)) - zp; + } + } + } else { + static_assert(PackRow == 1 || PackRow == 2 || PackRow == 4); + } + } else { + return decompress_s5_s8(bit4ptr, bit1ptr, dstptr, size_t(row) * col, tmp, tmpsize); + } + return BTLA_CODE::Success; +} + template static inline BTLA_CODE decompress_kblock_s4_s8(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, @@ -605,6 +834,42 @@ inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, DST_T* dstptr, int row, return BTLA_CODE::Success; } +template +static inline BTLA_CODE decompress_kblock_s6_fp(utils::bit4x2* b4ptr, utils::bit2x4* b2ptr, DST_T* dstptr, int row, + int col, void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, + int k_offset, int n_offset, int blocksize, int ldzp, int8_t* tmp, + size_t tmpsize) { + assert(tmpsize >= PackRow * NTILE); + assert(NTILE == col); + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + decompress_kblock_s6_s8(b4ptr, b2ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, row, + col, tmp, tmpsize); + decompress_kblock_s8_fp(tmps8ptr, dstptr, row, col, scales_, sdtype, nullptr, k_offset, n_offset, + blocksize, ldzp, tmp, tmpsize); + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s5_fp(utils::bit4x2* b4ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, + int col, void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, + int k_offset, int n_offset, int blocksize, int ldzp, int8_t* tmp, + size_t tmpsize) { + assert(tmpsize >= PackRow * NTILE); + assert(NTILE == col); + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + decompress_kblock_s5_s8(b4ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, row, + col, tmp, tmpsize); + decompress_kblock_s8_fp(tmps8ptr, dstptr, row, col, scales_, sdtype, nullptr, k_offset, n_offset, + blocksize, ldzp, tmp, tmpsize); + return BTLA_CODE::Success; +} + template static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, DST_T* dstptr, int row, int col, void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, @@ -1178,6 +1443,8 @@ static inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8 case BTLA_DTYPE::S2_CLIP: case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_CLIP: + case BTLA_DTYPE::S5_CLIP: + case BTLA_DTYPE::S6_CLIP: if (zero_points == nullptr) { sNauto_calc_store_scale_and_quantv_sym(blocksize); } else { @@ -2380,6 +2647,250 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut } return BTLA_CODE::Success; } + +template +static inline BTLA_CODE gemv_6bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b2ptr = reinterpret_cast(B.b2ptr); + int constexpr KTILE = 1; + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + int8_t UnpackBuf[NTILE * Unroll]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += Unroll) { + decompress_kblock_s6_s8<1, NTILE>(b4ptr, b2ptr, B.zpptr ? bzptr : nullptr, UnpackBuf, blocksize, B.ldzp, 0, 0, + Unroll, NTILE, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < Unroll; ikt++) { + auto bval = (UnpackBuf[in + ikt * NTILE]) * bsptr[in]; + auto aval = A[ikt + im * lda]; + accf[im * NTILE + in] += aval * bval; + } + } + } + b4ptr += Unroll * NTILE / 2; + b2ptr += Unroll * NTILE / 4; + A += Unroll; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_6bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = A.aptr; + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b2ptr = reinterpret_cast(B.b2ptr); + int constexpr KTILE = 4; + int8_t UnpackBuf[NTILE * KTILE]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += KTILE) { + decompress_kblock_s6_s8<4, NTILE>(b4ptr, b2ptr, B.zpptr ? bzptr : nullptr, UnpackBuf, blocksize, B.ldzp, 0, 0, + KTILE, NTILE, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + float ascale = A.sptr[ib + im * A.ldzp]; + auto azp = A.zpptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < KTILE; ikt++) { + auto bval = (UnpackBuf[in * KTILE + ikt]) * bsptr[in]; + auto aval = int(a8ptr[ikt + im * A.lda] - azp) * ascale; + accf[im * NTILE + in] += aval * bval; + } + } + } + b4ptr += KTILE * NTILE / 2; + b2ptr += KTILE * NTILE / 4; + a8ptr += KTILE; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_6bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = (int8_t*)A.aptr; + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b2ptr = reinterpret_cast(B.b2ptr); + int constexpr KTILE = 4; + int8_t UnpackBuf[NTILE * KTILE]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += KTILE) { + decompress_kblock_s6_s8<4, NTILE>(b4ptr, b2ptr, B.zpptr ? bzptr : nullptr, UnpackBuf, blocksize, B.ldzp, 0, 0, + KTILE, NTILE, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + float ascale = A.sptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < KTILE; ikt++) { + auto bval = (UnpackBuf[in * KTILE + ikt]) * bsptr[in]; + auto aval = int(a8ptr[ikt + im * A.lda]) * ascale; + accf[im * NTILE + in] += aval * bval; + } + } + } + b4ptr += KTILE * NTILE / 2; + b2ptr += KTILE * NTILE / 4; + a8ptr += KTILE; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_5bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + int constexpr KTILE = 1; + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + int8_t UnpackBuf[NTILE * Unroll]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += Unroll) { + decompress_kblock_s5_s8<1, NTILE>(b4ptr, b1ptr, B.zpptr ? bzptr : nullptr, UnpackBuf, blocksize, B.ldzp, 0, 0, + Unroll, NTILE, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < Unroll; ikt++) { + auto bval = (UnpackBuf[in + ikt * NTILE]) * bsptr[in]; + auto aval = A[ikt + im * lda]; + accf[im * NTILE + in] += aval * bval; + } + } + } + b4ptr += Unroll * NTILE / 2; + b1ptr += Unroll * NTILE / 8; + A += Unroll; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_5bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = A.aptr; + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + int constexpr KTILE = 4; + int8_t UnpackBuf[NTILE * KTILE]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += KTILE) { + decompress_kblock_s5_s8<4, NTILE>(b4ptr, b1ptr, B.zpptr ? bzptr : nullptr, UnpackBuf, blocksize, B.ldzp, 0, 0, + KTILE, NTILE, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + float ascale = A.sptr[ib + im * A.ldzp]; + auto azp = A.zpptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < KTILE; ikt++) { + auto bval = (UnpackBuf[in * KTILE + ikt]) * bsptr[in]; + auto aval = int(a8ptr[ikt + im * A.lda] - azp) * ascale; + accf[im * NTILE + in] += aval * bval; + } + } + } + b4ptr += KTILE * NTILE / 2; + b1ptr += KTILE * NTILE / 8; + a8ptr += KTILE; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_5bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = (int8_t*)A.aptr; + auto b4ptr = reinterpret_cast(B.b4ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + int constexpr KTILE = 4; + int8_t UnpackBuf[NTILE * KTILE]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += KTILE) { + decompress_kblock_s5_s8<4, NTILE>(b4ptr, b1ptr, B.zpptr ? bzptr : nullptr, UnpackBuf, blocksize, B.ldzp, 0, 0, + KTILE, NTILE, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + float ascale = A.sptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < KTILE; ikt++) { + auto bval = (UnpackBuf[in * KTILE + ikt]) * bsptr[in]; + auto aval = int(a8ptr[ikt + im * A.lda]) * ascale; + accf[im * NTILE + in] += aval * bval; + } + } + } + b4ptr += KTILE * NTILE / 2; + b1ptr += KTILE * NTILE / 8; + a8ptr += KTILE; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} } // namespace ref } // namespace kernel } // namespace bestla diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index beb9ba733..11f191914 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -260,18 +260,34 @@ class Dq8GetScale { class CompressS8S4 { public: template - static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::int4x2* dstptr, int row, int col, int ld_src, - int ld_dst) { - return ref::compress_s8_s4(srcptr, dstptr, row, col, ld_src, ld_dst); + static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::int4x2* dstptr, size_t size) { + return ref::compress_s8_s4(srcptr, dstptr, size); } }; class CompressFp4 { public: template - static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::f4x2* dstptr, int row, int col, int ld_src, - int ld_dst) { - return ref::compress_f4(srcptr, dstptr, row, col, ld_src, ld_dst); + static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::f4x2* dstptr, size_t size) { + return ref::compress_f4(srcptr, dstptr, size); + } +}; + +class CompressBit6 { + public: + template + static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, + size_t size) { + return ref::compress_6bit(srcptr, bit4ptr, bit2ptr, size); + } +}; + +class CompressBit5 { + public: + template + static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::bit4x2* bit4ptr, utils::bit1x8* bit1ptr, + size_t size) { + return ref::compress_5bit(srcptr, bit4ptr, bit1ptr, size); } }; @@ -433,6 +449,54 @@ class DecompressKBlockS4S8 { } }; +template +class DecompressKBlockS6S8 { + public: + template + static inline BTLA_CODE forward(utils::bit4x2* b4ptr, utils::bit2x4* b2ptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, int col, void* tmp, + size_t tmpsize) { +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + return avx512f::decompress_kblock_s6_s8(b4ptr, b2ptr, zpptr, dstptr, blocksize, ldzp, n_offset, + k_offset, row, col, (int8_t*)tmp, tmpsize); + } +#endif +#if CompileAVX2() + if constexpr (utils::isa_base::avx2) { + return avx2::decompress_kblock_s6_s8(b4ptr, b2ptr, zpptr, dstptr, blocksize, ldzp, n_offset, + k_offset, row, col, (int8_t*)tmp, tmpsize); + } +#endif + return ref::decompress_kblock_s6_s8(b4ptr, b2ptr, zpptr, dstptr, blocksize, ldzp, n_offset, + k_offset, row, col, (int8_t*)tmp, tmpsize); + } +}; + +template +class DecompressKBlockS5S8 { + public: + template + static inline BTLA_CODE forward(utils::bit4x2* b4ptr, utils::bit1x8* b1ptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, int col, void* tmp, + size_t tmpsize) { +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + return avx512f::decompress_kblock_s5_s8(b4ptr, b1ptr, zpptr, dstptr, blocksize, ldzp, n_offset, + k_offset, row, col, (int8_t*)tmp, tmpsize); + } +#endif +#if CompileAVX2() + if constexpr (utils::isa_base::avx2) { + return avx2::decompress_kblock_s5_s8(b4ptr, b1ptr, zpptr, dstptr, blocksize, ldzp, n_offset, + k_offset, row, col, (int8_t*)tmp, tmpsize); + } +#endif + return ref::decompress_kblock_s5_s8(b4ptr, b1ptr, zpptr, dstptr, blocksize, ldzp, n_offset, + k_offset, row, col, (int8_t*)tmp, tmpsize); + } +}; + template class DecompressKBlockS3S8 { public: @@ -440,6 +504,12 @@ class DecompressKBlockS3S8 { static inline BTLA_CODE forward(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int col, void* tmp, size_t tmpsize) { +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + return avx512f::decompress_kblock_s3_s8(b2ptr, b1ptr, zpptr, dstptr, blocksize, ldzp, n_offset, + k_offset, row, col, (int8_t*)tmp, tmpsize); + } +#endif #if CompileAVX2() if constexpr (utils::isa_base::avx2) { return avx2::decompress_kblock_s3_s8(b2ptr, b1ptr, zpptr, dstptr, blocksize, ldzp, n_offset, @@ -505,6 +575,64 @@ class DecompressKBlockS8Fp { } }; +template +class DecompressKBlockS6Fp { + public: + template + static inline BTLA_CODE forward(utils::bit4x2* b4ptr, utils::bit2x4* b2ptr, DstT* dstptr, int row, int col, + void* scales, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int kblock, int NPad, void* tmp, size_t tmpsize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + return avx512f::decompress_kblock_s6_fp(b4ptr, b2ptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } +#endif +#if CompileAVX2() + if constexpr (utils::isa_base::avx2) { + return avx2::decompress_kblock_s6_fp(b4ptr, b2ptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } +#endif + ret = ref::decompress_kblock_s6_fp(b4ptr, b2ptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + return ret; + } +}; + +template +class DecompressKBlockS5Fp { + public: + template + static inline BTLA_CODE forward(utils::bit4x2* b4ptr, utils::bit1x8* b1ptr, DstT* dstptr, int row, int col, + void* scales, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int kblock, int NPad, void* tmp, size_t tmpsize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + return avx512f::decompress_kblock_s5_fp(b4ptr, b1ptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } +#endif +#if CompileAVX2() + if constexpr (utils::isa_base::avx2) { + return avx2::decompress_kblock_s5_fp(b4ptr, b1ptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } +#endif + ret = ref::decompress_kblock_s5_fp(b4ptr, b1ptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + return ret; + } +}; + template class DecompressKBlockS4Fp { public: @@ -928,6 +1056,44 @@ class GEMVWoqNBits { template static inline BTLA_CODE forward_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, void* tmp, size_t tmpsize) { + if (B.nbits == 6) { +#if CompileAVX512VNNI() + if (ISA_T >= BTLA_ISA::AVX512_VNNI) { + return avx512f::vnni::gemv_6bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVXVNNI() + if (ISA_T >= BTLA_ISA::AVX_VNNI) { + return avx2::vnni::gemv_6bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif +#if CompileAVX2() + if (ISA_T >= BTLA_ISA::AVX2) { + return avx2::gemv_6bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_6bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + if (B.nbits == 5) { +#if CompileAVX512VNNI() + if (ISA_T >= BTLA_ISA::AVX512_VNNI) { + return avx512f::vnni::gemv_5bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVXVNNI() + if (ISA_T >= BTLA_ISA::AVX_VNNI) { + return avx2::vnni::gemv_5bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif +#if CompileAVX2() + if (ISA_T >= BTLA_ISA::AVX2) { + return avx2::gemv_5bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_5bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } if (B.nbits == 4) { #if CompileAVX512VNNI() if (ISA_T >= BTLA_ISA::AVX512_VNNI) { @@ -991,6 +1157,34 @@ class GEMVWoqNBits { template static inline BTLA_CODE forward_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, void* tmp, size_t tmpsize) { + if (B.nbits == 6) { +#if CompileAVX512VNNI() + if (ISA_T >= BTLA_ISA::AVX512_VNNI) { + return avx512f::vnni::gemv_6bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVXVNNI() + if (ISA_T >= BTLA_ISA::AVX_VNNI) { + return avx2::vnni::gemv_6bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_6bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + if (B.nbits == 5) { +#if CompileAVX512VNNI() + if (ISA_T >= BTLA_ISA::AVX512_VNNI) { + return avx512f::vnni::gemv_5bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVXVNNI() + if (ISA_T >= BTLA_ISA::AVX_VNNI) { + return avx2::vnni::gemv_5bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_5bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } if (B.nbits == 4) { #if CompileAVX512VNNI() if (ISA_T >= BTLA_ISA::AVX512_VNNI) { @@ -1039,6 +1233,34 @@ class GEMVWoqNBits { template static inline BTLA_CODE forward_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, int ldc, int k, int blocksize, void* tmp, size_t tmpsize) { + if (B.nbits == 6) { +#if CompileAVX512F() + if (ISA_T >= BTLA_ISA::AVX512F) { + return avx512f::gemv_6bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVX2() + if (ISA_T >= BTLA_ISA::AVX2) { + return avx2::gemv_6bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_6bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + if (B.nbits == 5) { +#if CompileAVX512F() + if (ISA_T >= BTLA_ISA::AVX512F) { + return avx512f::gemv_5bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVX2() + if (ISA_T >= BTLA_ISA::AVX2) { + return avx2::gemv_5bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_5bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } if (B.nbits == 4) { #if CompileAVX512F() if (ISA_T >= BTLA_ISA::AVX512F) { diff --git a/bestla/bestla/ut/bestla_benchmark.cpp b/bestla/bestla/ut/bestla_benchmark.cpp index 5d2a7fa04..56682ab67 100644 --- a/bestla/bestla/ut/bestla_benchmark.cpp +++ b/bestla/bestla/ut/bestla_benchmark.cpp @@ -1,7 +1,8 @@ #include #include "bestla_wrapper.h" #include "bestla_ut.h" - +#undef BTLA_UT_WRAPPER +#undef BTLA_UT_PROLOGUE_B namespace bestla { using namespace utils; namespace ut { @@ -440,9 +441,11 @@ class UTWOQ_CompFp32 { public: UTWOQ_CompFp32() { UT_START(); + ut_s6(); + /*ut_s5(); ut_s2(); ut_s4(); - ut_s3(); + ut_s3();*/ // ut_s8(); // ut_f4(); } @@ -458,7 +461,14 @@ class UTWOQ_CompFp32 { benchmark_all(1, 4096, 4096, BTLA_DTYPE::S4_CLIP); benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S4_CLIP); } - + void ut_s5() { + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S5_CLIP); + benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S5_CLIP); + } + void ut_s6() { + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S6_CLIP); + benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S6_CLIP); + } void ut_s8() { benchmark_all(1, 4096, 4096, BTLA_DTYPE::S8); benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S8); @@ -556,8 +566,8 @@ class UTWOQ_CompFp32 { } }; #ifdef BTLA_UT_PROLOGUE_B -static UTWOQ_CompFp32 sUTWOQ_CompFp32; #endif +static UTWOQ_CompFp32 sUTWOQ_CompFp32; class UTWOQ_CompBf16 { public: @@ -676,9 +686,11 @@ class UTWOQ_CompInt8 { public: UTWOQ_CompInt8() { UT_START(); + ut_s6(); + /*ut_s5(); ut_s2(); ut_s4(); - ut_s3(); + ut_s3();*/ // ut_s8(); } @@ -708,6 +720,24 @@ class UTWOQ_CompInt8 { // benchmark_all(2048, 4096, 4096, BTLA_DTYPE::S4_CLIP); } + void ut_s5() { + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S5_CLIP); + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S5_CLIP, true); + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S5_CLIP); + // benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S4_CLIP, true); + // benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S4_CLIP); + // benchmark_all(2048, 4096, 4096, BTLA_DTYPE::S4_CLIP); + } + + void ut_s6() { + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S6_CLIP); + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S6_CLIP, true); + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S6_CLIP); + // benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S4_CLIP, true); + // benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S4_CLIP); + // benchmark_all(2048, 4096, 4096, BTLA_DTYPE::S4_CLIP); + } + void ut_s8() { benchmark_all(1, 4096, 4096, BTLA_DTYPE::S8); benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S8); @@ -810,8 +840,8 @@ class UTWOQ_CompInt8 { } }; #ifdef BTLA_UT_PROLOGUE_B -static UTWOQ_CompInt8 sUTWOQ_CompInt8; #endif +static UTWOQ_CompInt8 sUTWOQ_CompInt8; #if 0 typedef struct { diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 7d6544249..5ff317e47 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -166,11 +166,17 @@ class UT_BlockQunatize_F8 { static UT_BlockQunatize_F8 sUT_BlockQunatize_F8; #endif -class UT_BlockQunatize_S3S4 { +class UT_BlockQunatize_SN { public: - UT_BlockQunatize_S3S4() { + UT_BlockQunatize_SN() { UT_START(); CheckISA(AVX2); + ut(4096, 4096, 32, BTLA_DTYPE::S6_CLIP, true); + ut(4096, 4096, 32, BTLA_DTYPE::S6_CLIP); + ut(4096, 4096, 128, BTLA_DTYPE::S6_CLIP); + ut(4096, 4096, 32, BTLA_DTYPE::S5_CLIP, true); + ut(4096, 4096, 32, BTLA_DTYPE::S5_CLIP); + ut(4096, 4096, 128, BTLA_DTYPE::S5_CLIP); ut(4096, 4096, 32, BTLA_DTYPE::S4_CLIP, true); ut(4096, 4096, 32, BTLA_DTYPE::S4_CLIP); ut(4096, 4096, 128, BTLA_DTYPE::S4_CLIP); @@ -213,7 +219,7 @@ class UT_BlockQunatize_S3S4 { }; #ifdef BTLA_UT_PROLOGUE_B // no proper threshold for this UT -// static UT_BlockQunatize_S3S4 sUT_BlockQunatize_S3S4; +// static UT_BlockQunatize_SN sUT_BlockQunatize_SN; #endif class UT_S3_WOQ { @@ -637,6 +643,8 @@ class UT_CompFp32 { public: UT_CompFp32() { UT_START(); + ut_s6(); + ut_s5(); ut_s4(); ut_s2(); ut_s3(); @@ -645,6 +653,7 @@ class UT_CompFp32 { ut_f4(); ut_f8(); } + void ut_s2() { GetCPUDevice(); if (_cd->AVX2()) { @@ -691,6 +700,20 @@ class UT_CompFp32 { false); ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, false); + + CheckISA(AVX512F); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + true); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + false); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + true); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + false); } void ut_f8() { @@ -705,6 +728,7 @@ class UT_CompFp32 { ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2, BTLA_DTYPE::F8_E8M0); ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2, BTLA_DTYPE::F32); } + void ut_s4() { CheckISA(AVX2); ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, @@ -731,6 +755,72 @@ class UT_CompFp32 { false); } + void ut_s5() { + CheckISA(AVX2); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + true); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::BF16, + false); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + true); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); + CheckISA(AVX512F); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + true); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::BF16, + false); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + true); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); + } + + void ut_s6() { + CheckISA(AVX2); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + true); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::BF16, + false); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + true); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); + CheckISA(AVX512F); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + true); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::BF16, + false); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + true); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); + } + void ut_s8() { CheckISA(AVX2); ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S8, BTLA_DTYPE::BF16, false); @@ -854,6 +944,8 @@ class UT_CompInt8 { public: UT_CompInt8() { UT_START(); + ut_s6(); + ut_s5(); ut_s4(); ut_s2(); ut_s3(); @@ -861,6 +953,15 @@ class UT_CompInt8 { void ut_s2() { GetCPUDevice(); + if (_cd->AVX2()) { + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(1, 4096, 4096, 16, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::BF16); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 128, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); + } if (_cd->AVX_VNNI()) { ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true); ut_newkblock>(1, 4096, 4096, 16, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::BF16); @@ -882,10 +983,20 @@ class UT_CompInt8 { ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); ut_newkblock>(1, 4096, 4096, 128, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); } + if (_cd->AMX_INT8()) { + ut_newkblock>(128, 4096, 4096, 128, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 64, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); + } } void ut_s3() { GetCPUDevice(); + if (_cd->AVX2()) { + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); + } if (_cd->AVX_VNNI()) { ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, true); @@ -900,6 +1011,10 @@ class UT_CompInt8 { true); ut_newkblock>(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); } + if (_cd->AMX_INT8()) { + ut_newkblock>(128, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 64, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); + } } void ut_s4() { @@ -938,6 +1053,67 @@ class UT_CompInt8 { } } + void ut_s5() { + GetCPUDevice(); + if (_cd->AVX2()) { + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::BF16); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + } + if (_cd->AVX_VNNI()) { + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::BF16); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::BF16); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + } + if (_cd->AVX512_VNNI()) { + ut_newkblock>(1, 11008, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + } + + if (_cd->AMX_INT8()) { + ut_newkblock>(128, 4096, 4096, 128, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 64, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + } + } + + void ut_s6() { + GetCPUDevice(); + if (_cd->AVX2()) { + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::BF16); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + } + if (_cd->AVX_VNNI()) { + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::BF16); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::BF16); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + } + if (_cd->AVX512_VNNI()) { + ut_newkblock>(1, 11008, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + } + if (_cd->AMX_INT8()) { + ut_newkblock>(128, 4096, 4096, 128, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 64, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + } + } + template void ut_newkblock(int m, int n, int k, int blocksize, BTLA_DTYPE qtype, BTLA_DTYPE stype, bool isAsym = false) { printf("Test Case %s: %d %d %d-%d type:%s core:%s scaletype:%s Asym:%d\n", __FUNCTION__, m, n, k, blocksize, diff --git a/bestla/bestla/ut/kernel_intrin.cpp b/bestla/bestla/ut/kernel_intrin.cpp index 47135c893..173293232 100644 --- a/bestla/bestla/ut/kernel_intrin.cpp +++ b/bestla/bestla/ut/kernel_intrin.cpp @@ -136,6 +136,145 @@ class UT_avx512_decompress_s3_s8 { static UT_avx512_decompress_s3_s8 sUT_avx512_decompress_s3_s8; #endif +class UT_avx512_decompress_s5_s8 { + public: + UT_avx512_decompress_s5_s8() { + UT_START(); + CheckISA(AVX512F); + ut<1, 48>(32); + ut<4, 48>(32); + ut<1, 48>(32, true); + ut<2, 48>(32, true); + ut<4, 48>(32, true); + } + + template + void ut(int blocksize, bool isasym = false) { + int row = blocksize * 2; + int constexpr col = NTILE; + printf("Test Case %s: %d %d %d\n", __FUNCTION__, row, col, blocksize); + std::vector s4_wei(row * col / 2); + avector s1_wei(row * col / 8); + + std::vector s8_wei(col * row); + std::vector s8_ref(col * row); + int blks = row / blocksize; + int row_offset = 8; + assert(blocksize % 8 == 0); + std::vector zp(col * blks); + fill_buffer_randn(zp.data(), zp.size(), int8_t(-16), int8_t(15)); + std::vector rev(col * row); + fill_buffer_randn(s8_wei.data(), s8_wei.size(), int8_t(-16), int8_t(15)); + + for (int i = 0; i < col * row; i += 8) { + memcpy(&s8_ref[i], &s8_wei[i], 8 * sizeof(int8_t)); + s4_wei[i / 2].x = (s8_wei[i + 0] + 16) & 0xf; + s4_wei[i / 2].y = (s8_wei[i + 1] + 16) & 0xf; + s4_wei[i / 2 + 1].x = (s8_wei[i + 2] + 16) & 0xf; + s4_wei[i / 2 + 1].y = (s8_wei[i + 3] + 16) & 0xf; + s4_wei[i / 2 + 2].x = (s8_wei[i + 4] + 16) & 0xf; + s4_wei[i / 2 + 2].y = (s8_wei[i + 5] + 16) & 0xf; + s4_wei[i / 2 + 3].x = (s8_wei[i + 6] + 16) & 0xf; + s4_wei[i / 2 + 3].y = (s8_wei[i + 7] + 16) & 0xf; + + s1_wei[i / 8].a = ((s8_wei[i + 0] + 16) & 0x10) >> 4; + s1_wei[i / 8].b = ((s8_wei[i + 1] + 16) & 0x10) >> 4; + s1_wei[i / 8].c = ((s8_wei[i + 2] + 16) & 0x10) >> 4; + s1_wei[i / 8].d = ((s8_wei[i + 3] + 16) & 0x10) >> 4; + s1_wei[i / 8].e = ((s8_wei[i + 4] + 16) & 0x10) >> 4; + s1_wei[i / 8].f = ((s8_wei[i + 5] + 16) & 0x10) >> 4; + s1_wei[i / 8].g = ((s8_wei[i + 6] + 16) & 0x10) >> 4; + s1_wei[i / 8].h = ((s8_wei[i + 7] + 16) & 0x10) >> 4; + } + if (isasym) { + for (int i = 0; i < row; i += PackRow) { + for (int j = 0; j < NTILE; j++) { + for (int ip = 0; ip < PackRow; ip++) { + s8_ref[i * NTILE + j * PackRow + ip] -= zp[i / blocksize * NTILE + j]; + } + } + } + } + + kernel::avx512f::decompress_kblock_s5_s8(s4_wei.data(), s1_wei.data(), isasym ? zp.data() : nullptr, + rev.data(), blocksize, NTILE, 0, 0, row_offset, NTILE, + cache, CacheSize); + kernel::avx512f::decompress_kblock_s5_s8( + s4_wei.data() + row_offset * NTILE / 2, s1_wei.data() + row_offset * NTILE / 8, isasym ? zp.data() : nullptr, + rev.data() + row_offset * NTILE, blocksize, NTILE, 0, row_offset, row - row_offset, NTILE, cache, CacheSize); + ut::buffer_error(s8_ref.data(), rev.data(), rev.size(), int8_t(0)); + } +}; +#ifdef BTLA_UT_KERNEL_INTRIN +static UT_avx512_decompress_s5_s8 sUT_avx512_decompress_s5_s8; +#endif + +class UT_avx512_decompress_s6_s8 { + public: + UT_avx512_decompress_s6_s8() { + UT_START(); + CheckISA(AVX512F); + ut<1, 48>(32); + ut<4, 48>(32); + ut<1, 48>(32, true); + ut<2, 48>(32, true); + ut<4, 48>(32, true); + } + + template + void ut(int blocksize, bool isasym = false) { + int row = blocksize * 2; + int constexpr FullRange = 1 << (6 - 1); + int constexpr col = NTILE; + printf("Test Case %s: %d %d %d\n", __FUNCTION__, row, col, blocksize); + std::vector s4_wei(row * col / 2); + avector s2_wei(row * col / 4); + + std::vector s8_wei(col * row); + std::vector s8_ref(col * row); + int blks = row / blocksize; + int row_offset = 8; + assert(blocksize % 8 == 0); + std::vector zp(col * blks); + fill_buffer_randn(zp.data(), zp.size(), int8_t(-FullRange), int8_t(FullRange - 1)); + std::vector rev(col * row); + fill_buffer_randn(s8_wei.data(), s8_wei.size(), int8_t(-FullRange), int8_t(FullRange - 1)); + + for (int i = 0; i < col * row; i += 4) { + memcpy(&s8_ref[i], &s8_wei[i], 4 * sizeof(int8_t)); + s4_wei[i / 2].x = (s8_wei[i + 0] + FullRange) & 0xf; + s4_wei[i / 2].y = (s8_wei[i + 1] + FullRange) & 0xf; + s4_wei[i / 2 + 1].x = (s8_wei[i + 2] + FullRange) & 0xf; + s4_wei[i / 2 + 1].y = (s8_wei[i + 3] + FullRange) & 0xf; + + s2_wei[i / 4].a = ((s8_wei[i + 0] + FullRange) & 0x30) >> 4; + s2_wei[i / 4].b = ((s8_wei[i + 1] + FullRange) & 0x30) >> 4; + s2_wei[i / 4].c = ((s8_wei[i + 2] + FullRange) & 0x30) >> 4; + s2_wei[i / 4].d = ((s8_wei[i + 3] + FullRange) & 0x30) >> 4; + } + if (isasym) { + for (int i = 0; i < row; i += PackRow) { + for (int j = 0; j < NTILE; j++) { + for (int ip = 0; ip < PackRow; ip++) { + s8_ref[i * NTILE + j * PackRow + ip] -= zp[i / blocksize * NTILE + j]; + } + } + } + } + + kernel::avx512f::decompress_kblock_s6_s8(s4_wei.data(), s2_wei.data(), isasym ? zp.data() : nullptr, + rev.data(), blocksize, NTILE, 0, 0, row_offset, NTILE, + cache, CacheSize); + kernel::avx512f::decompress_kblock_s6_s8( + s4_wei.data() + row_offset * NTILE / 2, s2_wei.data() + row_offset * NTILE / 4, isasym ? zp.data() : nullptr, + rev.data() + row_offset * NTILE, blocksize, NTILE, 0, row_offset, row - row_offset, NTILE, cache, CacheSize); + ut::buffer_error(s8_ref.data(), rev.data(), rev.size(), int8_t(0)); + } +}; +#ifdef BTLA_UT_KERNEL_INTRIN +static UT_avx512_decompress_s6_s8 sUT_avx512_decompress_s6_s8; +#endif + class UT_avx512_decompress_s2_s8 { public: UT_avx512_decompress_s2_s8() { @@ -323,6 +462,36 @@ class UT_avx512_gemv { ut_3bit_fp32<1>(48, 128, 32, false); ut_3bit_fp32<4>(48, 128, 32, true); ut_3bit_fp32<4>(48, 128, 32, false); + + ut_6bit_fp32<1>(48, 128, 32, true); + ut_6bit_fp32<1>(48, 128, 32, false); + ut_6bit_fp32<4>(48, 128, 32, true); + ut_6bit_fp32<4>(48, 128, 32, false); + + ut_6bit_u8s8<1>(48, 128, 32, true); + ut_6bit_u8s8<1>(48, 128, 32, false); + ut_6bit_u8s8<4>(48, 128, 32, true); + ut_6bit_u8s8<4>(48, 128, 32, false); + + ut_6bit_s8s8<1>(48, 128, 32, true); + ut_6bit_s8s8<1>(48, 128, 32, false); + ut_6bit_s8s8<4>(48, 128, 32, true); + ut_6bit_s8s8<4>(48, 128, 32, false); + + ut_5bit_fp32<1>(48, 128, 32, true); + ut_5bit_fp32<1>(48, 128, 32, false); + ut_5bit_fp32<4>(48, 128, 32, true); + ut_5bit_fp32<4>(48, 128, 32, false); + + ut_5bit_u8s8<1>(48, 128, 32, true); + ut_5bit_u8s8<1>(48, 128, 32, false); + ut_5bit_u8s8<4>(48, 128, 32, true); + ut_5bit_u8s8<4>(48, 128, 32, false); + + ut_5bit_s8s8<1>(48, 128, 32, true); + ut_5bit_s8s8<1>(48, 128, 32, false); + ut_5bit_s8s8<4>(48, 128, 32, true); + ut_5bit_s8s8<4>(48, 128, 32, false); } template @@ -727,13 +896,438 @@ class UT_avx512_gemv { B, Cf32.data(), n, k, kblock, cache, CacheSize); buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); } + + template + void ut_5bit_fp32(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b1(n * k / 8); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b1.data(), b1.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(Af32.data(), Af32.size(), -0.5f, 0.5f); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-16), int8_t(15)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector b8(n * k); + kernel::ref::decompress_s5_s8(b4.data(), b1.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 1) { + int bid = i / kblock; + for (int j = 0; j < n; j += 4) { + if (iasym) { + Bf32[(i)*n + j + 0] = (b8[(i)*n + j + 0] - bzp[bid * n + j + 0]) * scaleb[bid * n + j + 0]; + Bf32[(i)*n + j + 1] = (b8[(i)*n + j + 1] - bzp[bid * n + j + 1]) * scaleb[bid * n + j + 1]; + Bf32[(i)*n + j + 2] = (b8[(i)*n + j + 2] - bzp[bid * n + j + 2]) * scaleb[bid * n + j + 2]; + Bf32[(i)*n + j + 3] = (b8[(i)*n + j + 3] - bzp[bid * n + j + 3]) * scaleb[bid * n + j + 3]; + } else { + Bf32[(i)*n + j + 0] = (b8[(i)*n + j + 0]) * scaleb[bid * n + j + 0]; + Bf32[(i)*n + j + 1] = (b8[(i)*n + j + 1]) * scaleb[bid * n + j + 1]; + Bf32[(i)*n + j + 2] = (b8[(i)*n + j + 2]) * scaleb[bid * n + j + 2]; + Bf32[(i)*n + j + 3] = (b8[(i)*n + j + 3]) * scaleb[bid * n + j + 3]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), nullptr, (uint8_t*)b1.data(), scaleb.data(), iasym ? bzp.data() : nullptr, 5, n}; + kernel::avx512f::gemv_5bit_fp32_fp32(Af32.data(), k, B, Cf32.data(), n, k, kblock, cache, + CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } + + template + void ut_5bit_u8s8(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b1(n * k / 8); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b1.data(), b1.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-16), int8_t(15)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector A(MTILE * k), azp(MTILE * blks); + fill_buffer_randn(A.data(), A.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(azp.data(), azp.size(), uint8_t(100), uint8_t(150)); + fill_buffer_randn(scalea.data(), scalea.size(), 0.01f, 0.02f); + for (int im = 0; im < MTILE; im++) { + for (int i = 0; i < k; i += 4) { + int bid = i / kblock + im * blks; + for (int j = 0; j < 4; j++) { + Af32[im * k + i + j] = (int(A[im * k + i + j]) - azp[bid]) * scalea[bid]; + } + } + } + + avector b8(n * k); + kernel::ref::decompress_s5_s8(b4.data(), b1.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 4) { + int bid = i / kblock; + for (int j = 0; j < n; j += 1) { + if (iasym) { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + } else { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0])) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1])) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2])) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3])) * scaleb[bid * n + j]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), nullptr, (uint8_t*)b1.data(), scaleb.data(), iasym ? bzp.data() : nullptr, 2, n}; + kernel::avx512f::vnni::gemv_5bit_u8s8_fp32({A.data(), scalea.data(), azp.data(), k, blks}, B, + Cf32.data(), n, k, kblock, cache, CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } + + template + void ut_5bit_s8s8(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b1(n * k / 8); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b1.data(), b1.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-16), int8_t(15)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector A(MTILE * k); + fill_buffer_randn(A.data(), A.size(), int8_t(0), int8_t(127)); + fill_buffer_randn(scalea.data(), scalea.size(), 0.01f, 0.02f); + for (int im = 0; im < MTILE; im++) { + for (int i = 0; i < k; i += 4) { + int bid = i / kblock + im * blks; + for (int j = 0; j < 4; j++) { + Af32[im * k + i + j] = (int(A[im * k + i + j])) * scalea[bid]; + } + } + } + + avector b8(n * k); + kernel::ref::decompress_s5_s8(b4.data(), b1.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 4) { + int bid = i / kblock; + for (int j = 0; j < n; j += 1) { + if (iasym) { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + } else { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0])) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1])) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2])) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3])) * scaleb[bid * n + j]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), nullptr, (uint8_t*)b1.data(), scaleb.data(), iasym ? bzp.data() : nullptr, 5, n}; + kernel::avx512f::vnni::gemv_5bit_s8s8_fp32({(uint8_t*)A.data(), scalea.data(), nullptr, k, blks}, + B, Cf32.data(), n, k, kblock, cache, CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } + + template + void ut_6bit_fp32(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b2(n * k / 4); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b2.data(), b2.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(Af32.data(), Af32.size(), -0.5f, 0.5f); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-32), int8_t(31)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector b8(n * k); + kernel::ref::decompress_s6_s8(b4.data(), b2.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 1) { + int bid = i / kblock; + for (int j = 0; j < n; j += 4) { + if (iasym) { + Bf32[(i)*n + j + 0] = (b8[(i)*n + j + 0] - bzp[bid * n + j + 0]) * scaleb[bid * n + j + 0]; + Bf32[(i)*n + j + 1] = (b8[(i)*n + j + 1] - bzp[bid * n + j + 1]) * scaleb[bid * n + j + 1]; + Bf32[(i)*n + j + 2] = (b8[(i)*n + j + 2] - bzp[bid * n + j + 2]) * scaleb[bid * n + j + 2]; + Bf32[(i)*n + j + 3] = (b8[(i)*n + j + 3] - bzp[bid * n + j + 3]) * scaleb[bid * n + j + 3]; + } else { + Bf32[(i)*n + j + 0] = (b8[(i)*n + j + 0]) * scaleb[bid * n + j + 0]; + Bf32[(i)*n + j + 1] = (b8[(i)*n + j + 1]) * scaleb[bid * n + j + 1]; + Bf32[(i)*n + j + 2] = (b8[(i)*n + j + 2]) * scaleb[bid * n + j + 2]; + Bf32[(i)*n + j + 3] = (b8[(i)*n + j + 3]) * scaleb[bid * n + j + 3]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), (uint8_t*)b2.data(), nullptr, scaleb.data(), iasym ? bzp.data() : nullptr, 6, n}; + kernel::avx512f::gemv_6bit_fp32_fp32(Af32.data(), k, B, Cf32.data(), n, k, kblock, cache, + CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } + + template + void ut_6bit_u8s8(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b2(n * k / 4); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b2.data(), b2.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-32), int8_t(31)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector A(MTILE * k), azp(MTILE * blks); + fill_buffer_randn(A.data(), A.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(azp.data(), azp.size(), uint8_t(100), uint8_t(150)); + fill_buffer_randn(scalea.data(), scalea.size(), 0.01f, 0.02f); + for (int im = 0; im < MTILE; im++) { + for (int i = 0; i < k; i += 4) { + int bid = i / kblock + im * blks; + for (int j = 0; j < 4; j++) { + Af32[im * k + i + j] = (int(A[im * k + i + j]) - azp[bid]) * scalea[bid]; + } + } + } + + avector b8(n * k); + kernel::ref::decompress_s6_s8(b4.data(), b2.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 4) { + int bid = i / kblock; + for (int j = 0; j < n; j += 1) { + if (iasym) { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + } else { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0])) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1])) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2])) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3])) * scaleb[bid * n + j]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), (uint8_t*)b2.data(), nullptr, scaleb.data(), iasym ? bzp.data() : nullptr, 6, n}; + kernel::avx512f::vnni::gemv_6bit_u8s8_fp32({A.data(), scalea.data(), azp.data(), k, blks}, B, + Cf32.data(), n, k, kblock, cache, CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } + + template + void ut_6bit_s8s8(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b2(n * k / 4); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b2.data(), b2.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-32), int8_t(31)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector A(MTILE * k); + fill_buffer_randn(A.data(), A.size(), int8_t(0), int8_t(127)); + fill_buffer_randn(scalea.data(), scalea.size(), 0.01f, 0.02f); + for (int im = 0; im < MTILE; im++) { + for (int i = 0; i < k; i += 4) { + int bid = i / kblock + im * blks; + for (int j = 0; j < 4; j++) { + Af32[im * k + i + j] = (int(A[im * k + i + j])) * scalea[bid]; + } + } + } + + avector b8(n * k); + kernel::ref::decompress_s6_s8(b4.data(), b2.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 4) { + int bid = i / kblock; + for (int j = 0; j < n; j += 1) { + if (iasym) { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + } else { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0])) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1])) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2])) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3])) * scaleb[bid * n + j]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), (uint8_t*)b2.data(), nullptr, scaleb.data(), iasym ? bzp.data() : nullptr, 5, n}; + kernel::avx512f::vnni::gemv_6bit_s8s8_fp32({(uint8_t*)A.data(), scalea.data(), nullptr, k, blks}, + B, Cf32.data(), n, k, kblock, cache, CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } +}; +#ifdef BTLA_UT_KERNEL_INTRIN +UT_avx512_gemv sUT_avx512_gemv; +#endif +#endif + +#if CompileAVX2() +class UT_avx2_decompress_s6_s8 { + public: + UT_avx2_decompress_s6_s8() { + UT_START(); + CheckISA(AVX2); + ut<1, 24>(32); + ut<4, 24>(32); + ut<1, 24>(32, true); + ut<2, 24>(32, true); + ut<4, 24>(32, true); + } + + template + void ut(int blocksize, bool isasym = false) { + int row = blocksize * 2; + int constexpr FullRange = 1 << (6 - 1); + int constexpr col = NTILE; + printf("Test Case %s: %d %d %d\n", __FUNCTION__, row, col, blocksize); + std::vector s4_wei(row * col / 2); + avector s2_wei(row * col / 4); + + std::vector s8_wei(col * row); + std::vector s8_ref(col * row); + int blks = row / blocksize; + int row_offset = 8; + assert(blocksize % 8 == 0); + std::vector zp(col * blks); + fill_buffer_randn(zp.data(), zp.size(), int8_t(-FullRange), int8_t(FullRange - 1)); + std::vector rev(col * row); + fill_buffer_randn(s8_wei.data(), s8_wei.size(), int8_t(-FullRange), int8_t(FullRange - 1)); + + for (int i = 0; i < col * row; i += 4) { + memcpy(&s8_ref[i], &s8_wei[i], 4 * sizeof(int8_t)); + s4_wei[i / 2].x = (s8_wei[i + 0] + FullRange) & 0xf; + s4_wei[i / 2].y = (s8_wei[i + 1] + FullRange) & 0xf; + s4_wei[i / 2 + 1].x = (s8_wei[i + 2] + FullRange) & 0xf; + s4_wei[i / 2 + 1].y = (s8_wei[i + 3] + FullRange) & 0xf; + + s2_wei[i / 4].a = ((s8_wei[i + 0] + FullRange) & 0x30) >> 4; + s2_wei[i / 4].b = ((s8_wei[i + 1] + FullRange) & 0x30) >> 4; + s2_wei[i / 4].c = ((s8_wei[i + 2] + FullRange) & 0x30) >> 4; + s2_wei[i / 4].d = ((s8_wei[i + 3] + FullRange) & 0x30) >> 4; + } + if (isasym) { + for (int i = 0; i < row; i += PackRow) { + for (int j = 0; j < NTILE; j++) { + for (int ip = 0; ip < PackRow; ip++) { + s8_ref[i * NTILE + j * PackRow + ip] -= zp[i / blocksize * NTILE + j]; + } + } + } + } + + kernel::avx2::decompress_kblock_s6_s8(s4_wei.data(), s2_wei.data(), isasym ? zp.data() : nullptr, + rev.data(), blocksize, NTILE, 0, 0, row_offset, NTILE, cache, + CacheSize); + kernel::avx2::decompress_kblock_s6_s8( + s4_wei.data() + row_offset * NTILE / 2, s2_wei.data() + row_offset * NTILE / 4, isasym ? zp.data() : nullptr, + rev.data() + row_offset * NTILE, blocksize, NTILE, 0, row_offset, row - row_offset, NTILE, cache, CacheSize); + ut::buffer_error(s8_ref.data(), rev.data(), rev.size(), int8_t(0)); + } +}; +#ifdef BTLA_UT_KERNEL_INTRIN +static UT_avx2_decompress_s6_s8 sUT_avx2_decompress_s6_s8; +#endif + +class UT_avx2_decompress_s5_s8 { + public: + UT_avx2_decompress_s5_s8() { + UT_START(); + CheckISA(AVX2); + ut<1, 24>(32); + ut<4, 24>(32); + ut<1, 24>(32, true); + ut<2, 24>(32, true); + ut<4, 24>(32, true); + } + + template + void ut(int blocksize, bool isasym = false) { + int row = blocksize * 2; + int constexpr col = NTILE; + printf("Test Case %s: %d %d %d\n", __FUNCTION__, row, col, blocksize); + std::vector s4_wei(row * col / 2); + avector s1_wei(row * col / 8); + + std::vector s8_wei(col * row); + std::vector s8_ref(col * row); + int blks = row / blocksize; + int row_offset = 8; + assert(blocksize % 8 == 0); + std::vector zp(col * blks); + fill_buffer_randn(zp.data(), zp.size(), int8_t(-16), int8_t(15)); + std::vector rev(col * row); + fill_buffer_randn(s8_wei.data(), s8_wei.size(), int8_t(-16), int8_t(15)); + + for (int i = 0; i < col * row; i += 8) { + memcpy(&s8_ref[i], &s8_wei[i], 8 * sizeof(int8_t)); + s4_wei[i / 2].x = (s8_wei[i + 0] + 16) & 0xf; + s4_wei[i / 2].y = (s8_wei[i + 1] + 16) & 0xf; + s4_wei[i / 2 + 1].x = (s8_wei[i + 2] + 16) & 0xf; + s4_wei[i / 2 + 1].y = (s8_wei[i + 3] + 16) & 0xf; + s4_wei[i / 2 + 2].x = (s8_wei[i + 4] + 16) & 0xf; + s4_wei[i / 2 + 2].y = (s8_wei[i + 5] + 16) & 0xf; + s4_wei[i / 2 + 3].x = (s8_wei[i + 6] + 16) & 0xf; + s4_wei[i / 2 + 3].y = (s8_wei[i + 7] + 16) & 0xf; + + s1_wei[i / 8].a = ((s8_wei[i + 0] + 16) & 0x10) >> 4; + s1_wei[i / 8].b = ((s8_wei[i + 1] + 16) & 0x10) >> 4; + s1_wei[i / 8].c = ((s8_wei[i + 2] + 16) & 0x10) >> 4; + s1_wei[i / 8].d = ((s8_wei[i + 3] + 16) & 0x10) >> 4; + s1_wei[i / 8].e = ((s8_wei[i + 4] + 16) & 0x10) >> 4; + s1_wei[i / 8].f = ((s8_wei[i + 5] + 16) & 0x10) >> 4; + s1_wei[i / 8].g = ((s8_wei[i + 6] + 16) & 0x10) >> 4; + s1_wei[i / 8].h = ((s8_wei[i + 7] + 16) & 0x10) >> 4; + } + if (isasym) { + for (int i = 0; i < row; i += PackRow) { + for (int j = 0; j < NTILE; j++) { + for (int ip = 0; ip < PackRow; ip++) { + s8_ref[i * NTILE + j * PackRow + ip] -= zp[i / blocksize * NTILE + j]; + } + } + } + } + + kernel::avx2::decompress_kblock_s5_s8(s4_wei.data(), s1_wei.data(), isasym ? zp.data() : nullptr, + rev.data(), blocksize, NTILE, 0, 0, row_offset, NTILE, cache, + CacheSize); + kernel::avx2::decompress_kblock_s5_s8( + s4_wei.data() + row_offset * NTILE / 2, s1_wei.data() + row_offset * NTILE / 8, isasym ? zp.data() : nullptr, + rev.data() + row_offset * NTILE, blocksize, NTILE, 0, row_offset, row - row_offset, NTILE, cache, CacheSize); + ut::buffer_error(s8_ref.data(), rev.data(), rev.size(), int8_t(0)); + } }; #ifdef BTLA_UT_KERNEL_INTRIN -UT_avx512_gemv sUT_avx512_gemv; -#endif +static UT_avx2_decompress_s5_s8 sUT_avx2_decompress_s5_s8; #endif -#if CompileAVX2() class UT_avx2_decompress_s4_s8 { public: UT_avx2_decompress_s4_s8() { @@ -997,6 +1591,9 @@ class UT_avx2_decompress_s4_fp { #ifdef BTLA_UT_KERNEL_INTRIN static UT_avx2_decompress_s4_fp sUT_avx2_decompress_s4_fp; #endif + +// s_fp share the same process: s->s8->fp, so it's not necessary to test all s_fp cases. +// test s_s8 is just fine. #endif #if CompileAVXVNNI() @@ -1049,6 +1646,36 @@ class UT_avx2_gemv { ut_3bit_s8s8<1>(24, 128, 32, false); ut_3bit_s8s8<4>(24, 128, 32, true); ut_3bit_s8s8<4>(24, 128, 32, false); + + ut_6bit_fp32<1>(24, 128, 32, true); + ut_6bit_fp32<1>(24, 128, 32, false); + ut_6bit_fp32<4>(24, 128, 32, true); + ut_6bit_fp32<4>(24, 128, 32, false); + + ut_6bit_u8s8<1>(24, 128, 32, true); + ut_6bit_u8s8<1>(24, 128, 32, false); + ut_6bit_u8s8<4>(24, 128, 32, true); + ut_6bit_u8s8<4>(24, 128, 32, false); + + ut_6bit_s8s8<1>(24, 128, 32, true); + ut_6bit_s8s8<1>(24, 128, 32, false); + ut_6bit_s8s8<4>(24, 128, 32, true); + ut_6bit_s8s8<4>(24, 128, 32, false); + + ut_5bit_fp32<1>(24, 128, 32, true); + ut_5bit_fp32<1>(24, 128, 32, false); + ut_5bit_fp32<4>(24, 128, 32, true); + ut_5bit_fp32<4>(24, 128, 32, false); + + ut_5bit_u8s8<1>(24, 128, 32, true); + ut_5bit_u8s8<1>(24, 128, 32, false); + ut_5bit_u8s8<4>(24, 128, 32, true); + ut_5bit_u8s8<4>(24, 128, 32, false); + + ut_5bit_s8s8<1>(24, 128, 32, true); + ut_5bit_s8s8<1>(24, 128, 32, false); + ut_5bit_s8s8<4>(24, 128, 32, true); + ut_5bit_s8s8<4>(24, 128, 32, false); } template @@ -1450,6 +2077,290 @@ class UT_avx2_gemv { Cf32.data(), n, k, kblock, cache, CacheSize); buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); } + + template + void ut_6bit_fp32(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b2(n * k / 4); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b2.data(), b2.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(Af32.data(), Af32.size(), -0.5f, 0.5f); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-32), int8_t(31)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector b8(n * k); + kernel::ref::decompress_s6_s8(b4.data(), b2.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 1) { + int bid = i / kblock; + for (int j = 0; j < n; j += 4) { + if (iasym) { + Bf32[(i)*n + j + 0] = (b8[(i)*n + j + 0] - bzp[bid * n + j + 0]) * scaleb[bid * n + j + 0]; + Bf32[(i)*n + j + 1] = (b8[(i)*n + j + 1] - bzp[bid * n + j + 1]) * scaleb[bid * n + j + 1]; + Bf32[(i)*n + j + 2] = (b8[(i)*n + j + 2] - bzp[bid * n + j + 2]) * scaleb[bid * n + j + 2]; + Bf32[(i)*n + j + 3] = (b8[(i)*n + j + 3] - bzp[bid * n + j + 3]) * scaleb[bid * n + j + 3]; + } else { + Bf32[(i)*n + j + 0] = (b8[(i)*n + j + 0]) * scaleb[bid * n + j + 0]; + Bf32[(i)*n + j + 1] = (b8[(i)*n + j + 1]) * scaleb[bid * n + j + 1]; + Bf32[(i)*n + j + 2] = (b8[(i)*n + j + 2]) * scaleb[bid * n + j + 2]; + Bf32[(i)*n + j + 3] = (b8[(i)*n + j + 3]) * scaleb[bid * n + j + 3]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), (uint8_t*)b2.data(), nullptr, scaleb.data(), iasym ? bzp.data() : nullptr, 6, n}; + kernel::avx2::gemv_6bit_fp32_fp32(Af32.data(), k, B, Cf32.data(), n, k, kblock, cache, CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } + + template + void ut_6bit_u8s8(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b2(n * k / 4); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b2.data(), b2.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-32), int8_t(31)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector A(MTILE * k), azp(MTILE * blks); + fill_buffer_randn(A.data(), A.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(azp.data(), azp.size(), uint8_t(100), uint8_t(150)); + fill_buffer_randn(scalea.data(), scalea.size(), 0.01f, 0.02f); + for (int im = 0; im < MTILE; im++) { + for (int i = 0; i < k; i += 4) { + int bid = i / kblock + im * blks; + for (int j = 0; j < 4; j++) { + Af32[im * k + i + j] = (int(A[im * k + i + j]) - azp[bid]) * scalea[bid]; + } + } + } + + avector b8(n * k); + kernel::ref::decompress_s6_s8(b4.data(), b2.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 4) { + int bid = i / kblock; + for (int j = 0; j < n; j += 1) { + if (iasym) { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + } else { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0])) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1])) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2])) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3])) * scaleb[bid * n + j]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), (uint8_t*)b2.data(), nullptr, scaleb.data(), iasym ? bzp.data() : nullptr, 6, n}; + kernel::avx2::vnni::gemv_6bit_u8s8_fp32({A.data(), scalea.data(), azp.data(), k, blks}, B, + Cf32.data(), n, k, kblock, cache, CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } + + template + void ut_6bit_s8s8(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b2(n * k / 4); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b2.data(), b2.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-32), int8_t(31)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector A(MTILE * k); + fill_buffer_randn(A.data(), A.size(), int8_t(0), int8_t(127)); + fill_buffer_randn(scalea.data(), scalea.size(), 0.01f, 0.02f); + for (int im = 0; im < MTILE; im++) { + for (int i = 0; i < k; i += 4) { + int bid = i / kblock + im * blks; + for (int j = 0; j < 4; j++) { + Af32[im * k + i + j] = (int(A[im * k + i + j])) * scalea[bid]; + } + } + } + + avector b8(n * k); + kernel::ref::decompress_s6_s8(b4.data(), b2.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 4) { + int bid = i / kblock; + for (int j = 0; j < n; j += 1) { + if (iasym) { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + } else { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0])) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1])) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2])) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3])) * scaleb[bid * n + j]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), (uint8_t*)b2.data(), nullptr, scaleb.data(), iasym ? bzp.data() : nullptr, 5, n}; + kernel::avx2::vnni::gemv_6bit_s8s8_fp32({(uint8_t*)A.data(), scalea.data(), nullptr, k, blks}, B, + Cf32.data(), n, k, kblock, cache, CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } + + template + void ut_5bit_fp32(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b1(n * k / 8); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b1.data(), b1.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(Af32.data(), Af32.size(), -0.5f, 0.5f); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-16), int8_t(15)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector b8(n * k); + kernel::ref::decompress_s5_s8(b4.data(), b1.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 1) { + int bid = i / kblock; + for (int j = 0; j < n; j += 4) { + if (iasym) { + Bf32[(i)*n + j + 0] = (b8[(i)*n + j + 0] - bzp[bid * n + j + 0]) * scaleb[bid * n + j + 0]; + Bf32[(i)*n + j + 1] = (b8[(i)*n + j + 1] - bzp[bid * n + j + 1]) * scaleb[bid * n + j + 1]; + Bf32[(i)*n + j + 2] = (b8[(i)*n + j + 2] - bzp[bid * n + j + 2]) * scaleb[bid * n + j + 2]; + Bf32[(i)*n + j + 3] = (b8[(i)*n + j + 3] - bzp[bid * n + j + 3]) * scaleb[bid * n + j + 3]; + } else { + Bf32[(i)*n + j + 0] = (b8[(i)*n + j + 0]) * scaleb[bid * n + j + 0]; + Bf32[(i)*n + j + 1] = (b8[(i)*n + j + 1]) * scaleb[bid * n + j + 1]; + Bf32[(i)*n + j + 2] = (b8[(i)*n + j + 2]) * scaleb[bid * n + j + 2]; + Bf32[(i)*n + j + 3] = (b8[(i)*n + j + 3]) * scaleb[bid * n + j + 3]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), nullptr, (uint8_t*)b1.data(), scaleb.data(), iasym ? bzp.data() : nullptr, 5, n}; + kernel::avx2::gemv_5bit_fp32_fp32(Af32.data(), k, B, Cf32.data(), n, k, kblock, cache, CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } + + template + void ut_5bit_u8s8(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b1(n * k / 8); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b1.data(), b1.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-16), int8_t(15)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector A(MTILE * k), azp(MTILE * blks); + fill_buffer_randn(A.data(), A.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(azp.data(), azp.size(), uint8_t(100), uint8_t(150)); + fill_buffer_randn(scalea.data(), scalea.size(), 0.01f, 0.02f); + for (int im = 0; im < MTILE; im++) { + for (int i = 0; i < k; i += 4) { + int bid = i / kblock + im * blks; + for (int j = 0; j < 4; j++) { + Af32[im * k + i + j] = (int(A[im * k + i + j]) - azp[bid]) * scalea[bid]; + } + } + } + + avector b8(n * k); + kernel::ref::decompress_s5_s8(b4.data(), b1.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 4) { + int bid = i / kblock; + for (int j = 0; j < n; j += 1) { + if (iasym) { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + } else { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0])) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1])) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2])) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3])) * scaleb[bid * n + j]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), nullptr, (uint8_t*)b1.data(), scaleb.data(), iasym ? bzp.data() : nullptr, 2, n}; + kernel::avx2::vnni::gemv_5bit_u8s8_fp32({A.data(), scalea.data(), azp.data(), k, blks}, B, + Cf32.data(), n, k, kblock, cache, CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } + + template + void ut_5bit_s8s8(int n, int k, int kblock, bool iasym) { + printf("Test Case %s_%d: %d %d %d Asym:%d\n", __FUNCTION__, MTILE, n, k, kblock, iasym); + int blks = k / kblock; + avector b4(n * k / 2); + avector b1(n * k / 8); + avector scaleb(n * blks), scalea(MTILE * blks); + avector bzp(n * blks); + avector Af32(MTILE * k), Bf32(n * k), Cf32(MTILE * n), Cref(MTILE * n); + fill_buffer_randn((uint8_t*)b4.data(), b4.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn((uint8_t*)b1.data(), b1.size(), uint8_t(0), uint8_t(255)); + fill_buffer_randn(bzp.data(), bzp.size(), int8_t(-16), int8_t(15)); + fill_buffer_randn(scaleb.data(), scaleb.size(), 0.01f, 0.02f); + avector A(MTILE * k); + fill_buffer_randn(A.data(), A.size(), int8_t(0), int8_t(127)); + fill_buffer_randn(scalea.data(), scalea.size(), 0.01f, 0.02f); + for (int im = 0; im < MTILE; im++) { + for (int i = 0; i < k; i += 4) { + int bid = i / kblock + im * blks; + for (int j = 0; j < 4; j++) { + Af32[im * k + i + j] = (int(A[im * k + i + j])) * scalea[bid]; + } + } + } + + avector b8(n * k); + kernel::ref::decompress_s5_s8(b4.data(), b1.data(), b8.data(), b8.size(), cache, CacheSize); + for (int i = 0; i < k; i += 4) { + int bid = i / kblock; + for (int j = 0; j < n; j += 1) { + if (iasym) { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3]) - bzp[bid * n + j]) * scaleb[bid * n + j]; + } else { + Bf32[(i + 0) * n + j] = (int(b8[i * n + j * 4 + 0])) * scaleb[bid * n + j]; + Bf32[(i + 1) * n + j] = (int(b8[i * n + j * 4 + 1])) * scaleb[bid * n + j]; + Bf32[(i + 2) * n + j] = (int(b8[i * n + j * 4 + 2])) * scaleb[bid * n + j]; + Bf32[(i + 3) * n + j] = (int(b8[i * n + j * 4 + 3])) * scaleb[bid * n + j]; + } + } + } + gemmref_fp32fp32fp32(MTILE, n, k, Af32.data(), Bf32.data(), Cref.data(), k, n, n); + utils::GemvParamB B{ + (uint8_t*)b4.data(), nullptr, (uint8_t*)b1.data(), scaleb.data(), iasym ? bzp.data() : nullptr, 5, n}; + kernel::avx2::vnni::gemv_5bit_s8s8_fp32({(uint8_t*)A.data(), scalea.data(), nullptr, k, blks}, B, + Cf32.data(), n, k, kblock, cache, CacheSize); + buffer_error(Cref.data(), Cf32.data(), Cref.size(), FP32_ERR); + } }; #ifdef BTLA_UT_KERNEL_INTRIN UT_avx2_gemv sUT_avx2_gemv; diff --git a/docs/advanced_usage.md b/docs/advanced_usage.md index 278901b42..904db0bb5 100644 --- a/docs/advanced_usage.md +++ b/docs/advanced_usage.md @@ -6,9 +6,9 @@ Argument description of run.py ([supported MatMul combinations](#supported-matri | Argument | Description | | -------------- | --------------------------------------------------------------------- | | model | Directory containing model file or model id: String | -| --weight_dtype | Data type of quantized weight: int4/int8/fp8(=fp8_e4m3)/fp8_e5m2/fp4(=fp4e2m1)/nf4 (default int4) | +| --weight_dtype | Data type of quantized weight: int4/int3/int2/int5/int6/int8/fp8(=fp8_e4m3)/fp8_e5m2/fp4(=fp4e2m1)/nf4 (default int4) | | --alg | Quantization algorithm: sym/asym (default sym) | -| --group_size | Group size: Int, 32/128/-1 (per channel) (default: 32) | +| --group_size | Group size: Int, 16/32/64/128/-1 (per channel) (default: 32) | | --scale_dtype | Data type of scales: fp32/bf16/fp8 (default fp32) | | --compute_dtype | Data type of Gemm computation: int8/bf16/fp16/fp32 (default: fp32) | | --use_ggml | Enable ggml for quantization and inference | @@ -60,16 +60,16 @@ Argument description of quantize.py ([supported MatMul combinations](#supported- | --build_dir | Path to the build file: String | | --config | Path to the configuration file: String (default: "") | | --nthread | Number of threads to use: Int (default: 1) | -| --weight_dtype | Data type of quantized weight: int4/int8/fp8(=fp8_e4m3)/fp8_e5m2/fp4(=fp4_e2m1)/nf4 (default: int4) | +| --weight_dtype | Data type of quantized weight: int4/int3/int2/int5/int6/int8/fp8(=fp8_e4m3)/fp8_e5m2/fp4(=fp4_e2m1)/nf4 (default: int4) | | --alg | Quantization algorithm to use: sym/asym (default: sym) | -| --group_size | Group size: Int 32/128/-1 (per channel) (default: 32) | +| --group_size | Group size: Int 16/32/64/128/-1 (per channel) (default: 32) | | --scale_dtype | Data type of scales: bf16/fp32/fp8 (default: fp32) | | --compute_dtype | Data type of Gemm computation: int8/bf16/fp16/fp32 (default: fp32)| | --use_ggml | Enable ggml for quantization and inference | #### Supported Matrix Multiplication Data Types Combinations -Our Neural Speed supports INT4 / INT8 / FP8 (E4M3, E5M2) / FP4 (E2M1) / NF4 weight-only quantization and FP32 / FP16 / BF16 / INT8 computation forward matmul on the Intel platforms. Here are the all supported data types combinations for matmul operations (quantization and forward). +Our Neural Speed supports INT4 / INT3 / INT2 / INT5 / INT6 / INT8 / FP8 (E4M3, E5M2) / FP4 (E2M1) / NF4 weight-only quantization and FP32 / FP16 / BF16 / INT8 computation forward matmul on the Intel platforms. Here are the all supported data types combinations for matmul operations (quantization and forward). > This table will be updated frequently due to active development. For details you can refer to [BesTLA](../bestla#weight-only) | Weight dtype | Compute dtype (default value) | Scale dtype (default value) | Quantization scheme (default value) | @@ -77,6 +77,10 @@ Our Neural Speed supports INT4 / INT8 / FP8 (E4M3, E5M2) / FP4 (E2M1) / NF4 wei | FP32 | FP32 | NA | NA | | INT8 | INT8 / BF16 / FP16 / FP32 (FP32) | BF16 / FP32 (FP32) | sym / asym (sym) | | INT4 | INT8 / BF16 / FP16 / FP32 (FP32) | BF16 / FP32 (FP32) | sym / asym (sym) | +| INT3 | INT8 / BF16 / FP16 / FP32 (FP32) | BF16 / FP32 (FP32) | sym / asym (sym) | +| INT2 | INT8 / BF16 / FP16 / FP32 (FP32) | BF16 / FP32 (FP32) | sym / asym (sym) | +| INT5 | INT8 / BF16 / FP16 / FP32 (FP32) | BF16 / FP32 (FP32) | sym / asym (sym) | +| INT6 | INT8 / BF16 / FP16 / FP32 (FP32) | BF16 / FP32 (FP32) | sym / asym (sym) | | FP8 (E4M3, E5M2) | BF16 / FP16 / FP32 (FP32) | FP8 (FP8) | sym (sym) | | FP4 (E2M1) | BF16 / FP16 / FP32 (FP32) | BF16 / FP32 (FP32) | sym (sym) | | NF4 | BF16 / FP16 / FP32 (FP32) | BF16 / FP32 (FP32) | sym (sym) | diff --git a/neural_speed/application/common.cpp b/neural_speed/application/common.cpp index b9e7d9459..9a0313598 100644 --- a/neural_speed/application/common.cpp +++ b/neural_speed/application/common.cpp @@ -649,7 +649,7 @@ void quant_print_usage(int argc, char** argv, const quant_params& params) { fprintf(stderr, " --nthread number of threads to use (default: 1)\n"); fprintf(stderr, " --weight_dtype number of bits to use for quantization: int4/int8/fp8_e4m3/fp8_e5m2/" - "fp4_e2m1/nf4/int3 (default: int4)\n"); + "fp4_e2m1/nf4/int3/int2/int5/int6 (default: int4)\n"); fprintf(stderr, " --alg quantization algorithm to use: sym/asym (default: sym)\n"); fprintf(stderr, " --group_size group size: 32/128/-1 (per channel) (default: 32)\n"); fprintf(stderr, " --scale_dtype fp32/bf16/fp8 type for scales (default: fp32)\n"); diff --git a/neural_speed/core/README.md b/neural_speed/core/README.md index 9e36ddc65..22546cf05 100644 --- a/neural_speed/core/README.md +++ b/neural_speed/core/README.md @@ -24,6 +24,8 @@ dtype | algo | group size int4 | symmetric or asymmetric | multiplier of 8, -11 int3 | symmetric or asymmetric | multiplier of 8, -11 int2 | symmetric or asymmetric | multiplier of 8, -11 +int5 | symmetric or asymmetric | multiplier of 8, -11 +int6 | symmetric or asymmetric | multiplier of 8, -11 int8 | symmetric | multiplier of 8, -11 fp4 | | multiplier of 8 nf4 | | multiplier of 8 @@ -71,16 +73,16 @@ Referring [the fused-attention doc for details](../docs/fused_attention.md#suppo -## Fastest Configuration for CPUs +## Recommended Configuration for CPUs codename | weight config | runtime ISA ---|---|--- -Sapphire Rapids
Emerald Rapids | any int4
group size=-1
compute type=int8 | AMX_INT8 -Ice Lake
Cascade Lake
Cooper Lake
Tiger Lake
Rocket Lake | any int4
group size=-1
compute type=int8 | AVX512_VNNI -Skylake | any 4bits
group size=-1
compute type=fp32 | AVX512F -Alder Lake (12th Gen)
Raptor Lake (13th and 14th Gen)|any 4bits
group size=-1
compute type=int8 | AVX_VNNI -Older architecture (before 12th Gen)| any 4bits
group size=-1
compute type=int8 | AVX2 +Sapphire Rapids
Emerald Rapids | sym int3
group size=128
compute type=int8 | AMX_INT8 +Ice Lake
Cascade Lake
Cooper Lake
Tiger Lake
Rocket Lake | sym int3
group size=128
compute type=int8 | AVX512_VNNI +Skylake | sym int3
group size=128
compute type=fp32 | AVX512F +Alder Lake (12th Gen)
Raptor Lake (13th and 14th Gen)| sym int3
group size=128
compute type=int8 | AVX_VNNI +Older architecture (before 12th Gen)| sym int3
group size=128
compute type=int8 | AVX2 NOTE: -1. group_size=-1 requires the INC's finetuned model, or it may have lower accuracy than small group sizes. +1. group_size=-1 requires the INC's finetuned model, or it may have lower accuracy than small group sizes. It has the smallest model size, and the fastest first-token performance. 2. group_size=128 is a balance of accuracy and speed if you want RTN quantization only. 3. group_size=32, scale_dtype=bf16, compute_dtype=int8, alg=sym equals llama.cpp's Q4_0. diff --git a/neural_speed/models/model_utils/quant_config.h b/neural_speed/models/model_utils/quant_config.h index 469603709..41fbf6cd0 100644 --- a/neural_speed/models/model_utils/quant_config.h +++ b/neural_speed/models/model_utils/quant_config.h @@ -18,7 +18,7 @@ #include "core/data_types.h" #include "bestla/bestla.h" -enum class quant_bits : int { q4 = 0, q2, q3, q8, fp4_e2m1, nf4, fp8_e4m3, fp8_e5m2, count }; +enum class quant_bits : int { q4 = 0, q1, q2, q3, q5, q6, q7, q8, fp4_e2m1, nf4, fp8_e4m3, fp8_e5m2, count }; static inline quant_bits parse_bits(const std::string& bits) { if (bits == "int3") { return quant_bits::q3; @@ -26,6 +26,18 @@ static inline quant_bits parse_bits(const std::string& bits) { if (bits == "int2") { return quant_bits::q2; } + if (bits == "int1") { + return quant_bits::q1; + } + if (bits == "int5") { + return quant_bits::q5; + } + if (bits == "int6") { + return quant_bits::q6; + } + if (bits == "int7") { + return quant_bits::q7; + } if (bits == "int4") { return quant_bits::q4; } diff --git a/neural_speed/models/model_utils/quant_utils.cpp b/neural_speed/models/model_utils/quant_utils.cpp index c037b91a5..daf88f648 100644 --- a/neural_speed/models/model_utils/quant_utils.cpp +++ b/neural_speed/models/model_utils/quant_utils.cpp @@ -274,27 +274,48 @@ size_t bestla_quantize(const float* f32ptr, void* dstpr, const quant_params_inte auto thdptr = bestla_get_thread_handle(); BTLA_DTYPE quant_type = BTLA_DTYPE::S4_CLIP; - if (params.bits == quant_bits::q3) { - quant_type = BTLA_DTYPE::S3_CLIP; - } - if (params.bits == quant_bits::q2) { - quant_type = BTLA_DTYPE::S2_CLIP; - } - if (params.bits == quant_bits::q8) { - quant_type = BTLA_DTYPE::S8; - } - if (params.bits == quant_bits::fp4_e2m1) { - quant_type = BTLA_DTYPE::F4_E2M1; - } - if (params.bits == quant_bits::nf4) { - quant_type = BTLA_DTYPE::F4_NF4; - } - if (params.bits == quant_bits::fp8_e4m3) { - quant_type = BTLA_DTYPE::F8_E4M3; - } - if (params.bits == quant_bits::fp8_e5m2) { - quant_type = BTLA_DTYPE::F8_E5M2; + switch (params.bits) { + case quant_bits::q6: + quant_type = BTLA_DTYPE::S6_CLIP; + break; + case quant_bits::q5: + quant_type = BTLA_DTYPE::S5_CLIP; + break; + case quant_bits::q4: + quant_type = BTLA_DTYPE::S4_CLIP; + break; + case quant_bits::q3: + quant_type = BTLA_DTYPE::S3_CLIP; + break; + case quant_bits::q2: + quant_type = BTLA_DTYPE::S2_CLIP; + break; + case quant_bits::q1: + quant_type = BTLA_DTYPE::S1_CLIP; + break; + case quant_bits::q7: + quant_type = BTLA_DTYPE::S7_CLIP; + break; + case quant_bits::q8: + quant_type = BTLA_DTYPE::S8; + break; + case quant_bits::fp4_e2m1: + quant_type = BTLA_DTYPE::F4_E2M1; + break; + case quant_bits::nf4: + quant_type = BTLA_DTYPE::F4_NF4; + break; + case quant_bits::fp8_e4m3: + quant_type = BTLA_DTYPE::F8_E4M3; + break; + case quant_bits::fp8_e5m2: + quant_type = BTLA_DTYPE::F8_E5M2; + break; + default: + printf("Unsupported quant bits:%d, set to int4\n", int(quant_type)); + break; } + auto dtype_type = static_cast( bestla::utils::bestla_dtype_get_mask_val(quant_type, BTLA_DTYPE::TypeMask, BTLA_DTYPE::TypeShift)); if (dtype_type == BTLA_DTYPE::TypeFloat) { @@ -318,6 +339,10 @@ size_t bestla_quantize(const float* f32ptr, void* dstpr, const quant_params_inte } scale_type = BTLA_DTYPE::F8_E8M0; } + if (quant_type == BTLA_DTYPE::S1_CLIP || quant_type == BTLA_DTYPE::S7_CLIP) { + printf("Current not support this data type, reset to int4\n"); + quant_type = BTLA_DTYPE::S4_CLIP; + } auto gsize = params.group_size == -1 ? k : params.group_size; auto size = BTLAGemmPackBSize(n, k, gsize, quant_type, scale_type, params.alg == quant_alg::asym, ctype, nullptr); bool constexpr IsTrans_TorchWeight = true;