Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[BesTLA] First-token inference optimization (#271)
Browse files Browse the repository at this point in the history
* add per-channel kblock template

* add gemv support for pckblock. revise all benchmark cases and UT cases.

* fix bandwith calc of CompFp32 and CompBf16

* use correct core number

* update thread pool

* fix bug

* fix bug

* update kernels with gemm and qkv fusion

* refactor epilogue (removing ISA from the class' template)

* update fnn and ip_add

* fix compile on gcc

* fix gcc template

* fix compile

* update amx template

* fix UT compile

* fix benchmark compile

* revert NTILE of amx_int8

* reduce templates

* fix deprecated UTs. optimize cache block strategy

* Enlarge stack size on windows

* revert NTILE of amx_int8

* update cache config

* add mul support

* add mul implementation

* support tensor mul tensor

* fix compile on gcc

* clang-format

* fix doc

* code scan fix

* fix compile

* fix batch bug

* comment add

* comment mul

* enable mul&add

* clang-format

* fix the code bug of mul and add. use new kernels in custom::epilogue

* clang-format

---------

Co-authored-by: yuchengliu1 <[email protected]>
  • Loading branch information
luoyu-intel and yuchengliu1 authored May 31, 2024
1 parent af22f2a commit 3757fda
Show file tree
Hide file tree
Showing 26 changed files with 1,515 additions and 2,534 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ endif()

if (MSVC)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS NOMINMAX)
add_compile_options(/bigobj)
if (BUILD_SHARED_LIBS)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
Expand Down
4 changes: 4 additions & 0 deletions bestla/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ if(UT_BUILD)
target_link_options(${PROJECT_NAME}_ut PRIVATE -fsanitize=address)
endif()
target_link_options(${PROJECT_NAME}_ut PRIVATE -lpthread)
else()
target_link_options(${PROJECT_NAME}_ut PUBLIC /STACK:5242880)
endif()

add_ut_flag(BTLA_UT_DEBUG)
Expand Down Expand Up @@ -137,6 +139,8 @@ if(BTLA_UT_BENCHMARK)
endif()
if(NOT WIN32)
target_link_options(${PROJECT_NAME}_benchmark PRIVATE -lpthread)
else()
target_link_options(${PROJECT_NAME}_benchmark PUBLIC /STACK:5242880)
endif()
target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE ${PROJECT_NAME} ${sycl_libs})
endif(BTLA_UT_BENCHMARK)
3 changes: 2 additions & 1 deletion bestla/bestla/bestla_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
#include "bestla_utils.h"
#ifdef _WIN32
#include <windows.h>
#define FIXED_CACHE 1
#else
#include <sched.h>
#define FIXED_CACHE 0
#endif

#define FIXED_CACHE_SIZE ((1 << 20) - (128 << 10))
#define FIXED_CACHE 1

namespace bestla {

Expand Down
209 changes: 101 additions & 108 deletions bestla/bestla/bestla_epilogue.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,102 @@ namespace bestla {
namespace epilogue {
namespace gemm {

struct ParamPcKBlockCompInt8Epilogue {
void* scalesB;
BTLA_DTYPE scaleBdtype;
float* scalesA;
// optional if A asym
uint8_t* zpA = nullptr;
void* reduceB = nullptr;
BTLA_DTYPE reduceBdtype = BTLA_DTYPE::F32;
// optional if B asym
int8_t* zpB = nullptr;
float* reduceA = nullptr;
int K = 1;
};
template <class Fp32Epilogue>
class PcKBlockCompInt8Epilogue {
public:
using Fp32Param = typename Fp32Epilogue::Param;
struct Param {
ParamPcKBlockCompInt8Epilogue param1;
Fp32Param param2;
};
using Fp32Epi = Fp32Epilogue;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const int32_t* srcptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
BTLA_CODE ret = BTLA_CODE::NotSupport;
float* scab = nullptr;
size_t ScaleBTmpSize = N * sizeof(float);
size_t ReduceBTmpSize = N * sizeof(float);
assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize));
auto& param1 = _param.param1;
if (param1.scaleBdtype == BTLA_DTYPE::BF16) {
auto scache = reinterpret_cast<float*>(tmpcache);
ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward<ISA_T>(
reinterpret_cast<utils::bf16*>(param1.scalesB) + N_offset, scache, 1, N, N, N, false);
assert(ret == BTLA_CODE::Success);
scab = scache;
} else if (param1.scaleBdtype == BTLA_DTYPE::F32) {
scab = reinterpret_cast<float*>(param1.scalesB) + N_offset;
}
float* redb = nullptr;
if (param1.reduceB) {
if (param1.reduceBdtype == BTLA_DTYPE::BF16) {
auto rcache = reinterpret_cast<float*>(reinterpret_cast<char*>(tmpcache) + ScaleBTmpSize);
ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward<ISA_T>(
reinterpret_cast<utils::bf16*>(param1.reduceB) + N_offset, rcache, 1, N, N, N, false);
assert(ret == BTLA_CODE::Success);
redb = rcache;
} else if (param1.reduceBdtype == BTLA_DTYPE::F32) {
redb = reinterpret_cast<float*>(param1.reduceB) + N_offset;
}
}
auto tmpfp32ptr = reinterpret_cast<float*>(const_cast<int32_t*>(srcptr));
ret = kernel::wrapper::DequanS32Fp32::template forward<ISA_T>(srcptr, cachestep, tmpfp32ptr, cachestep, M, N,
param1.scalesA + M_offset, 1, scab);
assert(ret == BTLA_CODE::Success);

if (param1.zpA == nullptr) {
if (param1.zpB == nullptr) {
} else {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei<ISA_T>(
tmpfp32ptr, cachestep, M, N, param1.zpB + N_offset, scab, 1, param1.reduceA + M_offset);
}
} else {
if (param1.zpB == nullptr) {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_act<ISA_T>(
tmpfp32ptr, cachestep, M, N, param1.zpA + M_offset, param1.scalesA + M_offset, 1, redb);
} else {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_both<ISA_T>(
tmpfp32ptr, cachestep, M, N, param1.zpA + M_offset, param1.zpB + N_offset, param1.scalesA + M_offset, scab,
1, param1.K, param1.reduceA + M_offset, redb);
}
}
Fp32Epilogue::template forward<ISA_T>(tmpfp32ptr, cachestep, M_offset, N_offset, M, N, _param.param2, tmpcache,
cachesize);

return ret;
}
};

template <typename DT>
struct ParamAccumulatorWriteBack {
DT* C;
int ldc;
void* elt_const_v;
};

template <BTLA_ISA ISA_T, typename _SRC_T, typename _DST_T>
template <typename _SRC_T, typename _DST_T>
class AccumulatorWriteBack {
public:
using SType = _SRC_T;
using DType = _DST_T;
using Param = ParamAccumulatorWriteBack<DType>;
using PcCompInt8Epi = bestla::epilogue::gemm::PcKBlockCompInt8Epilogue<AccumulatorWriteBack<_SRC_T, _DST_T>>;

template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
Expand All @@ -52,10 +134,13 @@ class AccumulatorWriteBack {
}
};

template <BTLA_ISA ISA_T, typename _SRC_T, typename _DST_T, BTLA_ELTWISEOP _OP>
template <typename _SRC_T, typename _DST_T, BTLA_ELTWISEOP _OP>
class CustomAccumulatorWriteBackWithEltop {
public:
using PcCompInt8Epi =
bestla::epilogue::gemm::PcKBlockCompInt8Epilogue<CustomAccumulatorWriteBackWithEltop<_SRC_T, _DST_T, _OP>>;
using Param = ParamAccumulatorWriteBack<_DST_T>;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
Expand All @@ -68,39 +153,29 @@ class CustomAccumulatorWriteBackWithEltop {
}
}
};
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackFp32 = AccumulatorWriteBack<ISA_T, float, float>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackInt32 = AccumulatorWriteBack<ISA_T, int, int>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackBf16 = AccumulatorWriteBack<ISA_T, utils::bf16, utils::bf16>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackFp16 = AccumulatorWriteBack<ISA_T, utils::fp16, utils::fp16>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackBf16Fp32 = AccumulatorWriteBack<ISA_T, utils::bf16, float>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack<ISA_T, utils::fp16, float>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack<ISA_T, float, utils::bf16>;
using AccumulatorWriteBackFp32 = AccumulatorWriteBack<float, float>;
using AccumulatorWriteBackInt32 = AccumulatorWriteBack<int, int>;
using AccumulatorWriteBackBf16 = AccumulatorWriteBack<utils::bf16, utils::bf16>;
using AccumulatorWriteBackFp16 = AccumulatorWriteBack<utils::fp16, utils::fp16>;
using AccumulatorWriteBackBf16Fp32 = AccumulatorWriteBack<utils::bf16, float>;
using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack<utils::fp16, float>;
using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack<float, utils::bf16>;

template <BTLA_ISA ISA_T>
using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop<ISA_T, float, float, BTLA_ELTWISEOP::GELU>;
using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop<float, float, BTLA_ELTWISEOP::GELU>;

template <BTLA_ISA ISA_T>
using AccumulatorWriteBackWithSwishFp32 =
CustomAccumulatorWriteBackWithEltop<ISA_T, float, float, BTLA_ELTWISEOP::SWISH>;
using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop<float, float, BTLA_ELTWISEOP::SWISH>;

template <typename DT>
struct ParamAlphaBetaProcess {
DT *C, *D;
int ldc, ldd;
float alpha, beta;
};
template <BTLA_ISA ISA_T>
class AlphaBetaProcessFp32 {
public:
using Param = ParamAlphaBetaProcess<float>;

template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto DOffset = M_offset * _param.ldd + N_offset;
Expand All @@ -120,10 +195,10 @@ struct ParamCompFp32BlockEpilogue {
float* reduce = nullptr;
int ldra;
};
template <BTLA_ISA ISA_T>
class CompFp32BlockEpilogue {
public:
using Param = ParamCompFp32BlockEpilogue;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset,
const int N_offset, const int K_offset, const int M, const int N, const Param& _param,
void* tmpcache, size_t cachesize) {
Expand Down Expand Up @@ -171,10 +246,10 @@ struct ParamDequantInt32ToFp32 {
float* scalesA;
float* scalesB;
};
template <BTLA_ISA ISA_T>
class DequantInt32ToFp32 {
public:
using Param = ParamDequantInt32ToFp32;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
Expand All @@ -185,88 +260,6 @@ class DequantInt32ToFp32 {
}
};

struct ParamCompInt8BlockEpilogue {
void* scalesB;
BTLA_DTYPE scaleBdtype;
int ldsb;
float* scalesA;
int ldsa;
// optional if A asym
uint8_t* zpA = nullptr;
void* reduceB = nullptr;
BTLA_DTYPE reduceBdtype = BTLA_DTYPE::F32;
// optional if B asym
int8_t* zpB = nullptr;
float* reduceA = nullptr;
int K = 1;
};
template <BTLA_ISA ISA_T>
class CompInt8BlockEpilogue {
public:
using Param = ParamCompInt8BlockEpilogue;
static BTLA_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset,
const int N_offset, const int K_offset, const int M, const int N, const Param& _param,
void* tmpcache, size_t cachesize) {
BTLA_CODE ret = BTLA_CODE::NotSupport;
float* scab = nullptr;
size_t ScaleBTmpSize = N * sizeof(float);
size_t ReduceBTmpSize = N * sizeof(float);
assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize));
if (_param.scaleBdtype == BTLA_DTYPE::BF16) {
auto scache = reinterpret_cast<float*>(tmpcache);
ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward<ISA_T>(
reinterpret_cast<utils::bf16*>(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N,
false);
assert(ret == BTLA_CODE::Success);
scab = scache;
} else if (_param.scaleBdtype == BTLA_DTYPE::F32) {
scab = reinterpret_cast<float*>(_param.scalesB) + N_offset + K_offset * _param.ldsb;
}
float* redb = nullptr;
if (_param.reduceB) {
if (_param.reduceBdtype == BTLA_DTYPE::BF16) {
auto rcache = reinterpret_cast<float*>(reinterpret_cast<char*>(tmpcache) + ScaleBTmpSize);
ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward<ISA_T>(
reinterpret_cast<utils::bf16*>(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N,
false);
assert(ret == BTLA_CODE::Success);
redb = rcache;
} else if (_param.reduceBdtype == BTLA_DTYPE::F32) {
redb = reinterpret_cast<float*>(_param.reduceB) + N_offset + K_offset * _param.ldsb;
}
}
ret = kernel::wrapper::DequanS32Fp32::template forward<ISA_T>(
srcptr, cachestep, reinterpret_cast<float*>(const_cast<int32_t*>(srcptr)), cachestep, M, N,
_param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab);
assert(ret == BTLA_CODE::Success);
ret = kernel::wrapper::AccumulateFp32::template forward<ISA_T>(reinterpret_cast<const float*>(srcptr), cachestep,
dstptr, cachestep, M, N);
assert(ret == BTLA_CODE::Success);

if (_param.zpA == nullptr) {
if (_param.zpB == nullptr) {
return ret;
} else {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei<ISA_T>(
dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa,
_param.reduceA + M_offset * _param.ldsa + K_offset);
}
} else {
if (_param.zpB == nullptr) {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_act<ISA_T>(
dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset,
_param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb);
} else {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_both<ISA_T>(
dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset,
_param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab,
_param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb);
}
}
return ret;
}
};

struct ParamZpDequantInt32ToFp32 {
// necessary
float* C;
Expand All @@ -282,10 +275,10 @@ struct ParamZpDequantInt32ToFp32 {
float* reduceA = nullptr;
int K = 1;
};
template <BTLA_ISA ISA_T>
class ZpDequantInt32ToFp32 {
public:
using Param = ParamZpDequantInt32ToFp32;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
Expand Down Expand Up @@ -323,10 +316,10 @@ struct ParamAlphaBetaProcessS32U8 {
float scaleAcc, scaleC;
int zpC;
};
template <BTLA_ISA ISA_T>
class AlphaBetaProcessS32U8 {
public:
using Param = ParamAlphaBetaProcessS32U8;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
Expand Down
4 changes: 2 additions & 2 deletions bestla/bestla/bestla_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -4816,7 +4816,7 @@ class CoreCodeBase {
static auto constexpr KTILE = Code::KTILE;
static auto constexpr PACK_ROW = Code::PackRow;
static auto constexpr COMP = Code::COMPUTE;
static int constexpr PREFERRED_N = NTILE * 3;
static int constexpr PREFERRED_N = NTILE * 4;
static auto constexpr ISA = Code::ISA;
static auto constexpr ID = CoreAttr::make_core_id(NTILE, PACK_ROW, COMP, ISA);
void configure(int _M, int _N, int _K) { (void)(0); }
Expand All @@ -4842,7 +4842,7 @@ class CoreCodeBaseAMX {
static auto constexpr KTILE = Code::KTILE;
static auto constexpr PACK_ROW = Code::PackRow;
static auto constexpr COMP = Code::COMPUTE;
static int constexpr PREFERRED_N = NTILE * 3;
static int constexpr PREFERRED_N = NTILE * 4;
static auto constexpr ISA = Code::ISA;
static auto constexpr ID = CoreAttr::make_core_id(_NTILE, PACK_ROW, COMP, ISA);
Xbyak::CodeGenerator cfgcode;
Expand Down
Loading

0 comments on commit 3757fda

Please sign in to comment.