Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
support more epilogue classes
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed May 8, 2024
1 parent 7d49516 commit c3d9073
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 24 deletions.
32 changes: 11 additions & 21 deletions bestla/bestla/bestla_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,31 +191,26 @@ class LauncherBase {
int n = _param.problem.dims[2];
int k = _param.problem.dims[3];
int kblocksize = _param.problem.dims[4];
auto Cptr = _param.paramC.C + _config.loc[1];
SNbits::template updateBNStep<ScaleT>(paramB, _config.loc[1]);
int size_padded = utils::padto_le(_config.size[1], GemmCore::NTILE);
int in = 0;
for (; in < size_padded; in += GemmCore::NTILE) {
if constexpr (std::is_same_v<AType, float>) {
kernel::wrapper::GEMVWoqNBits::forward_fp32_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>(
Aptr, _param.paramA.lda, paramB, Cptr, _param.paramC.ldc, k, kblocksize, StackTmp, TmpSize);
Aptr, _param.paramA.lda, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize);
}

Cptr += GemmCore::NTILE;
Epilogue::forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, GemmCore::NTILE, _param.paramC,
StackTmp, TmpSize);
SNbits::template updateBNStep<ScaleT>(paramB, GemmCore::NTILE);
}
if (size_padded != _config.size[1]) {
if constexpr (std::is_same_v<AType, float>) {
kernel::wrapper::GEMVWoqNBits::forward_fp32_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>(
Aptr, _param.paramA.lda, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize);
}
for (int i = 0; i < MTILE; i++) {
memcpy(Cptr + i * _param.paramC.ldc, tmpc_ptr + i * GemmCore::NTILE,
(_config.size[1] - in) * sizeof(CType));
}
Epilogue::forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, (_config.size[1] - in),
_param.paramC, StackTmp, TmpSize);
}
Epilogue::forward(_param.paramC.C + _config.loc[1], _param.paramC.ldc, 0, _config.loc[1], MTILE,
_config.size[1], _param.paramC, StackTmp, TmpSize);
}
}

Expand Down Expand Up @@ -448,20 +443,19 @@ class LauncherIntKBlock {
int n = _param.problem.dims[2];
int k = _param.problem.dims[3];
int kblocksize = _param.problem.dims[4];
auto Cptr = _param.paramC.C + _config.loc[1];
SNbits::template updateBNStep<ScaleT>(paramB, _config.loc[1]);
int size_padded = utils::padto_le(_config.size[1], GemmCore::NTILE);
int in = 0;
for (; in < size_padded; in += GemmCore::NTILE) {
if constexpr (std::is_same_v<AType, uint8_t>) {
kernel::wrapper::GEMVWoqNBits::forward_u8s8_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>(
paramA, paramB, Cptr, _param.paramC.ldc, k, kblocksize, StackTmp, TmpSize);
paramA, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize);
} else if constexpr (std::is_same_v<AType, int8_t>) {
kernel::wrapper::GEMVWoqNBits::forward_s8s8_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>(
paramA, paramB, Cptr, _param.paramC.ldc, k, kblocksize, StackTmp, TmpSize);
paramA, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize);
}

Cptr += GemmCore::NTILE;
Epilogue::forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, GemmCore::NTILE, _param.paramC,
StackTmp, TmpSize);
SNbits::template updateBNStep<ScaleT>(paramB, GemmCore::NTILE);
}
if (size_padded != _config.size[1]) {
Expand All @@ -472,13 +466,9 @@ class LauncherIntKBlock {
kernel::wrapper::GEMVWoqNBits::forward_s8s8_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>(
paramA, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize);
}
for (int i = 0; i < MTILE; i++) {
memcpy(Cptr + i * _param.paramC.ldc, tmpc_ptr + i * GemmCore::NTILE,
(_config.size[1] - in) * sizeof(CType));
}
Epilogue::forward(tmpc_ptr, GemmCore::NTILE, 0, _config.loc[1] + in, MTILE, (_config.size[1] - in),
_param.paramC, StackTmp, TmpSize);
}
Epilogue::forward(_param.paramC.C + _config.loc[1], _param.paramC.ldc, 0, _config.loc[1], MTILE,
_config.size[1], _param.paramC, StackTmp, TmpSize);
}
}

Expand Down
8 changes: 5 additions & 3 deletions bestla/bestla/ut/bestla_benchmark.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#include <stdio.h>
#include "bestla_wrapper.h"
#include "bestla_ut.h"
#undef BTLA_UT_WRAPPER
#undef BTLA_UT_PROLOGUE_B

namespace bestla {
using namespace utils;
namespace ut {
Expand Down Expand Up @@ -747,6 +746,9 @@ class UTWOQ_CompInt8 {
int blks = k / blocksize;
int nbits = utils::bestla_dtype_bits(qtype);
auto memsize = (size_t)(n * k * nbits / 8 + n * blks * sizeof(Scale_T)) + (m * k + m * n) * sizeof(float);
if (isasym) {
memsize += n * blks * sizeof(int8_t);
}
tm.start();
while (tm.stop() < timems) {
for (int i = 0; i < batch; i++) {
Expand Down Expand Up @@ -808,8 +810,8 @@ class UTWOQ_CompInt8 {
}
};
#ifdef BTLA_UT_PROLOGUE_B
#endif
static UTWOQ_CompInt8 sUTWOQ_CompInt8;
#endif

#if 0
typedef struct {
Expand Down
2 changes: 2 additions & 0 deletions bestla/bestla/ut/kernel_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class UT_PaddingInterleaveMN {
UT_PaddingInterleaveMN() {
UT_START();
// ut<48, 2, bf16, bf16>(128, 128, 2); // TO IMPLEMENT
CheckISA(AVX512_FP16);
ut<32, 2, fp16, bf16>(128, 128, 2);
}
template <int NTile, int RowPack, typename T_SRC, typename T_DST>
Expand Down Expand Up @@ -120,6 +121,7 @@ class UT_PaddingTransInterleaveMN {
public:
UT_PaddingTransInterleaveMN() {
UT_START();
CheckISA(AVX512_FP16);
// ut<48, 2, bf16, bf16>(128, 128, 2); // TO IMPLEMENT
ut<32, 2, fp16, bf16>(128, 128, 2);
}
Expand Down

0 comments on commit c3d9073

Please sign in to comment.