Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[BesTLA] Support int5&int6 for kernels and models (#259)
Browse files Browse the repository at this point in the history
* add initial of int5

* add all gemv of int5

* finish all avx2 kernels of int5

* add benchmark of int5

* add avx512f s5_s8, s5_fp

* add avx512f kernels of int5

* test LLaMa2-7B with int5, sym and asym.

* fix code scan

* clang-format

* add avx2 decompress kernels of int6

* add avx2 gemv kernels for int6

* add avx512f kernels for int6

* clang-format

* fix UT

* fix UT bug

* add UTs for new bits

* update doc

* fix UT bug

* fix ISA check

* fix bug of AVX2 s6_s8

* update dtypes in advanced_usage.md
  • Loading branch information
luoyu-intel authored May 21, 2024
1 parent 3257516 commit 68d2cff
Show file tree
Hide file tree
Showing 18 changed files with 7,447 additions and 2,585 deletions.
31 changes: 28 additions & 3 deletions bestla/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ BesTLA provides weight-only linear computational capabilities for LLM inference.
| Weight dtype | Compute dtype | Scale dtype | algo |
| ---------------------- | :----------------: | :---------------: | :--------: |
| INT8 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
| INT4 (CLIP, FULLRANGE) | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
| INT4 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
| INT3 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
| INT2 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
| INT5 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
| INT6 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
| FP8 (E4M3, E5M2) | BF16 / FP32 | FP32 / FP8 (E8M0) | sym |
| FP4 (E2M1) | BF16 / FP32 | BF16 / FP32 | sym |
| NF4 | BF16 / FP32 | BF16 / FP32 | sym |
Expand All @@ -47,11 +51,32 @@ BesTLA provides assembly-level postop-fusion through epilogue to minimize the ov
## Compilation Requirements and Usage
Compile:

- GCC version >=8.5.0
- CMake version >=3.5
- GCC version >= 9.0
- CMake version >= 3.12
- MSVC version >= 1900
- oneAPI version >= 2024.0

Best Performance:

- GCC >= 11.0.0
- MSVC >= 1930
- DPCPP >= 2024.0


Usage:
```cmake
add_subdirectory(bestla)
target_link_libraries("${YOUR_PROJECT}" bestla::bestla)
```

# Benchmark
Build with:
```shell
mkdir build
cd build
cmake .. -DBTLA_UT_BENCHMARK=ON -DBTLA_UT_ALL=ON -DCMAKE_BUILD_TYPE=Release
cmake --build . -j
./bestla_benchmark
```

More template usages, please refer code in [bestla_benchmark](bestla/ut/bestla_benchmark.cpp)
8 changes: 8 additions & 0 deletions bestla/bestla/bestla.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ enum class BTLA_DTYPE : uint32_t {
EleBitsMask = 0xff,
EleBitsShift = 0,
EleBitsUndef = 0,
EleBits1 = 1,
EleBits2 = 2,
EleBits3 = 3,
EleBits4 = 4,
EleBits5 = 5,
EleBits6 = 6,
EleBits7 = 7,
EleBits8 = 8,
EleBits16 = 16,
EleBits32 = 32,
Expand All @@ -66,9 +70,13 @@ enum class BTLA_DTYPE : uint32_t {
DQ8_BNB = EleBits8 | TypeFloat | SubType4,
S8 = EleBits8 | TypeInt,
U8 = EleBits8 | TypeInt | SubType1,
S1_CLIP = EleBits1 | TypeInt,
S2_CLIP = EleBits2 | TypeInt,
S3_CLIP = EleBits3 | TypeInt,
S4_CLIP = EleBits4 | TypeInt,
S5_CLIP = EleBits5 | TypeInt,
S6_CLIP = EleBits6 | TypeInt,
S7_CLIP = EleBits7 | TypeInt,
F4_E2M1 = EleBits4 | TypeFloat,
F4_BNB = EleBits4 | TypeFloat | SubType1,
F4_NF4 = EleBits4 | TypeFloat | SubType2,
Expand Down
177 changes: 132 additions & 45 deletions bestla/bestla/bestla_prologue_b.h

Large diffs are not rendered by default.

17 changes: 13 additions & 4 deletions bestla/bestla/bestla_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -712,11 +712,20 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase {
InfoType::resize(NPad, KPad, Block, N, K, qtype);
auto bits = utils::bestla_dtype_bits(qtype);
auto elesize = static_cast<size_t>(NPad) * KPad;
auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here
if (qtype == BTLA_DTYPE::S3_CLIP)
elesize =
static_cast<size_t>(utils::padto(KPad, 128)) * NPad; // pad K-dim to 128 because 128pack round2 interleave.
// round2 interleave ld_dim == pad_to(KPad,128) * NTILE
auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here
bytes =
utils::updiv(static_cast<size_t>(KPad) * NPad * 2, 8) + utils::updiv(static_cast<size_t>(KPad) * NPad * 1, 8);
else if (qtype == BTLA_DTYPE::S5_CLIP)
bytes =
utils::updiv(static_cast<size_t>(KPad) * NPad * 4, 8) + utils::updiv(static_cast<size_t>(KPad) * NPad * 1, 8);
else if (qtype == BTLA_DTYPE::S6_CLIP)
bytes =
utils::updiv(static_cast<size_t>(KPad) * NPad * 4, 8) + utils::updiv(static_cast<size_t>(KPad) * NPad * 2, 8);
else if (qtype == BTLA_DTYPE::S7_CLIP)
bytes = utils::updiv(static_cast<size_t>(KPad) * NPad * 4, 8) +
utils::updiv(static_cast<size_t>(KPad) * NPad * 2, 8) +
utils::updiv(static_cast<size_t>(KPad) * NPad * 1, 8);
mQBuf.resize(bytes);
int nk_scale = utils::updiv(KPad, Block);
auto gemm_comp = bestla::gemm::CoreAttr::get_comp(mCoreId);
Expand Down
8 changes: 8 additions & 0 deletions bestla/bestla/bestla_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,20 @@ inline const char* bestla_dtype_str(BTLA_DTYPE dtype) {
return "signed_int8";
case BTLA_DTYPE::U8:
return "unsigned_int8";
case BTLA_DTYPE::S7_CLIP:
return "int7_clip";
case BTLA_DTYPE::S6_CLIP:
return "int6_clip";
case BTLA_DTYPE::S5_CLIP:
return "int5_clip";
case BTLA_DTYPE::S4_CLIP:
return "int4_clip";
case BTLA_DTYPE::S3_CLIP:
return "int3_clip";
case BTLA_DTYPE::S2_CLIP:
return "int2_clip";
case BTLA_DTYPE::S1_CLIP:
return "int1_clip";
case BTLA_DTYPE::F4_E2M1:
return "fp4_e2m1";
case BTLA_DTYPE::F4_BNB:
Expand Down
115 changes: 115 additions & 0 deletions bestla/bestla/bestla_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,58 @@
namespace bestla {
namespace wrapper {
namespace gemv_nbits {
class S6 {
public:
static int constexpr NBits = 6;
template <typename ScaleT>
static inline utils::GemvParamB<ScaleT> createB(storage::gemm::StorageWeightKBlockNInteger* packedW) {
auto isasym = packedW->IsAsym();
auto bzptr = packedW->template ZPtr<int8_t>();
int ld_scaleb = packedW->CStep();
auto bwptr = packedW->template WPtr<uint8_t>();
auto bit2_offset = packedW->mNPad * packedW->mKPad / 2;
utils::GemvParamB<ScaleT> paramB{
bwptr, bwptr + bit2_offset, nullptr, packedW->template SPtr<ScaleT>(), isasym ? bzptr : nullptr,
NBits, ld_scaleb, packedW->mKPad};
return paramB;
}
template <typename ScaleT>
static void updateBNStep(utils::GemvParamB<ScaleT>& paramB, int n_offset) {
paramB.b4ptr += n_offset * paramB.kpad / 2;
paramB.b2ptr += n_offset * paramB.kpad / 4;
paramB.sptr += n_offset;
if (paramB.zpptr) {
paramB.zpptr += n_offset;
}
}
};

class S5 {
public:
static int constexpr NBits = 5;
template <typename ScaleT>
static inline utils::GemvParamB<ScaleT> createB(storage::gemm::StorageWeightKBlockNInteger* packedW) {
auto isasym = packedW->IsAsym();
auto bzptr = packedW->template ZPtr<int8_t>();
int ld_scaleb = packedW->CStep();
auto bwptr = packedW->template WPtr<uint8_t>();
auto bit1_offset = packedW->mNPad * packedW->mKPad / 2;
utils::GemvParamB<ScaleT> paramB{
bwptr, nullptr, bwptr + bit1_offset, packedW->template SPtr<ScaleT>(), isasym ? bzptr : nullptr,
NBits, ld_scaleb, packedW->mKPad};
return paramB;
}
template <typename ScaleT>
static void updateBNStep(utils::GemvParamB<ScaleT>& paramB, int n_offset) {
paramB.b4ptr += n_offset * paramB.kpad / 2;
paramB.b1ptr += n_offset * paramB.kpad / 8;
paramB.sptr += n_offset;
if (paramB.zpptr) {
paramB.zpptr += n_offset;
}
}
};

class S4 {
public:
static int constexpr NBits = 4;
Expand Down Expand Up @@ -161,6 +213,8 @@ class LauncherBase {
bool impl = true;
impl &= _param.paramB.packedW->mDType == BTLA_DTYPE::S4_CLIP ||
_param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP ||
_param.paramB.packedW->mDType == BTLA_DTYPE::S6_CLIP ||
_param.paramB.packedW->mDType == BTLA_DTYPE::S5_CLIP ||
_param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP;
if constexpr (support()) {
impl &= _param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::F32 ||
Expand Down Expand Up @@ -233,6 +287,36 @@ class LauncherBase {
}
return;
}
if (_param.paramB.packedW->mDType == BTLA_DTYPE::S5_CLIP) {
if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) {
if (m == 1) gemv_kblock<float, 1, gemv_nbits::S5>(_param, _config);
if (m == 2) gemv_kblock<float, 2, gemv_nbits::S5>(_param, _config);
if (m == 3) gemv_kblock<float, 3, gemv_nbits::S5>(_param, _config);
if (m == 4) gemv_kblock<float, 4, gemv_nbits::S5>(_param, _config);

} else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) {
if (m == 1) gemv_kblock<utils::bf16, 1, gemv_nbits::S5>(_param, _config);
if (m == 2) gemv_kblock<utils::bf16, 2, gemv_nbits::S5>(_param, _config);
if (m == 3) gemv_kblock<utils::bf16, 3, gemv_nbits::S5>(_param, _config);
if (m == 4) gemv_kblock<utils::bf16, 4, gemv_nbits::S5>(_param, _config);
}
return;
}
if (_param.paramB.packedW->mDType == BTLA_DTYPE::S6_CLIP) {
if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) {
if (m == 1) gemv_kblock<float, 1, gemv_nbits::S6>(_param, _config);
if (m == 2) gemv_kblock<float, 2, gemv_nbits::S6>(_param, _config);
if (m == 3) gemv_kblock<float, 3, gemv_nbits::S6>(_param, _config);
if (m == 4) gemv_kblock<float, 4, gemv_nbits::S6>(_param, _config);

} else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) {
if (m == 1) gemv_kblock<utils::bf16, 1, gemv_nbits::S6>(_param, _config);
if (m == 2) gemv_kblock<utils::bf16, 2, gemv_nbits::S6>(_param, _config);
if (m == 3) gemv_kblock<utils::bf16, 3, gemv_nbits::S6>(_param, _config);
if (m == 4) gemv_kblock<utils::bf16, 4, gemv_nbits::S6>(_param, _config);
}
return;
}
if (_param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP) {
if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) {
if (m == 1) gemv_kblock<float, 1, gemv_nbits::S3>(_param, _config);
Expand Down Expand Up @@ -418,6 +502,8 @@ class LauncherIntKBlock {
static bool implemented(const Param& _param) {
bool impl = true;
impl &= _param.paramB.packedW->mDType == BTLA_DTYPE::S4_CLIP ||
_param.paramB.packedW->mDType == BTLA_DTYPE::S6_CLIP ||
_param.paramB.packedW->mDType == BTLA_DTYPE::S5_CLIP ||
_param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP ||
_param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP;
impl &= _param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::F32 ||
Expand Down Expand Up @@ -490,7 +576,36 @@ class LauncherIntKBlock {
}
return;
}
if (_param.paramB.packedW->mDType == BTLA_DTYPE::S5_CLIP) {
if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) {
if (m == 1) gemv_kblock<float, 1, gemv_nbits::S5>(_param, _config);
if (m == 2) gemv_kblock<float, 2, gemv_nbits::S5>(_param, _config);
if (m == 3) gemv_kblock<float, 3, gemv_nbits::S5>(_param, _config);
if (m == 4) gemv_kblock<float, 4, gemv_nbits::S5>(_param, _config);

} else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) {
if (m == 1) gemv_kblock<utils::bf16, 1, gemv_nbits::S5>(_param, _config);
if (m == 2) gemv_kblock<utils::bf16, 2, gemv_nbits::S5>(_param, _config);
if (m == 3) gemv_kblock<utils::bf16, 3, gemv_nbits::S5>(_param, _config);
if (m == 4) gemv_kblock<utils::bf16, 4, gemv_nbits::S5>(_param, _config);
}
return;
}
if (_param.paramB.packedW->mDType == BTLA_DTYPE::S6_CLIP) {
if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) {
if (m == 1) gemv_kblock<float, 1, gemv_nbits::S6>(_param, _config);
if (m == 2) gemv_kblock<float, 2, gemv_nbits::S6>(_param, _config);
if (m == 3) gemv_kblock<float, 3, gemv_nbits::S6>(_param, _config);
if (m == 4) gemv_kblock<float, 4, gemv_nbits::S6>(_param, _config);

} else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) {
if (m == 1) gemv_kblock<utils::bf16, 1, gemv_nbits::S6>(_param, _config);
if (m == 2) gemv_kblock<utils::bf16, 2, gemv_nbits::S6>(_param, _config);
if (m == 3) gemv_kblock<utils::bf16, 3, gemv_nbits::S6>(_param, _config);
if (m == 4) gemv_kblock<utils::bf16, 4, gemv_nbits::S6>(_param, _config);
}
return;
}
if (_param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP) {
if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) {
if (m == 1) gemv_kblock<float, 1, gemv_nbits::S3>(_param, _config);
Expand Down
Loading

0 comments on commit 68d2cff

Please sign in to comment.