diff --git a/CMakeLists.txt b/CMakeLists.txt index c341d83c7..3ace8bf36 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,6 +96,7 @@ endif() if (MSVC) add_compile_definitions(_CRT_SECURE_NO_WARNINGS NOMINMAX) + add_compile_options(/bigobj) if (BUILD_SHARED_LIBS) set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) endif() diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index 9c44e2fcc..f6d58049a 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -108,6 +108,8 @@ if(UT_BUILD) target_link_options(${PROJECT_NAME}_ut PRIVATE -fsanitize=address) endif() target_link_options(${PROJECT_NAME}_ut PRIVATE -lpthread) + else() + target_link_options(${PROJECT_NAME}_ut PUBLIC /STACK:5242880) endif() add_ut_flag(BTLA_UT_DEBUG) @@ -137,6 +139,8 @@ if(BTLA_UT_BENCHMARK) endif() if(NOT WIN32) target_link_options(${PROJECT_NAME}_benchmark PRIVATE -lpthread) + else() + target_link_options(${PROJECT_NAME}_benchmark PUBLIC /STACK:5242880) endif() target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE ${PROJECT_NAME} ${sycl_libs}) endif(BTLA_UT_BENCHMARK) diff --git a/bestla/bestla/bestla_device.h b/bestla/bestla/bestla_device.h index aaa1c3b28..d7c1f2fbb 100644 --- a/bestla/bestla/bestla_device.h +++ b/bestla/bestla/bestla_device.h @@ -20,12 +20,13 @@ #include "bestla_utils.h" #ifdef _WIN32 #include +#define FIXED_CACHE 1 #else #include +#define FIXED_CACHE 0 #endif #define FIXED_CACHE_SIZE ((1 << 20) - (128 << 10)) -#define FIXED_CACHE 1 namespace bestla { diff --git a/bestla/bestla/bestla_epilogue.h b/bestla/bestla/bestla_epilogue.h index 3360688f5..11fa5db99 100644 --- a/bestla/bestla/bestla_epilogue.h +++ b/bestla/bestla/bestla_epilogue.h @@ -23,6 +23,86 @@ namespace bestla { namespace epilogue { namespace gemm { +struct ParamPcKBlockCompInt8Epilogue { + void* scalesB; + BTLA_DTYPE scaleBdtype; + float* scalesA; + // optional if A asym + uint8_t* zpA = nullptr; + void* reduceB = nullptr; + BTLA_DTYPE reduceBdtype = BTLA_DTYPE::F32; + // optional if B asym + int8_t* zpB = nullptr; + float* reduceA = nullptr; + int K = 1; +}; +template +class PcKBlockCompInt8Epilogue { + public: + using Fp32Param = typename Fp32Epilogue::Param; + struct Param { + ParamPcKBlockCompInt8Epilogue param1; + Fp32Param param2; + }; + using Fp32Epi = Fp32Epilogue; + template + static BTLA_CODE forward(const int32_t* srcptr, const int cachestep, const int M_offset, const int N_offset, + const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; + float* scab = nullptr; + size_t ScaleBTmpSize = N * sizeof(float); + size_t ReduceBTmpSize = N * sizeof(float); + assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); + auto& param1 = _param.param1; + if (param1.scaleBdtype == BTLA_DTYPE::BF16) { + auto scache = reinterpret_cast(tmpcache); + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + reinterpret_cast(param1.scalesB) + N_offset, scache, 1, N, N, N, false); + assert(ret == BTLA_CODE::Success); + scab = scache; + } else if (param1.scaleBdtype == BTLA_DTYPE::F32) { + scab = reinterpret_cast(param1.scalesB) + N_offset; + } + float* redb = nullptr; + if (param1.reduceB) { + if (param1.reduceBdtype == BTLA_DTYPE::BF16) { + auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + reinterpret_cast(param1.reduceB) + N_offset, rcache, 1, N, N, N, false); + assert(ret == BTLA_CODE::Success); + redb = rcache; + } else if (param1.reduceBdtype == BTLA_DTYPE::F32) { + redb = reinterpret_cast(param1.reduceB) + N_offset; + } + } + auto tmpfp32ptr = reinterpret_cast(const_cast(srcptr)); + ret = kernel::wrapper::DequanS32Fp32::template forward(srcptr, cachestep, tmpfp32ptr, cachestep, M, N, + param1.scalesA + M_offset, 1, scab); + assert(ret == BTLA_CODE::Success); + + if (param1.zpA == nullptr) { + if (param1.zpB == nullptr) { + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( + tmpfp32ptr, cachestep, M, N, param1.zpB + N_offset, scab, 1, param1.reduceA + M_offset); + } + } else { + if (param1.zpB == nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( + tmpfp32ptr, cachestep, M, N, param1.zpA + M_offset, param1.scalesA + M_offset, 1, redb); + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( + tmpfp32ptr, cachestep, M, N, param1.zpA + M_offset, param1.zpB + N_offset, param1.scalesA + M_offset, scab, + 1, param1.K, param1.reduceA + M_offset, redb); + } + } + Fp32Epilogue::template forward(tmpfp32ptr, cachestep, M_offset, N_offset, M, N, _param.param2, tmpcache, + cachesize); + + return ret; + } +}; + template struct ParamAccumulatorWriteBack { DT* C; @@ -30,13 +110,15 @@ struct ParamAccumulatorWriteBack { void* elt_const_v; }; -template +template class AccumulatorWriteBack { public: using SType = _SRC_T; using DType = _DST_T; using Param = ParamAccumulatorWriteBack; + using PcCompInt8Epi = bestla::epilogue::gemm::PcKBlockCompInt8Epilogue>; + template static BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; @@ -52,10 +134,13 @@ class AccumulatorWriteBack { } }; -template +template class CustomAccumulatorWriteBackWithEltop { public: + using PcCompInt8Epi = + bestla::epilogue::gemm::PcKBlockCompInt8Epilogue>; using Param = ParamAccumulatorWriteBack<_DST_T>; + template static BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; @@ -68,27 +153,17 @@ class CustomAccumulatorWriteBackWithEltop { } } }; -template -using AccumulatorWriteBackFp32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackInt32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackBf16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackBf16Fp32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack; +using AccumulatorWriteBackFp32 = AccumulatorWriteBack; +using AccumulatorWriteBackInt32 = AccumulatorWriteBack; +using AccumulatorWriteBackBf16 = AccumulatorWriteBack; +using AccumulatorWriteBackFp16 = AccumulatorWriteBack; +using AccumulatorWriteBackBf16Fp32 = AccumulatorWriteBack; +using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack; +using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; +using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; -template -using AccumulatorWriteBackWithSwishFp32 = - CustomAccumulatorWriteBackWithEltop; +using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop; template struct ParamAlphaBetaProcess { @@ -96,11 +171,11 @@ struct ParamAlphaBetaProcess { int ldc, ldd; float alpha, beta; }; -template class AlphaBetaProcessFp32 { public: using Param = ParamAlphaBetaProcess; + template static BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto DOffset = M_offset * _param.ldd + N_offset; @@ -120,10 +195,10 @@ struct ParamCompFp32BlockEpilogue { float* reduce = nullptr; int ldra; }; -template class CompFp32BlockEpilogue { public: using Param = ParamCompFp32BlockEpilogue; + template static BTLA_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { @@ -171,10 +246,10 @@ struct ParamDequantInt32ToFp32 { float* scalesA; float* scalesB; }; -template class DequantInt32ToFp32 { public: using Param = ParamDequantInt32ToFp32; + template static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; @@ -185,88 +260,6 @@ class DequantInt32ToFp32 { } }; -struct ParamCompInt8BlockEpilogue { - void* scalesB; - BTLA_DTYPE scaleBdtype; - int ldsb; - float* scalesA; - int ldsa; - // optional if A asym - uint8_t* zpA = nullptr; - void* reduceB = nullptr; - BTLA_DTYPE reduceBdtype = BTLA_DTYPE::F32; - // optional if B asym - int8_t* zpB = nullptr; - float* reduceA = nullptr; - int K = 1; -}; -template -class CompInt8BlockEpilogue { - public: - using Param = ParamCompInt8BlockEpilogue; - static BTLA_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, - const int N_offset, const int K_offset, const int M, const int N, const Param& _param, - void* tmpcache, size_t cachesize) { - BTLA_CODE ret = BTLA_CODE::NotSupport; - float* scab = nullptr; - size_t ScaleBTmpSize = N * sizeof(float); - size_t ReduceBTmpSize = N * sizeof(float); - assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); - if (_param.scaleBdtype == BTLA_DTYPE::BF16) { - auto scache = reinterpret_cast(tmpcache); - ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, - false); - assert(ret == BTLA_CODE::Success); - scab = scache; - } else if (_param.scaleBdtype == BTLA_DTYPE::F32) { - scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb; - } - float* redb = nullptr; - if (_param.reduceB) { - if (_param.reduceBdtype == BTLA_DTYPE::BF16) { - auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); - ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, - false); - assert(ret == BTLA_CODE::Success); - redb = rcache; - } else if (_param.reduceBdtype == BTLA_DTYPE::F32) { - redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb; - } - } - ret = kernel::wrapper::DequanS32Fp32::template forward( - srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N, - _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab); - assert(ret == BTLA_CODE::Success); - ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep, - dstptr, cachestep, M, N); - assert(ret == BTLA_CODE::Success); - - if (_param.zpA == nullptr) { - if (_param.zpB == nullptr) { - return ret; - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( - dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa, - _param.reduceA + M_offset * _param.ldsa + K_offset); - } - } else { - if (_param.zpB == nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( - dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb); - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( - dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab, - _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb); - } - } - return ret; - } -}; - struct ParamZpDequantInt32ToFp32 { // necessary float* C; @@ -282,10 +275,10 @@ struct ParamZpDequantInt32ToFp32 { float* reduceA = nullptr; int K = 1; }; -template class ZpDequantInt32ToFp32 { public: using Param = ParamZpDequantInt32ToFp32; + template static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; @@ -323,10 +316,10 @@ struct ParamAlphaBetaProcessS32U8 { float scaleAcc, scaleC; int zpC; }; -template class AlphaBetaProcessS32U8 { public: using Param = ParamAlphaBetaProcessS32U8; + template static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; diff --git a/bestla/bestla/bestla_gemm.h b/bestla/bestla/bestla_gemm.h index 3316127f5..fe521c4ab 100644 --- a/bestla/bestla/bestla_gemm.h +++ b/bestla/bestla/bestla_gemm.h @@ -4816,7 +4816,7 @@ class CoreCodeBase { static auto constexpr KTILE = Code::KTILE; static auto constexpr PACK_ROW = Code::PackRow; static auto constexpr COMP = Code::COMPUTE; - static int constexpr PREFERRED_N = NTILE * 3; + static int constexpr PREFERRED_N = NTILE * 4; static auto constexpr ISA = Code::ISA; static auto constexpr ID = CoreAttr::make_core_id(NTILE, PACK_ROW, COMP, ISA); void configure(int _M, int _N, int _K) { (void)(0); } @@ -4842,7 +4842,7 @@ class CoreCodeBaseAMX { static auto constexpr KTILE = Code::KTILE; static auto constexpr PACK_ROW = Code::PackRow; static auto constexpr COMP = Code::COMPUTE; - static int constexpr PREFERRED_N = NTILE * 3; + static int constexpr PREFERRED_N = NTILE * 4; static auto constexpr ISA = Code::ISA; static auto constexpr ID = CoreAttr::make_core_id(_NTILE, PACK_ROW, COMP, ISA); Xbyak::CodeGenerator cfgcode; diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index b60a81d89..04a310a5e 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -212,12 +212,14 @@ class StdThreading : public IThreading { memcpy(reinterpret_cast(core_order.data() + _cd->getPcoreNum() + _cd->getEcoreNum()), reinterpret_cast(_cd->getSMTCores()), _cd->getSMTcoreNum() * sizeof(int)); } else { - core_order.resize(mThreadNum); + core_order.resize(_cd->getCores() * 2); // *2 for SMT if (_cd->isClient()) { - for (int i = 0; i < _cd->getCores(); i++) core_order[i] = 2 * i; - for (int i = _cd->getCores(); i < mThreadNum; i++) core_order[i] = 2 * (i - _cd->getCores()) + 1; + for (int i = 0; i < _cd->getCores(); i++) { + core_order[i] = 2 * i; + core_order[i + _cd->getCores()] = 2 * i + 1; + } } else { - for (int i = 0; i < mThreadNum; i++) core_order[i] = i; + for (int i = 0; i < _cd->getCores() * 2; i++) core_order[i] = i; } } _cd->core_bond(core_order[0]); @@ -483,8 +485,8 @@ class SchedulerBase : public Scheduler2D { update_cache_blocking(); Scheduler2D::set(mThdSize, mSize, mStep); mL2Use = static_cast(mBlock[0]) * mBlock[1] * mEleSize[2]; - mL2Use += static_cast(mBlock[1]) * mBlock[2] * mEleSize[1]; - mL2Use += static_cast(mStep[0]) * mBlock[2] * mEleSize[0]; + mL2Use += static_cast(mBlock[1]) * mBlock[2] * mEleSize[1] * 2; + mL2Use += static_cast(mStep[0]) * mBlock[2] * mEleSize[0] * 2; } static float constexpr DensityThres = 16; static size_t constexpr ReservedSize = 32ULL * 1024ULL; @@ -520,11 +522,11 @@ class SchedulerBase : public Scheduler2D { } virtual void cache_blocking_compute() { - int constexpr KRef = 256; + size_t constexpr KRef = 256; size_t valid_total = mL2Size - ReservedSize; - auto asize = mStep[0] * KRef * mEleSize[0]; - size_t csize_total = valid_total - _GemmCore_T::PREFERRED_N * KRef * mEleSize[1] - asize; - int maxM = static_cast(csize_total / _GemmCore_T::PREFERRED_N / mEleSize[2]); + size_t asize = KRef * mStep[0] * mEleSize[0] * 2; + size_t bsize = _GemmCore_T::PREFERRED_N * KRef * mEleSize[1] * 2; + int maxM = static_cast((valid_total - bsize - asize) / (_GemmCore_T::PREFERRED_N * mEleSize[2])); maxM = utils::downdiv(maxM, mStep[0]); int nthdm = mThdSize[0] / mStep[0]; if (maxM < nthdm) { @@ -533,7 +535,7 @@ class SchedulerBase : public Scheduler2D { } else { mBlock[0] = mThdSize[0]; } - int maxN = static_cast((valid_total - asize) / (mBlock[0] * mEleSize[2] + KRef * mEleSize[1])); + int maxN = static_cast((valid_total - asize) / (mBlock[0] * mEleSize[2] + KRef * mEleSize[1] * 2)); maxN = utils::downdiv(maxN, mStep[1]); int nthdn = mThdSize[1] / mStep[1]; if (maxN < nthdn) { @@ -542,8 +544,9 @@ class SchedulerBase : public Scheduler2D { } else { mBlock[1] = mThdSize[1]; } - auto rawk = static_cast((valid_total - mBlock[0] * mBlock[1] * mEleSize[2]) / - (mStep[0] * mEleSize[0] + mBlock[1] * mEleSize[1])); + bsize = KRef * mBlock[1] * mEleSize[1] * 2; + size_t csize = static_cast(mBlock[0]) * mBlock[1] * mEleSize[2]; + auto rawk = static_cast((valid_total - csize) / (mStep[0] * mEleSize[0] + mBlock[1] * mEleSize[1]) / 2); rawk = std::min(rawk, mSizePadded[2]); mBlock[2] = utils::padto_le(rawk, mStep[2]); } @@ -570,195 +573,6 @@ class SchedulerBase : public Scheduler2D { int mBlock[3] = {0, 0, 0}; }; -template -class SchedulerKBlock : public Scheduler2D { - // Block[2]: block size of K must be multiplier of mKBlock - // or factor of mKBlock - public: - using ThreadProblem = ThreadProblemBase; - SchedulerKBlock() = default; - SchedulerKBlock(const Config& config) { update(config); } - virtual void getIndex(ThreadProblem& problem) { - problem.stacksize = mL2Size; - problem.tmpcachesize = mL2Size - mL2Use; - problem.block[0] = mBlock[0]; - problem.block[1] = mBlock[1]; - problem.block[2] = mBlock[2]; - Scheduler2D::getIndex(problem); - } - - void update(const Config& config) { - for (size_t i = 0; i < 3; i++) { - mSize[i] = config.problem.dims[i + 1]; - mSizePadded[i] = utils::padto(mSize[i], mStep[i]); - } - mThdCount = config.threads; - mL2Size = config.l2cache; - mL1Size = config.l1cache; - moffset[0] = config.offset[0]; - moffset[1] = config.offset[1]; - mKBlock = config.problem.dims[4]; - if (mSize[0] <= 0 || mSize[1] <= 0 || mSize[2] <= 0) { - return; - } - schedule(); - assert(this->mL2Use <= this->mL2Size); - assert(this->mBlock[0] > 0); - assert(this->mBlock[1] > 0); - assert(this->mBlock[2] > 0); - } - - constexpr static BTLA_ISA gemm_ISA() { return _GemmCore_T::ISA; } - - constexpr int valid_theads() { return mThdValid; } - - void print() { - printf("Thread Block:(%d,%d)\n", mThdSize[0], mThdSize[1]); - printf("Thread in use:%d of %d, Nx%d\n", mThdValid, mThdCount, mThdPerRow); - printf("GEMM MStep:%d NStep:%d KStep:%d\n", mBlock[0], mBlock[1], mBlock[2]); - printf("Cache Size:%zu used:%zu\n", mL2Size, mL2Use); - } - - template - friend class SchedulerDispatcher; - - protected: - void schedule() { - int rownum = utils::updiv(mSize[0], mStep[0]); - int colnum = utils::updiv(mSize[1], mStep[1]); - mDensity = static_cast(mSize[0]) * mSize[1] / (mSize[0] + mSize[1]); - int maxN = 0; - float maxScore = std::numeric_limits::min(); - int core_enum = static_cast(std::sqrt(mThdCount)); - for (int i = 1; i <= core_enum; i += 1) { - generate_by_cores(i, mThdCount / i, rownum, colnum); - auto thdscore = calculate_score(); - if (maxScore < thdscore) { - maxScore = thdscore; - maxN = i; - } - generate_by_cores(mThdCount / i, i, rownum, colnum); - thdscore = calculate_score(); - if (maxScore < thdscore) { - maxScore = thdscore; - maxN = mThdCount / i; - } - } - generate_by_cores(maxN, mThdCount / maxN, rownum, colnum); - update_cache_blocking(); - Scheduler2D::set(mThdSize, mSize, mStep); - mL2Use = static_cast(mBlock[0]) * mBlock[1] * mEleSize[2] * 2; - mL2Use += static_cast(mBlock[1]) * mBlock[2] * mEleSize[1]; - mL2Use += static_cast(mStep[0]) * mBlock[2] * mEleSize[0]; - } - static float constexpr DensityThres = 16; - - float calculate_score() { - int tmpnstep = mThdSize[1] < _GemmCore_T::PREFERRED_N ? mThdSize[1] : _GemmCore_T::PREFERRED_N; - float threadratio = static_cast(mThdValid) / mThdCount; - float density = static_cast(tmpnstep) * mThdSize[0] / (tmpnstep + mThdSize[0]); - if (mDensity < DensityThres) { - return threadratio * 1.f; - } - return (threadratio * 1.f + density * 0.0016f); - } - - void generate_by_cores(int ny, int nx, int rownum, int colnum) { - mThdSize[0] = utils::updiv(rownum, ny) * mStep[0]; - mThdSize[1] = utils::updiv(colnum, nx) * mStep[1]; - mThdPerRow = utils::updiv(mSize[1], mThdSize[1]); - mThdValid = utils::updiv(mSize[0], mThdSize[0]) * mThdPerRow; - } - - // C-KBlock Accumulator=MBlock*NBlock - // C-K Accumulator=MBlock*NBlock - // B=MBlock*KBlock - // A=MTILE*KBlock - void update_cache_blocking() { - if (mDensity <= DensityThres) { - return cache_blocking_memory(); - } else { - return cache_blocking_compute(); - } - } - - void cache_blocking_compute() { - int constexpr KRef = 256; - int constexpr NRef = _GemmCore_T::PREFERRED_N; - int constexpr MTile = _GemmCore_T::MTILE; - int constexpr KSplitStage = 16; - int BlkNum = utils::updiv(mSize[2], mKBlock); - int KSplitSize = utils::padto(utils::updiv(mSize[2], KSplitStage), mStep[2]); - mBlock[1] = NRef < mThdSize[1] ? NRef : mThdSize[1]; - if (KSplitStage * mStep[2] >= mSize[2]) { - mBlock[2] = mSize[2]; - } else if (KSplitSize >= mKBlock) { - mBlock[2] = mKBlock; - } else { - int scale = utils::downdiv(KSplitStage, BlkNum); - for (; scale >= 1; scale--) { - if (mKBlock % scale == 0) { - break; - } - } - mBlock[2] = utils::downdiv(mKBlock, scale); - mBlock[2] = utils::padto_le(mBlock[2], mStep[2]); - } - size_t size_remain = mL2Size - mBlock[1] * mBlock[2] * mEleSize[1]; - // MBlock*KBlock*ASize+MBlock*NBlock*CSize*2<=size_remain - int maxMBlock = static_cast(size_remain / (mBlock[1] * mEleSize[2] * 2 + mBlock[2] * mEleSize[0])); - int maxM = utils::downdiv(maxMBlock, mStep[0]); - int nthdm = mThdSize[0] / mStep[0]; - if (maxM < nthdm) { - int niter = utils::updiv(nthdm, maxM); - mBlock[0] = utils::updiv(nthdm, niter) * mStep[0]; - } else { - mBlock[0] = mThdSize[0]; - } - } - - void cache_blocking_memory() { - mBlock[0] = _GemmCore_T::MTILE; - size_t startK = std::max(16, _GemmCore_T::KTILE); - auto getMaxN = [&](size_t refk) { - size_t sizeA = refk * mEleSize[0] * mBlock[0]; - size_t maxN = (mL1Size - sizeA) / (mBlock[0] * mEleSize[2] * 2 + refk * mEleSize[1]); - return maxN; - }; - auto getMaxK = [&](size_t refN) { - size_t sizeC = refN * mEleSize[2] * mBlock[0] * 2; - size_t maxK = (mL1Size - sizeC) / (mBlock[0] * mEleSize[0] + refN * mEleSize[1]); - return maxK; - }; - auto maxN = getMaxN(startK); - if (maxN <= mThdSize[1]) { - mBlock[1] = static_cast(maxN); - mBlock[1] = utils::padto_le(mBlock[1], mStep[1]); - mBlock[2] = static_cast(startK); - } else { - mBlock[1] = mThdSize[1]; - mBlock[2] = static_cast(getMaxK(mBlock[1])); - mBlock[2] = utils::padto_le(mBlock[2], mStep[2]); - mBlock[2] = std::min(mKBlock, mBlock[2]); - auto tmp = utils::updiv(mKBlock, mBlock[2]); - while (mKBlock % tmp != 0) tmp++; // TODO(Yu) optimize - mBlock[2] = utils::downdiv(mKBlock, tmp); - } - } - size_t mL2Size = 0, mL1Size = 0, mL2Use = 0; - float mDensity = 0.f; - int mKBlock = 0; - - private: - int mSize[3] = {0, 0, 0}; - int mThdSize[3] = {0, 0, 0}; - static constexpr int mStep[3] = {_GemmCore_T::MTILE, _GemmCore_T::NTILE, _GemmCore_T::KTILE}; - static constexpr int mEleSize[3] = {sizeof(typename _GemmCore_T::AType), sizeof(typename _GemmCore_T::BType), - sizeof(typename _GemmCore_T::CType)}; - int mSizePadded[3] = {0, 0, 0}; - int mBlock[3] = {0, 0, 0}; -}; - template class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { // Block[2]: block size of K must be multiplier of mKBlock @@ -792,12 +606,14 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { static size_t constexpr ReservedSize = 32ULL * 1024ULL; void cache_blocking_compute() override { - int constexpr KRef = 256; - int constexpr CorSize = sizeof(float) + sizeof(int8_t) + sizeof(float); + size_t constexpr KRef = 256; + size_t constexpr CorSize = sizeof(float) + sizeof(int8_t) + sizeof(float); size_t valid_total = this->mL2Size - ReservedSize; auto blks = utils::updiv(KRef, this->mKBlock); - auto asize = this->mStep[0] * KRef * this->mEleSize[0] + this->mStep[0] * blks * CorSize; + auto asize = KRef * this->mStep[0] * this->mEleSize[0] + blks * this->mStep[0] * CorSize; + asize *= 2; auto bsize = _GemmCore_T::PREFERRED_N * KRef * this->mEleSize[1] + _GemmCore_T::PREFERRED_N * blks * CorSize; + asize *= 2; size_t csize_total = valid_total - asize - bsize; int maxM = static_cast(csize_total / _GemmCore_T::PREFERRED_N / this->mEleSize[2]); maxM = utils::downdiv(maxM, this->mStep[0]); @@ -808,8 +624,8 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { } else { this->mBlock[0] = this->mThdSize[0]; } - int maxN = static_cast((valid_total - asize) / - (this->mBlock[0] * this->mEleSize[2] + KRef * this->mEleSize[1] + blks * CorSize)); + int maxN = static_cast((valid_total - asize) / (this->mBlock[0] * this->mEleSize[2] + + (KRef * this->mEleSize[1] + blks * CorSize) * 2)); maxN = utils::downdiv(maxN, this->mStep[1]); int nthdn = this->mThdSize[1] / this->mStep[1]; if (maxN < nthdn) { @@ -818,13 +634,13 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { } else { this->mBlock[1] = this->mThdSize[1]; } - auto rawk = static_cast((valid_total - this->mBlock[0] * this->mBlock[1] * this->mEleSize[2]) / + size_t csize = static_cast(this->mBlock[0]) * this->mBlock[1] * this->mEleSize[2]; + auto rawk = static_cast((valid_total - csize) / 2 / (this->mStep[0] * this->mEleSize[0] + float(CorSize * (this->mStep[0] + this->mBlock[1])) / this->mKBlock + this->mBlock[1] * this->mEleSize[1])); if (rawk < this->mKBlock) { - rawk = static_cast((valid_total - this->mBlock[0] * this->mBlock[1] * this->mEleSize[2] - - 1 * CorSize * (this->mStep[0] + this->mBlock[1])) / + rawk = static_cast((valid_total - csize - 1 * CorSize * (this->mStep[0] + this->mBlock[1])) / 2 / (this->mStep[0] * this->mEleSize[0] + this->mBlock[1] * this->mEleSize[1])); } rawk = std::min(rawk, this->mSizePadded[2]); diff --git a/bestla/bestla/bestla_wrapper.h b/bestla/bestla/bestla_wrapper.h index af2c93675..00c7e8d40 100644 --- a/bestla/bestla/bestla_wrapper.h +++ b/bestla/bestla/bestla_wrapper.h @@ -200,18 +200,73 @@ class S1 { } }; +class NBitsHelper { + public: + template + static inline utils::GemvParamB createB(storage::gemm::StorageWeightKBlockNInteger* packedW) { + if (packedW->mDType == BTLA_DTYPE::S4_CLIP) { + return S4::createB(packedW); + } + if (packedW->mDType == BTLA_DTYPE::S3_CLIP) { + return S3::createB(packedW); + } + if (packedW->mDType == BTLA_DTYPE::S5_CLIP) { + return S5::createB(packedW); + } + if (packedW->mDType == BTLA_DTYPE::S2_CLIP) { + return S2::createB(packedW); + } + if (packedW->mDType == BTLA_DTYPE::S6_CLIP) { + return S6::createB(packedW); + } + if (packedW->mDType == BTLA_DTYPE::S7_CLIP) { + return S7::createB(packedW); + } + if (packedW->mDType == BTLA_DTYPE::S1_CLIP) { + return S1::createB(packedW); + } + assert(0); + return utils::GemvParamB(); + } + template + static void updateBNStep(utils::GemvParamB& paramB, int n_offset) { + if (paramB.nbits == 4) { + return S4::updateBNStep(paramB, n_offset); + } + if (paramB.nbits == 3) { + return S3::updateBNStep(paramB, n_offset); + } + if (paramB.nbits == 5) { + return S5::updateBNStep(paramB, n_offset); + } + if (paramB.nbits == 2) { + return S2::updateBNStep(paramB, n_offset); + } + if (paramB.nbits == 6) { + return S6::updateBNStep(paramB, n_offset); + } + if (paramB.nbits == 7) { + return S7::updateBNStep(paramB, n_offset); + } + if (paramB.nbits == 1) { + return S1::updateBNStep(paramB, n_offset); + } + assert(0); + } +}; + } // namespace gemv_nbits namespace gemm { template class _PrologueA_T, - template class _PrologueB_T, template class _Epilogue_T> + template class _PrologueB_T, class _Epilogue_T> class LauncherBase { public: using GemmCore = _GemmCore_T; static constexpr BTLA_ISA ISA = _RT_ISA_T; using PrologueA = _PrologueA_T; using PrologueB = _PrologueB_T; - using Epilogue = _Epilogue_T<_RT_ISA_T>; + using Epilogue = _Epilogue_T; using AType = typename GemmCore::AType; using AParam = typename PrologueA::Param; using BType = typename GemmCore::BType; @@ -228,7 +283,6 @@ class LauncherBase { _GemmCore_T mGemmCore; PrologueA mProA; PrologueB mProB; - Epilogue mEpilogue; class GEMVWrapper { public: @@ -239,22 +293,53 @@ class LauncherBase { if constexpr (!std::is_same_v> && !std::is_same_v> && + !std::is_same_v> && + !std::is_same_v> && !std::is_same_v>) { return false; } if constexpr (GemmCore::ISA == BTLA_ISA::AVX2) { #if CompileAVX2() - static_assert(GemmCore::PACK_ROW == 1); if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_FP32) { + static_assert(GemmCore::PACK_ROW == 1); + return true; + } + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_INT32) { + static_assert(GemmCore::PACK_ROW == 4); + return true; + } +#endif + } + if constexpr (GemmCore::ISA == BTLA_ISA::AVX512_VNNI || GemmCore::ISA == BTLA_ISA::AMX_INT8) { +#if CompileAVX512VNNI() + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_INT32) { + static_assert(GemmCore::PACK_ROW == 4); + return true; + } +#endif + } + if constexpr (GemmCore::ISA == BTLA_ISA::AVX_VNNI) { +#if CompileAVXVNNI() + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_INT32) { + static_assert(GemmCore::PACK_ROW == 4); return true; } #endif } if constexpr (GemmCore::ISA == BTLA_ISA::AVX512F) { #if CompileAVX512F() - static_assert(GemmCore::PACK_ROW == 1); if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_FP32) { + static_assert(GemmCore::PACK_ROW == 1); + return true; + } +#endif + } + if constexpr (GemmCore::ISA == BTLA_ISA::AVX512BW) { +#if CompileAVX512F() + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_INT32) { + static_assert(GemmCore::PACK_ROW == 4); return true; } #endif @@ -281,46 +366,69 @@ class LauncherBase { return impl; } - template + template static void gemv_kblock(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { if constexpr (support()) { auto constexpr TmpSize = 16 * 1024LL; auto constexpr CSize = 8 * 1024LL; auto StackTmp_ = alloca(TmpSize + CSize); auto StackTmp = utils::cpu_pointer_align(StackTmp_); - auto tmpc_ptr = reinterpret_cast((char*)StackTmp + TmpSize); + auto tmpc_ptr = reinterpret_cast((char*)StackTmp + TmpSize); static_assert(CSize >= (MTILE * GemmCore::NTILE * sizeof(float))); - utils::GemvParamB paramB = SNbits::template createB(_param.paramB.packedW); - const float* Aptr = _param.paramA.A; - if constexpr (std::is_same_v>) { - if (_param.paramA.reordered && _param.paramA.reordered->template APtr()) { - Aptr = _param.paramA.reordered->template APtr(); - } - } + utils::GemvParamB paramB = gemv_nbits::NBitsHelper::template createB(_param.paramB.packedW); int m = _param.problem.dims[1]; int n = _param.problem.dims[2]; int k = _param.problem.dims[3]; int kblocksize = _param.problem.dims[4]; - SNbits::template updateBNStep(paramB, _config.loc[1]); + gemv_nbits::NBitsHelper::template updateBNStep(paramB, _config.loc[1]); int size_padded = utils::padto_le(_config.size[1], GemmCore::NTILE); int in = 0; for (; in < size_padded; in += GemmCore::NTILE) { - if constexpr (std::is_same_v) { + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_INT32) { + utils::GemvParamA paramA{ + _param.paramA.quan->template APtr(), _param.paramA.quan->template SPtr(), + _param.paramA.quan->template ZPtr(), _param.paramA.quan->mKPad, _param.paramA.quan->CStep()}; + kernel::wrapper::GEMVWoqNBits::forward_u8s8_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( + paramA, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize); + Epilogue::Fp32Epi::template forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, + GemmCore::NTILE, _param.paramC.param2, StackTmp, TmpSize); + } else { + const float* Aptr = _param.paramA.A; + if constexpr (std::is_same_v>) { + if (_param.paramA.reordered && _param.paramA.reordered->template APtr()) { + Aptr = _param.paramA.reordered->template APtr(); + } + } kernel::wrapper::GEMVWoqNBits::forward_fp32_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( Aptr, _param.paramA.lda, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize); + Epilogue::template forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, GemmCore::NTILE, + _param.paramC, StackTmp, TmpSize); } - Epilogue::forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, GemmCore::NTILE, _param.paramC, - StackTmp, TmpSize); - SNbits::template updateBNStep(paramB, GemmCore::NTILE); + gemv_nbits::NBitsHelper::template updateBNStep(paramB, GemmCore::NTILE); } if (size_padded != _config.size[1]) { - if constexpr (std::is_same_v) { + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_INT32) { + utils::GemvParamA paramA{ + _param.paramA.quan->template APtr(), _param.paramA.quan->template SPtr(), + _param.paramA.quan->template ZPtr(), _param.paramA.quan->mKPad, _param.paramA.quan->CStep()}; + kernel::wrapper::GEMVWoqNBits::forward_u8s8_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( + paramA, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize); + Epilogue::Fp32Epi::template forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, + (_config.size[1] - in), _param.paramC.param2, StackTmp, TmpSize); + } else { + const float* Aptr = _param.paramA.A; + if constexpr (std::is_same_v>) { + if (_param.paramA.reordered && _param.paramA.reordered->template APtr()) { + Aptr = _param.paramA.reordered->template APtr(); + } + } kernel::wrapper::GEMVWoqNBits::forward_fp32_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( Aptr, _param.paramA.lda, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize); + Epilogue::template forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, + (_config.size[1] - in), _param.paramC, StackTmp, TmpSize); } - Epilogue::forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, (_config.size[1] - in), - _param.paramC, StackTmp, TmpSize); } } } @@ -329,187 +437,28 @@ class LauncherBase { if constexpr (support()) { assert(_param.problem.dims[4] > 0); auto& m = _param.problem.dims[1]; - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S4_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S5_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S6_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + if constexpr (Reg32) { + if (m == 5) gemv_kblock(_param, _config); + if (m == 6) gemv_kblock(_param, _config); + if (m == 7) gemv_kblock(_param, _config); + if (m == 8) gemv_kblock(_param, _config); } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S7_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + if constexpr (Reg32) { + if (m == 5) gemv_kblock(_param, _config); + if (m == 6) gemv_kblock(_param, _config); + if (m == 7) gemv_kblock(_param, _config); + if (m == 8) gemv_kblock(_param, _config); } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S1_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } - return; } } } @@ -582,20 +531,20 @@ class LauncherBase { } } } - mEpilogue.forward(tmpC, _config.block[1], (_config.loc[0] + blk_m), _config.loc[1] + blk_n, blk_msize, blk_nsize, - _param.paramC, tmpcache, _config.tmpcachesize); + Epilogue::template forward(tmpC, _config.block[1], (_config.loc[0] + blk_m), _config.loc[1] + blk_n, blk_msize, + blk_nsize, _param.paramC, tmpcache, _config.tmpcachesize); } }; template class _PrologueA_T, - template class _PrologueB_T, template class _Epilogue_T> + template class _PrologueB_T, class _Epilogue_T> class LauncherIntKBlock { public: using GemmCore = _GemmCore_T; static constexpr BTLA_ISA ISA = _RT_ISA_T; using PrologueA = _PrologueA_T; using PrologueB = _PrologueB_T; - using Epilogue = _Epilogue_T<_RT_ISA_T>; + using Epilogue = _Epilogue_T; using AType = typename GemmCore::AType; using AParam = typename PrologueA::Param; using BType = typename GemmCore::BType; @@ -613,7 +562,6 @@ class LauncherIntKBlock { _GemmCore_T mGemmCore; PrologueA mProA; PrologueB mProB; - Epilogue mEpilogue; class GEMVWrapper { public: @@ -628,41 +576,44 @@ class LauncherIntKBlock { } if constexpr (GemmCore::ISA == BTLA_ISA::AVX_VNNI) { #if CompileAVXVNNI() - static_assert(GemmCore::PACK_ROW == 4); if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_FP32) { + static_assert(GemmCore::PACK_ROW == 4); return true; } if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_SS_FP32) { + static_assert(GemmCore::PACK_ROW == 4); return true; } #endif } if constexpr (GemmCore::ISA == BTLA_ISA::AVX2) { #if CompileAVX2() - static_assert(GemmCore::PACK_ROW == 4); if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_FP32) { + static_assert(GemmCore::PACK_ROW == 4); return true; } if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_SS_FP32) { + static_assert(GemmCore::PACK_ROW == 4); return true; } #endif } if constexpr (GemmCore::ISA == BTLA_ISA::AVX512BW) { #if CompileAVX512F() - static_assert(GemmCore::PACK_ROW == 4); if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_FP32) { + static_assert(GemmCore::PACK_ROW == 4); return true; } #endif } if constexpr (GemmCore::ISA == BTLA_ISA::AVX512_VNNI || GemmCore::ISA == BTLA_ISA::AMX_INT8) { #if CompileAVX512VNNI() - static_assert(GemmCore::PACK_ROW == 4); if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_FP32) { + static_assert(GemmCore::PACK_ROW == 4); return true; } if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_SS_FP32) { + static_assert(GemmCore::PACK_ROW == 4); return true; } #endif @@ -687,7 +638,7 @@ class LauncherIntKBlock { return impl; } - template + template static void gemv_kblock(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { if constexpr (support()) { auto constexpr TmpSize = 16 * 1024LL; @@ -696,7 +647,7 @@ class LauncherIntKBlock { auto StackTmp_ = alloca(TmpSize + CSize); auto StackTmp = utils::cpu_pointer_align(StackTmp_); auto tmpc_ptr = reinterpret_cast((char*)StackTmp + TmpSize); - utils::GemvParamB paramB = SNbits::template createB(_param.paramB.packedW); + utils::GemvParamB paramB = gemv_nbits::NBitsHelper::template createB(_param.paramB.packedW); utils::GemvParamA paramA{ _param.paramA.quan->template APtr(), _param.paramA.quan->template SPtr(), _param.paramA.quan->template ZPtr(), _param.paramA.quan->mKPad, _param.paramA.quan->CStep()}; @@ -705,7 +656,7 @@ class LauncherIntKBlock { int n = _param.problem.dims[2]; int k = _param.problem.dims[3]; int kblocksize = _param.problem.dims[4]; - SNbits::template updateBNStep(paramB, _config.loc[1]); + gemv_nbits::NBitsHelper::template updateBNStep(paramB, _config.loc[1]); int size_padded = utils::padto_le(_config.size[1], GemmCore::NTILE); int in = 0; for (; in < size_padded; in += GemmCore::NTILE) { @@ -716,9 +667,9 @@ class LauncherIntKBlock { kernel::wrapper::GEMVWoqNBits::forward_s8s8_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( paramA, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize); } - Epilogue::forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, GemmCore::NTILE, _param.paramC, - StackTmp, TmpSize); - SNbits::template updateBNStep(paramB, GemmCore::NTILE); + Epilogue::template forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, GemmCore::NTILE, + _param.paramC, StackTmp, TmpSize); + gemv_nbits::NBitsHelper::template updateBNStep(paramB, GemmCore::NTILE); } if (size_padded != _config.size[1]) { if constexpr (std::is_same_v) { @@ -728,8 +679,8 @@ class LauncherIntKBlock { kernel::wrapper::GEMVWoqNBits::forward_s8s8_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( paramA, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize); } - Epilogue::forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, (_config.size[1] - in), - _param.paramC, StackTmp, TmpSize); + Epilogue::template forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, + (_config.size[1] - in), _param.paramC, StackTmp, TmpSize); } } } @@ -737,190 +688,28 @@ class LauncherIntKBlock { static void gemv(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { if constexpr (support()) { auto& m = _param.problem.dims[1]; - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S4_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + if constexpr (Reg32) { + if (m == 5) gemv_kblock(_param, _config); + if (m == 6) gemv_kblock(_param, _config); + if (m == 7) gemv_kblock(_param, _config); + if (m == 8) gemv_kblock(_param, _config); } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S5_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S6_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S7_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S1_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } - return; - } - if (_param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP) { - if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } - } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { - if (m == 1) gemv_kblock(_param, _config); - if (m == 2) gemv_kblock(_param, _config); - if (m == 3) gemv_kblock(_param, _config); - if (m == 4) gemv_kblock(_param, _config); - if constexpr (Reg32) { - if (m == 5) gemv_kblock(_param, _config); - if (m == 6) gemv_kblock(_param, _config); - if (m == 7) gemv_kblock(_param, _config); - if (m == 8) gemv_kblock(_param, _config); - } + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + if constexpr (Reg32) { + if (m == 5) gemv_kblock(_param, _config); + if (m == 6) gemv_kblock(_param, _config); + if (m == 7) gemv_kblock(_param, _config); + if (m == 8) gemv_kblock(_param, _config); } - return; } } } @@ -1023,8 +812,8 @@ class LauncherIntKBlock { bcache_stride, ccache_stride, iterk, 1.f, tmp_, _config.tmpcachesize); } } - mEpilogue.forward(tmpC, _config.block[1], (_config.loc[0] + blk_m), _config.loc[1] + blk_n, blk_msize, blk_nsize, - _param.paramC, tmpcache, _config.tmpcachesize); + Epilogue::template forward(tmpC, _config.block[1], (_config.loc[0] + blk_m), _config.loc[1] + blk_n, blk_msize, + blk_nsize, _param.paramC, tmpcache, _config.tmpcachesize); } // _config.block[2](tmpC, _config.block[1], (_config.loc[0] + blk_m), _config.loc[1] + blk_n, blk_msize, + blk_nsize, _param.paramC, tmpcache, _config.tmpcachesize); } }; } // namespace gemm diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index 361577d73..29110cf0d 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -6634,10 +6634,57 @@ static inline BTLA_CODE gemv_7bit_s8s8_fp32(const utils::GemvParamA& A, const ut #endif } // namespace vnni +template +static inline BTLA_CODE mul(const T* src0ptr, const T* src1ptr, T* dstptr, size_t size) { + int constexpr VLen = 8; + size_t velt = utils::padto_le(size, VLen); + size_t i = 0; + auto vfunc = [&]() { + auto v0 = load_T_fp32(src0ptr + i); + auto v1 = load_T_fp32(src1ptr + i); + auto out = _mm256_mul_ps(v0, v1); + store_fp_T(out, dstptr + i); + }; + for (; i < velt; i += VLen) vfunc(); + if (i < size) { + if (size >= VLen) { + i = size - VLen; + vfunc(); + } else { + ref::mul(src0ptr + i, src1ptr + i, dstptr + i, size - i); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE add(const T* src0ptr, const T* src1ptr, T* dstptr, size_t size) { + int constexpr VLen = 8; + size_t velt = utils::padto_le(size, VLen); + size_t i = 0; + auto vfunc = [&]() { + auto v0 = load_T_fp32(src0ptr + i); + auto v1 = load_T_fp32(src1ptr + i); + auto out = _mm256_add_ps(v0, v1); + store_fp_T(out, dstptr + i); + }; + for (; i < velt; i += VLen) vfunc(); + if (i < size) { + if (size >= VLen) { + i = size - VLen; + vfunc(); + } else { + ref::add(src0ptr + i, src1ptr + i, dstptr + i, size - i); + } + } + return BTLA_CODE::Success; +} + #ifdef __GNUC__ #pragma GCC pop_options #else #endif + #endif } // namespace avx2 } // namespace kernel diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index 590024f93..eef6c96fc 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -7784,6 +7784,51 @@ static inline BTLA_CODE gemv_7bit_s8s8_fp32(const utils::GemvParamA& A, const ut #endif } // namespace vnni +template +static inline BTLA_CODE mul(const T* src0ptr, const T* src1ptr, T* dstptr, size_t size) { + int constexpr VLen = 16; + size_t velt = utils::padto_le(size, VLen); + size_t i = 0; + auto vfunc = [&]() { + auto v0 = load_T_fp32(src0ptr + i); + auto v1 = load_T_fp32(src1ptr + i); + auto out = _mm512_mul_ps(v0, v1); + store_fp_T(out, dstptr + i); + }; + for (; i < velt; i += VLen) vfunc(); + if (i < size) { + if (size >= VLen) { + i = size - VLen; + vfunc(); + } else { + ref::mul(src0ptr + i, src1ptr + i, dstptr + i, size - i); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE add(const T* src0ptr, const T* src1ptr, T* dstptr, size_t size) { + int constexpr VLen = 16; + size_t velt = utils::padto_le(size, VLen); + size_t i = 0; + auto vfunc = [&]() { + auto v0 = load_T_fp32(src0ptr + i); + auto v1 = load_T_fp32(src1ptr + i); + auto out = _mm512_add_ps(v0, v1); + store_fp_T(out, dstptr + i); + }; + for (; i < velt; i += VLen) vfunc(); + if (i < size) { + if (size >= VLen) { + i = size - VLen; + vfunc(); + } else { + ref::add(src0ptr + i, src1ptr + i, dstptr + i, size - i); + } + } + return BTLA_CODE::Success; +} #ifdef __GNUC__ #pragma GCC pop_options #else diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index 43373c6c3..eb04a2d8d 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -3392,6 +3392,23 @@ static inline BTLA_CODE gemv_7bit_s8s8_fp32(const utils::GemvParamA& A, const ut return BTLA_CODE::Success; } +template +static inline BTLA_CODE mul(const T* src0ptr, const T* src1ptr, T* dstptr, size_t size) { + for (size_t i = 0; i < size; i++) { + float tmp = float(src0ptr[i]) * float(src1ptr[i]); + dstptr[i] = tmp; + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE add(const T* src0ptr, const T* src1ptr, T* dstptr, size_t size) { + for (size_t i = 0; i < size; i++) { + float tmp = float(src0ptr[i]) + float(src1ptr[i]); + dstptr[i] = tmp; + } + return BTLA_CODE::Success; +} } // namespace ref } // namespace kernel } // namespace bestla diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index 3b65e6cb5..4491e6515 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -1558,6 +1558,66 @@ class GEMVWoqNBits { } }; +template +class Mul { + public: + template + static inline BTLA_CODE forward(const T* src0ptr, const T* src1ptr, T* dstptr, size_t size) { +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + return avx512f::mul(src0ptr, src1ptr, dstptr, size); + } +#endif +#if CompileAVX2() + if constexpr (utils::isa_base::avx2) { + return avx2::mul(src0ptr, src1ptr, dstptr, size); + } +#endif + return ref::mul(src0ptr, src1ptr, dstptr, size); + } + + static inline BTLA_CODE forward_auto(const T* src0ptr, const T* src1ptr, T* dstptr, size_t size) { + GetCPUDevice(); + if (_cd->AVX512F()) { + return forward(src0ptr, src1ptr, dstptr, size); + } + if (_cd->AVX2()) { + return forward(src0ptr, src1ptr, dstptr, size); + } + return forward(src0ptr, src1ptr, dstptr, size); + } +}; + +template +class Add { + public: + template + static inline BTLA_CODE forward(const T* src0ptr, const T* src1ptr, T* dstptr, size_t size) { +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + return avx512f::add(src0ptr, src1ptr, dstptr, size); + } +#endif +#if CompileAVX2() + if constexpr (utils::isa_base::avx2) { + return avx2::add(src0ptr, src1ptr, dstptr, size); + } +#endif + return ref::add(src0ptr, src1ptr, dstptr, size); + } + + static inline BTLA_CODE forward_auto(const T* src0ptr, const T* src1ptr, T* dstptr, size_t size) { + GetCPUDevice(); + if (_cd->AVX512F()) { + return forward(src0ptr, src1ptr, dstptr, size); + } + if (_cd->AVX2()) { + return forward(src0ptr, src1ptr, dstptr, size); + } + return forward(src0ptr, src1ptr, dstptr, size); + } +}; + } // namespace wrapper } // namespace kernel } // namespace bestla diff --git a/bestla/bestla/ut/bestla_benchmark.cpp b/bestla/bestla/ut/bestla_benchmark.cpp index 24a952dbe..8d62b1fac 100644 --- a/bestla/bestla/ut/bestla_benchmark.cpp +++ b/bestla/bestla/ut/bestla_benchmark.cpp @@ -1,8 +1,7 @@ #include #include "bestla_wrapper.h" #include "bestla_ut.h" -#undef BTLA_UT_WRAPPER -#undef BTLA_UT_PROLOGUE_B + namespace bestla { using namespace utils; namespace ut { @@ -13,7 +12,6 @@ class Benchmark_Fp32Fp32 { UT_START(); benchmark_all(1, 4096, 4096); benchmark_all(1024, 4096, 4096); - benchmark_all(2048, 4096, 4096); } using AType = float; @@ -78,10 +76,10 @@ class Benchmark_Fp32Fp32 { auto threads_cfg = UT_Threading::get_threads_config(); for (auto threads : threads_cfg) { if (_cd->AVX512F()) { - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); } if (_cd->AVX2()) { - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); } } } @@ -96,7 +94,6 @@ class Benchmark_U8S8S32 { UT_START(); benchmark_all(1, 4096, 4096); benchmark_all(1024, 4096, 4096); - benchmark_all(2048, 4096, 4096); } using AType = uint8_t; @@ -166,8 +163,6 @@ class Benchmark_U8S8S32 { benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); } if (_cd->AVX512_VNNI()) { - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, - threads); benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); } @@ -177,20 +172,22 @@ class Benchmark_U8S8S32 { if (_cd->AVX_VNNI()) { benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); } + if (_cd->AVX2()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + } } } }; #ifdef BTLA_UT_WRAPPER -#endif static Benchmark_U8S8S32 sBenchmark_U8S8S32; +#endif class Benchmark_S8S8S32 { public: Benchmark_S8S8S32() { UT_START(); - // benchmark_all(1, 4096, 4096); + benchmark_all(1, 4096, 4096); benchmark_all(1024, 4096, 4096); - // benchmark_all(2048, 4096, 4096); } using AType = int8_t; @@ -254,10 +251,14 @@ class Benchmark_S8S8S32 { GetCPUDevice(); auto threads_cfg = UT_Threading::get_threads_config(); for (auto threads : threads_cfg) { + if (_cd->AVX2()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + threads); + } if (_cd->AVX_VNNI()) { benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); - benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); } if (_cd->AMX_INT8()) { @@ -281,7 +282,6 @@ class Benchmark_Bf16Bf16Fp32 { UT_START(); benchmark_all(1, 4096, 4096); benchmark_all(1024, 4096, 4096); - benchmark_all(2048, 4096, 4096); } using AType = utils::bf16; @@ -363,7 +363,6 @@ class Benchmark_Fp16Fp16Fp16 { UT_START(); benchmark_all(1, 4096, 4096); benchmark_all(1024, 4096, 4096); - benchmark_all(2048, 4096, 4096); } using AType = utils::fp16; @@ -443,57 +442,42 @@ class UTWOQ_CompFp32 { public: UTWOQ_CompFp32() { UT_START(); - ut_s1(); - ut_s7(); - ut_s6(); - /*ut_s5(); - ut_s2(); - ut_s4(); - ut_s3();*/ - // ut_s8(); - // ut_f4(); - } - void ut_s1() { - benchmark_all(1, 4096, 4096, BTLA_DTYPE::S1_CLIP); - benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S1_CLIP); - } - void ut_s2() { - benchmark_all(1, 4096, 4096, BTLA_DTYPE::S2_CLIP); - benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S2_CLIP); - } - void ut_s3() { - benchmark_all(1, 4096, 4096, BTLA_DTYPE::S3_CLIP); - benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S3_CLIP); - } - void ut_s4() { - benchmark_all(1, 4096, 4096, BTLA_DTYPE::S4_CLIP); - benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S4_CLIP); - } - void ut_s5() { - benchmark_all(1, 4096, 4096, BTLA_DTYPE::S5_CLIP); - benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S5_CLIP); - } - void ut_s6() { - benchmark_all(1, 4096, 4096, BTLA_DTYPE::S6_CLIP); - benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S6_CLIP); + ut_s4_full(); + ut_new_dtype(BTLA_DTYPE::S1_CLIP); + ut_new_dtype(BTLA_DTYPE::S2_CLIP); + ut_new_dtype(BTLA_DTYPE::S3_CLIP); + ut_new_dtype(BTLA_DTYPE::S5_CLIP); + ut_new_dtype(BTLA_DTYPE::S6_CLIP); + ut_new_dtype(BTLA_DTYPE::S7_CLIP); + ut_new_dtype(BTLA_DTYPE::S8); + ut_f4(); } - void ut_s7() { - benchmark_all(1, 4096, 4096, BTLA_DTYPE::S7_CLIP); - benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S7_CLIP); + + void ut_new_dtype(BTLA_DTYPE qtype) { + benchmark_all(1, 4096, 4096, qtype, true); + benchmark_all(1, 4096, 4096, qtype); + benchmark_all(1, 4096, 4096, qtype, true); } - void ut_s8() { - benchmark_all(1, 4096, 4096, BTLA_DTYPE::S8); - benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S8); + + void ut_s4_full() { + BTLA_DTYPE qtype = BTLA_DTYPE::S4_CLIP; + benchmark_all(1, 4096, 4096, qtype, true); + benchmark_all(1, 4096, 4096, qtype); + benchmark_all(1, 4096, 4096, qtype, true); + benchmark_all(1, 4096, 4096, qtype); + benchmark_all(1024, 4096, 4096, qtype); } void ut_f4() { benchmark_all(1, 4096, 4096, BTLA_DTYPE::F4_BNB); + benchmark_all(1, 4096, 4096, BTLA_DTYPE::F4_E2M1); + benchmark_all(1, 4096, 4096, BTLA_DTYPE::F4_NF4); benchmark_all(1024, 4096, 4096, BTLA_DTYPE::F4_BNB); } template class Wei, typename Scale_T> void benchmark(int m, int n, int k, int batch, int blocksize, float* A, float* B, float* C, float timems, int threads, - BTLA_DTYPE qtype) { + BTLA_DTYPE qtype, bool isasym) { LOG_T log; using Parallel = parallel::gemm::SchedulerBase; using Launcher = wrapper::gemm::LauncherBase, prologue_b::gemm::WeightKBlockNInteger>) { - tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, false); + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, isasym); } else if constexpr (std::is_same_v, prologue_b::gemm::WeightKBlockNFloat>) { @@ -524,7 +508,9 @@ class UTWOQ_CompFp32 { memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); } auto psize = (size_t)m * n * k * 2; - auto memsize = (size_t)packBs[0].mSize + (m * k + m * n) * sizeof(float); + int blks = k / blocksize; + int nbits = utils::bestla_dtype_bits(qtype); + auto memsize = (size_t)(n * k * nbits / 8 + n * blks * sizeof(Scale_T)) + (m * k + m * n) * sizeof(float); tm.start(); while (tm.stop() < timems) { for (int i = 0; i < batch; i++) { @@ -541,16 +527,17 @@ class UTWOQ_CompFp32 { log.record(); double flops = double(psize) / log.min_val / 1e6; double band = double(memsize) / log.min_val / 1e6; + int cores = std::min(threads, device::CpuDevice::getInstance()->getCores()); printf("Threads %d Block %d %s %s Flops:%.3fG PerCoreFlops:%.3fG MemoryBandwidth:%.3fGB/s\n", threads, blocksize, - corestr, log.get_log_str(), flops, flops / threads, band); + corestr, log.get_log_str(), flops, flops / cores, band); } template