From dfdfb0f937b7600ad3aedc4f2038676f53ccecf8 Mon Sep 17 00:00:00 2001 From: ThanatosShinji <108169286+ThanatosShinji@users.noreply.github.com> Date: Thu, 25 Apr 2024 10:32:24 +0800 Subject: [PATCH] [BesTLA] The initial SYCL support (#229) * sycl init * add helper * add epilogue base * launcher done * dequant code * add s4sgemm * add sgemv * add trans B support * add hgemm * finish half gemms * add dequant kernels * keep TILEK code * enable all cases * fix perf on MTL * update gemv * update fp16 performance * add half for getweight * add tail process for gemm and epilogue * remove sycl sources when disabled * protect mha from unsupported compilers --- CMakePresets.json | 12 + bestla/CMakeLists.txt | 37 +- bestla/CMakePresets.json | 28 + bestla/bestla/bestla_utils.h | 14 +- bestla/bestla/kernel_avx2.h | 12 +- bestla/bestla/kernel_avx512_bf16.h | 6 + bestla/bestla/kernel_wrapper.h | 32 +- bestla/bestla/sycl/sycl_device.h | 73 ++ bestla/bestla/sycl/sycl_epilogue.h | 59 ++ bestla/bestla/sycl/sycl_gemm.h | 227 ++++++ bestla/bestla/sycl/sycl_prologue_a.h | 41 ++ bestla/bestla/sycl/sycl_prologue_b.h | 455 ++++++++++++ bestla/bestla/sycl/sycl_utils.h | 110 +++ bestla/bestla/sycl/sycl_wrapper.h | 216 ++++++ bestla/bestla/ut/kernel_intrin.cpp | 4 + bestla/bestla/ut/sycl_benchmark.cpp | 729 +++++++++++++++++++ bestla/bestla/ut/sycl_gemm.cpp | 485 ++++++++++++ bestla/bestla/ut/sycl_misc.cpp | 34 + bestla/bestla/ut/sycl_ut.h | 16 + bestla/cmake/sycl.cmake | 3 + neural_speed/core/layers/mha_dense.cpp | 2 + neural_speed/core/layers/mha_dense_wrapper.h | 25 +- 22 files changed, 2590 insertions(+), 30 deletions(-) create mode 100644 bestla/bestla/sycl/sycl_device.h create mode 100644 bestla/bestla/sycl/sycl_epilogue.h create mode 100644 bestla/bestla/sycl/sycl_gemm.h create mode 100644 bestla/bestla/sycl/sycl_prologue_a.h create mode 100644 bestla/bestla/sycl/sycl_prologue_b.h create mode 100644 bestla/bestla/sycl/sycl_utils.h create mode 100644 bestla/bestla/sycl/sycl_wrapper.h create mode 100644 bestla/bestla/ut/sycl_benchmark.cpp create mode 100644 bestla/bestla/ut/sycl_gemm.cpp create mode 100644 bestla/bestla/ut/sycl_misc.cpp create mode 100644 bestla/bestla/ut/sycl_ut.h create mode 100644 bestla/cmake/sycl.cmake diff --git a/CMakePresets.json b/CMakePresets.json index 6ad6836b1..3a0694af9 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -106,6 +106,18 @@ "cacheVariables": { "BTLA_UT_OPENMP": "OFF" } + }, + { + "name": "x64-release-sycl", + "displayName": "x64 Release SYCL", + "description": "x64 SYCL", + "inherits": "x64-debug", + "cacheVariables": { + "CMAKE_CXX_COMPILER": "icx-cl", + "CMAKE_C_COMPILER": "icx-cl", + "CMAKE_BUILD_TYPE": "Release", + "BTLA_UT_ALL": "ON" + } } ] } diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index a3082acca..e11ea875c 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -5,6 +5,7 @@ file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) option(BTLA_ENABLE_OPENMP "Compile OpenMP thread pool if OMP can be found" OFF) +option(BTLA_SYCL "Compile OpenMP thread pool if OMP can be found" OFF) option(BTLA_UT_ALL "Enable all unit tests" OFF) option(BTLA_UT_DEBUG "Enable debug unit tests" OFF) @@ -21,6 +22,10 @@ option(BTLA_UT_NOASAN "Disable sanitize" OFF) option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" OFF) option(BTLA_UT_OPENMP "Use OpenMP for UT tests" OFF) + + + + add_library(${PROJECT_NAME} INTERFACE) add_library(neural_speed::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) target_include_directories( @@ -28,7 +33,15 @@ target_include_directories( "$" "$" ) - +set(sycl_headers) +set(sycl_libs) +if(BTLA_SYCL) + include(cmake/sycl.cmake) + file(GLOB sycl_headers ${PROJECT_NAME}/sycl/*.h ${PROJECT_NAME}/sycl/*.hpp) + add_compile_definitions(BTLA_SYCL) + list(APPEND sycl_libs IntelSYCL::SYCL_CXX) + #add_link_options(-fsycl-targets=spir64 -Xsycl-target-backend "-options -ze-opt-large-register-file") +endif(BTLA_SYCL) if(BTLA_ENABLE_OPENMP) message(STATUS "BesTLA enable OpenMP ThreadPool") @@ -69,12 +82,20 @@ function(add_ut_flag UT_OPTION) endif() endfunction() +set(benchmark_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/bestla_benchmark.cpp) +# list(APPEND benchmark_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/sycl_benchmark.cpp) + + if(UT_BUILD) file(GLOB srcs ${PROJECT_NAME}/ut/*.cc ${PROJECT_NAME}/ut/*.cpp) #compile everything even run parts of UTs - list(REMOVE_ITEM srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/bestla_benchmark.cpp) + file(GLOB sycl_srcs ${PROJECT_NAME}/ut/sycl*) + if(NOT BTLA_SYCL) + list(REMOVE_ITEM srcs ${sycl_srcs}) + endif() + list(REMOVE_ITEM srcs ${benchmark_srcs}) file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h) include_directories(${PROJECT_NAME}) - add_executable(${PROJECT_NAME}_ut ${srcs} ${headers} ${ut_headers}) + add_executable(${PROJECT_NAME}_ut ${srcs} ${headers} ${sycl_headers} ${ut_headers}) if(BTLA_UT_OPENMP) include(FindOpenMP) target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP) @@ -98,14 +119,16 @@ if(UT_BUILD) add_ut_flag(BTLA_UT_KERNEL_INTRIN) add_ut_flag(BTLA_UT_KERNEL_JIT) add_ut_flag(BTLA_UT_KERNEL_WRAPPER) - target_link_libraries(${PROJECT_NAME}_ut PRIVATE ${PROJECT_NAME}) + if(BTLA_SYCL) + add_compile_definitions(BTLA_UT_SYCL) + endif() + target_link_libraries(${PROJECT_NAME}_ut PRIVATE ${PROJECT_NAME} ${sycl_libs}) endif(UT_BUILD) if(BTLA_UT_BENCHMARK) - file(GLOB srcs ${PROJECT_NAME}/ut/bestla_benchmark.cpp) #compile everything even run parts of UTs file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h) include_directories(${PROJECT_NAME}) - add_executable(${PROJECT_NAME}_benchmark ${srcs} ${headers} ${ut_headers}) + add_executable(${PROJECT_NAME}_benchmark ${benchmark_srcs} ${headers} ${ut_headers}) if(BTLA_UT_OPENMP) include(FindOpenMP) target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP) @@ -114,5 +137,5 @@ if(BTLA_UT_BENCHMARK) if(NOT WIN32) target_link_options(${PROJECT_NAME}_benchmark PRIVATE -lpthread) endif() - target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE ${PROJECT_NAME}) + target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE ${PROJECT_NAME} ${sycl_libs}) endif(BTLA_UT_BENCHMARK) diff --git a/bestla/CMakePresets.json b/bestla/CMakePresets.json index 3fa3071ae..7187120ff 100644 --- a/bestla/CMakePresets.json +++ b/bestla/CMakePresets.json @@ -83,6 +83,34 @@ "description": "Target Windows (64-bit) with the Visual Studio development environment. (RelWithDebInfo)", "inherits": "x64-release", "cacheVariables": { "BTLA_UT_ALL": "ON" } + }, + { + "name": "x64-debug-sycl", + "displayName": "x64 Debug SYCL", + "description": "x64 Debug SYCL", + "inherits": "windows-base", + "architecture": { + "value": "x64", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "BTLA_UT_DEBUG": "ON", + "BTLA_UT_ALL": "OFF", + "BTLA_SYCL": "ON", + "BTLA_UT_BENCHMARK": "ON", + "CMAKE_CXX_COMPILER": "icx", + "CMAKE_C_COMPILER": "icx" + } + }, + { + "name": "x64-release-sycl", + "displayName": "x64 Release for SYCL", + "description": "x64 SYCL", + "inherits": "x64-debug-sycl", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release" + } } ] } diff --git a/bestla/bestla/bestla_utils.h b/bestla/bestla/bestla_utils.h index 284fadb5f..17e24b75e 100644 --- a/bestla/bestla/bestla_utils.h +++ b/bestla/bestla/bestla_utils.h @@ -70,7 +70,7 @@ #define CompileAMXINT8() (CompileAMX()) #endif -#ifdef _MSC_VER +#if defined(_MSC_VER) && !defined(__INTEL_LLVM_COMPILER) #define CompileAVX512F() _MSC_VER && (_MSC_VER >= 1911) #define CompileAVX2() _MSC_VER && (_MSC_VER >= 1900) #define CompileAMX() 0 @@ -80,12 +80,12 @@ #define CompileAMXINT8() 0 #endif -#ifdef __clang_major__ -#define CompileAVX512F() (__clang_major__ >= 4) -#define CompileAVX2() (__clang_major__ >= 3) -#define CompileAMX() (__clang_major__ >= 11) -#define CompileBF16() (__clang_major__ >= 11) -#define CompileFP16() (__clang_major__ >= 16) +#if defined(_MSC_VER) && defined(__INTEL_LLVM_COMPILER) +#define CompileAVX512F() defined(__AVX512F__) +#define CompileAVX2() defined(__AVX2__) && defined(__F16C__) && defined(__FMA__) +#define CompileAMX() 0 +#define CompileBF16() 0 +#define CompileFP16() 0 #define CompileAMXBF16() (CompileAMX()) #define CompileAMXINT8() (CompileAMX()) #endif diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index a6899d8f5..e980fa90a 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -23,12 +23,12 @@ namespace bestla { namespace kernel { namespace avx2 { #if CompileAVX2() -#ifdef __GNUC__ +#if defined(__GNUC__) #pragma GCC push_options #pragma GCC target("avx2", "fma", "f16c") -#else +#elif defined(ICX) +#pragma clang attribute push(__attribute__((target("avx,avx2,fma"))), apply_to = function) #endif - template static inline __m256i unpack_4bits_avx2(void* srcptr, __m256i mask) { auto raw_data = _mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr)); @@ -74,7 +74,7 @@ inline __m256 ymm_cvt_bf16_fp32(__m128i vbf16) { inline __m128i ymm_cvtepi32_epi16(__m256i src) { __m128i tmp; -#ifdef __GNUC__ +#if defined(__GNUC__) || defined(__clang_major__) for (size_t i = 0; i < 8; i++) { (reinterpret_cast(&tmp))[i] = (reinterpret_cast(&src))[i]; } @@ -443,7 +443,7 @@ inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int e_revert = _mm256_srli_epi32(e_revert, mantissabit); if constexpr (WITH_SCALE && std::is_same_v<_S_T, utils::f8>) { auto scale = _mm256_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(sptr + j / _PACK_ROW))); - if constexpr (_PACK_ROW == 2) scale = _mm256_permutexvar_epi32(packrow2_permute_idx, scale); + if constexpr (_PACK_ROW == 2) scale = _mm256_permutevar8x32_epi32(packrow2_permute_idx, scale); e_revert = _mm256_add_epi32(e_revert, scale); } e_revert = _mm256_sub_epi32(e_revert, e_revert_shift); @@ -454,7 +454,7 @@ inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int fp_v = _mm256_or_ps(fp_v, _mm256_castsi256_ps(mantissa_revert)); if constexpr (WITH_SCALE && std::is_same_v<_S_T, float>) { auto scale = _mm256_loadu_ps(sptr + j / _PACK_ROW); - if constexpr (_PACK_ROW == 2) scale = _mm256_permutexvar_ps(packrow2_permute_idx, scale); + if constexpr (_PACK_ROW == 2) scale = _mm256_permutevar8x32_ps(scale, packrow2_permute_idx); fp_v = _mm256_mul_ps(fp_v, scale); } if constexpr (std::is_same_v<_DST_T, float>) { diff --git a/bestla/bestla/kernel_avx512_bf16.h b/bestla/bestla/kernel_avx512_bf16.h index 453b88afd..ece55a5dd 100644 --- a/bestla/bestla/kernel_avx512_bf16.h +++ b/bestla/bestla/kernel_avx512_bf16.h @@ -47,7 +47,10 @@ static inline BTLA_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, } return BTLA_CODE::Success; #endif +#if CompileAVX512F() return avx512f::bf16_cvt_fp32_2D_write_back(src_ptr, dst_ptr, row, col, src_step, dst_step, zeropadding); +#endif + return BTLA_CODE::NotSupport; } static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, @@ -83,7 +86,10 @@ static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void } return BTLA_CODE::Success; #endif +#if CompileAVX512F() return avx512f::fp32_cvt_bf16_2D_write_back(raw_srcptr, raw_dstptr, row, col, srcstride, dststride, zeropadding); +#endif + return BTLA_CODE::NotSupport; } #if CompileBF16() #pragma GCC pop_options diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index a9726f28b..f8751b3c9 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -34,11 +34,13 @@ class PaddingInterleaveMN { template static BTLA_CODE forward(const T_SRC* src, T_DST* dst, int row, int col, int row_pad, int col_pad, int src_step, int dst_step) { +#if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { const auto kern_ret = kernel::avx512f::padding_interleave_cvt::forward( src, dst, NTile, row, col, row_pad, col_pad, src_step, dst_step); if (kern_ret != BTLA_CODE::NotSupport) return kern_ret; } +#endif return ref::padding_interleave(src, dst, row, col, row_pad, col_pad, src_step, dst_step, NTile, RowPack); } }; @@ -62,12 +64,14 @@ class PaddingTransInterleaveMN { template static BTLA_CODE forward(const T_SRC* src, T_DST* dst, int row, int col, int row_pad, int col_pad, int src_step, int dst_step) { +#if CompileAVX512F() // Note: rows/cols and i/j are in terms of src if constexpr (utils::isa_base::avx512f) { const auto kern_ret = kernel::avx512f::padding_trans_interleave_cvt::forward( src, dst, MTile, row, col, row_pad, col_pad, src_step, dst_step); if (kern_ret != BTLA_CODE::NotSupport) return kern_ret; } +#endif return ref::padding_trans_interleave(src, dst, row, col, row_pad, col_pad, src_step, dst_step, MTile, ColPack); } }; @@ -85,7 +89,6 @@ class Memcpy2D { return ret; } } -#if CompileAVX2() if constexpr (utils::isa_base::avx2) { auto align_col = col * sizeof(_SRC_T) / 32 * 32 / sizeof(_SRC_T); ret = kernel::jit::JitMemcpy2DAvx2::forward<_SRC_T, _DST_T>(srcptr, dstptr, row, align_col, srcstep, dststep, @@ -97,7 +100,6 @@ class Memcpy2D { return ret; } } -#endif return kernel::ref::memcpy2d(srcptr, dstptr, row, col * sizeof(_SRC_T), srcstep * sizeof(_SRC_T), dststep * sizeof(_DST_T)); } @@ -106,7 +108,6 @@ class Memcpy2D { static BTLA_CODE forward1(const _SRC_T* srcptr, _DST_T* dstptr, int row, int col, int srcstep, int dststep, void* const_elt_v = nullptr) { auto ret = BTLA_CODE::NotSupport; -#if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { ret = kernel::jit::JitMemcpy2DAvx512f::forward1<_SRC_T, _DST_T, OP_T>(srcptr, dstptr, row, col, srcstep, dststep, const_elt_v); @@ -114,8 +115,6 @@ class Memcpy2D { return ret; } } -#endif -#if CompileAVX2() if constexpr (utils::isa_base::avx2) { auto align_col = col * sizeof(_SRC_T) / 32 * 32 / sizeof(_SRC_T); ret = kernel::jit::JitMemcpy2DAvx2::forward1<_SRC_T, _DST_T, OP_T>(srcptr, dstptr, row, align_col, srcstep, @@ -128,7 +127,6 @@ class Memcpy2D { return ret; } } -#endif return ref::memcpy2d_withop<_SRC_T, _DST_T, OP_T>(srcptr, dstptr, row, col, srcstep, dststep, const_elt_v); } }; @@ -504,10 +502,12 @@ class DecompressKBlockS4S8Fp { reinterpret_cast(tmp), tmpsize); } #endif +#if CompileAVX2() if constexpr (utils::isa_base::avx2) { return avx2::decompress_kblock_s4_s8fp(srcptr, dstptr, row, col, ld_src, ld_dst, reinterpret_cast(tmp), tmpsize); } +#endif return ref::decompress_kblock_s4_s8fp(srcptr, dstptr, row, col, ld_src, ld_dst, reinterpret_cast(tmp), tmpsize); } @@ -605,14 +605,18 @@ class DecompressKBlockF4FpNoscale { static inline BTLA_CODE forward(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, void* tmp, size_t tmpsize) { BTLA_CODE ret = BTLA_CODE::NotSupport; +#if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { return avx512f::decompress_kblock_f4_fp_noscale(srcptr, dstptr, row, col, ld_src, ld_dst, reinterpret_cast(tmp), tmpsize); } +#endif +#if CompileAVX2() if constexpr (utils::isa_base::avx2) { return avx2::decompress_kblock_f4_fp_noscale(srcptr, dstptr, row, col, ld_src, ld_dst, reinterpret_cast(tmp), tmpsize); } +#endif return ref::decompress_kblock_f4_fp_noscale(srcptr, dstptr, row, col, ld_src, ld_dst, reinterpret_cast(tmp), tmpsize); } @@ -669,12 +673,10 @@ class DecompressKBlockS8Fp { static inline BTLA_CODE forward(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, SCA_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad, void* tmp, size_t tmpsize) { -#if CompileAVX512F() if constexpr (utils::isa_base::avx512f && std::is_same_v) { // TODO Scale type support return jit::DequanKBlockS8Fp::forward_avx512f(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad); } -#endif #if CompileAVX2() // PACK_ROW must be 1/4 when using avx2 proB. if constexpr (utils::isa_base::avx2 && std::is_same_v && @@ -694,12 +696,16 @@ class DecompressKBlockS8S8Fp { template static inline BTLA_CODE forward(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, void* tmp, size_t tmpsize) { +#if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { // TODO Scale type support return avx512f::decompress_kblock_s8_s8fp<_DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst); } +#endif +#if CompileAVX2() if constexpr (utils::isa_base::avx2) { // TODO Scale type support return avx2::decompress_kblock_s8_s8fp<_DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst); } +#endif return ref::decompress_kblock_s8_s8fp<_DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst); } }; @@ -756,9 +762,11 @@ class CompFp32BlockScale { return avx512f::accum_alphaN_f32_f32(alpha, srcptr, srcstep, dstptr, dststep, M, N); } #endif +#if CompileAVX2() if constexpr (utils::isa_base::avx2) { return avx2::accum_alphaN_f32_f32(alpha, srcptr, srcstep, dstptr, dststep, M, N); } +#endif return ref::accum_alphaN_f32_f32(alpha, srcptr, srcstep, dstptr, dststep, M, N); } }; @@ -845,12 +853,16 @@ class ColBlockReduceSum { template static inline BTLA_CODE forward(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, float* reduce, int ldr) { +#if CompileAVX512F() if constexpr (utils::isa_base::avx512f && std::is_same_v) { return avx512f::col_block_reduce_sum(srcptr, ldsrc, row, col, blocksize, reduce, ldr); } +#endif +#if CompileAVX2() if constexpr (utils::isa_base::avx2 && std::is_same_v) { return avx2::col_block_reduce_sum(srcptr, ldsrc, row, col, blocksize, reduce, ldr); } +#endif return ref::col_block_reduce_sum(srcptr, ldsrc, row, col, blocksize, reduce, ldr); } }; @@ -911,12 +923,16 @@ class LayerNormalization { template static inline BTLA_CODE forward(const T* srcptr, const T* scaleptr, const T* biasptr, T epsilon, int norm_size, T* dstptr, T* mean, T* mean_square, bool simplified) { +#if CompileAVX512F() if constexpr (utils::isa_base::avx512f && std::is_same_v) { return avx512f::layernorm(srcptr, scaleptr, biasptr, epsilon, norm_size, dstptr, mean, mean_square, simplified); } +#endif +#if CompileAVX2() if constexpr (utils::isa_base::avx2 && std::is_same_v) { return avx2::layernorm(srcptr, scaleptr, biasptr, epsilon, norm_size, dstptr, mean, mean_square, simplified); } +#endif return ref::layernorm(srcptr, scaleptr, biasptr, epsilon, norm_size, dstptr, mean, mean_square, simplified); } template diff --git a/bestla/bestla/sycl/sycl_device.h b/bestla/bestla/sycl/sycl_device.h new file mode 100644 index 000000000..c23d241c1 --- /dev/null +++ b/bestla/bestla/sycl/sycl_device.h @@ -0,0 +1,73 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include + +namespace bestla { + +namespace sycl_device { + +class SyclDevice { + public: + SyclDevice(bool profile) { + // Create an exception handler for asynchronous SYCL exceptions + static auto exception_handler = [](sycl::exception_list e_list) { + for (std::exception_ptr const& e : e_list) { + try { + std::rethrow_exception(e); + } catch (std::exception const& e) { +#if _DEBUG + std::cout << "Failure" << std::endl; +#endif + std::terminate(); + } + } + }; + + auto d_selector{sycl::default_selector_v}; + if (profile) { + sycl::property_list prop = {sycl::property::queue::enable_profiling()}; + mQueue = sycl::queue(d_selector, exception_handler, prop); + } else { + mQueue = sycl::queue(d_selector, exception_handler); + } + } + + inline sycl::queue* getQueue() { return &mQueue; } + + inline std::string getName() { return mQueue.get_device().get_info(); }; + + void print() { + std::cout << "Running on device: " << mQueue.get_device().get_info() << "\n"; + std::cout << "EU count:" << mQueue.get_device().get_info() + << "\n"; // 448 + std::cout << "EU count per subslice:" + << mQueue.get_device().get_info() << "\n"; // 8 + std::cout << "EU SIMD width:" << mQueue.get_device().get_info() + << "\n"; // 8 + std::cout << "HW threads per EU:" + << mQueue.get_device().get_info() << "\n"; // 8 + std::cout << "GPU slices:" << mQueue.get_device().get_info() + << "\n"; // 7 + std::cout << "Subslice per slice:" + << mQueue.get_device().get_info() << "\n"; // 8 + } + sycl::queue mQueue; +}; + +} // namespace sycl_device +} // namespace bestla diff --git a/bestla/bestla/sycl/sycl_epilogue.h b/bestla/bestla/sycl/sycl_epilogue.h new file mode 100644 index 000000000..caa3ef062 --- /dev/null +++ b/bestla/bestla/sycl/sycl_epilogue.h @@ -0,0 +1,59 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#ifdef BTLA_SYCL +#include + +#include "sycl_utils.h" + +namespace bestla { +namespace sycl_epilogue { +template +struct ParamOutputBase { + DstT* C; + int ldc; +}; +template +class OutputBase { + public: + using CType = typename GemmCoreT::TACC; + using DstType = DstT; + using Param = ParamOutputBase; + static inline void store(const Param& _param, CType* tmpAcc, const sycl_utils::nd_item_helper& helper) { +#pragma unroll + for (int im = 0; im < GemmCoreT::TileM; im++) { +#pragma unroll + for (int in = 0; in < GemmCoreT::TileN; in++) { + _param.C[(helper.item_g_m() + im) * _param.ldc + helper.item_g_n() + in] = tmpAcc[im * GemmCoreT::TileN + in]; + } + } + } + + static inline void store_tail(const Param& _param, CType* tmpAcc, const sycl_utils::nd_item_helper& helper, + int m_tail) { + if (m_tail) { + for (int im = 0; im < m_tail; im++) { +#pragma unroll + for (int in = 0; in < GemmCoreT::TileN; in++) { + _param.C[(helper.item_g_m() + im) * _param.ldc + helper.item_g_n() + in] = tmpAcc[im * GemmCoreT::TileN + in]; + } + } + } + } +}; + +} // namespace sycl_epilogue +} // namespace bestla +#endif diff --git a/bestla/bestla/sycl/sycl_gemm.h b/bestla/bestla/sycl/sycl_gemm.h new file mode 100644 index 000000000..7ba1e7963 --- /dev/null +++ b/bestla/bestla/sycl/sycl_gemm.h @@ -0,0 +1,227 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#ifdef BTLA_SYCL +#include + +#include "bestla_utils.h" +#include + +namespace bestla { +namespace sycl_gemm { +namespace xve { +class Config_Fp32Fp32Fp32 { + public: + static int constexpr sg_size = 16; + static int constexpr sg_m = 16; + static int constexpr sg_n = 2; + static int constexpr sg_k = 32; + static int constexpr unroll_k = 4; + static int constexpr wg_m = 8; + static int constexpr wg_n = 32; + + using data_type_a = float; + using data_type_b = float; + using data_type_c = float; + using data_type_acc = float; +}; + +template +class SGemmCoreSharedB { + public: + static int constexpr SgSize = ConfigT::sg_size; + static int constexpr WgM = ConfigT::wg_m; + static int constexpr WgN = ConfigT::wg_n; + static int constexpr SgNStride = WgN / SgSize; + static int constexpr WgWorkers = WgM * WgN; + static int constexpr SgCount = WgWorkers / SgSize; + static int constexpr TileM = ConfigT::sg_m; + static int constexpr TileN = ConfigT::sg_n; + static int constexpr TileK = ConfigT::sg_k; + static int constexpr UnrollK = ConfigT::unroll_k; + static int constexpr WgNEle = WgN * TileN; + static int constexpr WgMEle = WgM * TileM; + static int constexpr SgNEle = SgSize * TileN; + static int constexpr SLM_B_Size = WgNEle * TileK; + static int constexpr SLM_A_Size = 0; + + using TA = typename ConfigT::data_type_a; + using TB = typename ConfigT::data_type_b; + using TC = typename ConfigT::data_type_c; + using TACC = typename ConfigT::data_type_acc; + + using SLM_B_Acc = sycl::local_accessor; + + static inline void compute(const TA* aptr, int lda, const SLM_B_Acc& bacc, TACC* accptr, + const sycl_utils::nd_item_helper>& helper) { +#pragma unroll(1) + for (int ik = 0; ik < TileK; ik += UnrollK) { + int constexpr MReg = TileM / SgSize; + TA regA[UnrollK * MReg]; + for (int im = 0; im < MReg; im++) { + *(sycl::vec*)®A[im * UnrollK] = + *(sycl::vec*)&aptr[(helper.sg_id() + im * SgSize) * lda + ik]; + } + +#pragma unroll + for (int ikk = 0; ikk < UnrollK; ikk++) { + TB tmpB[TileN]; +#pragma unroll + for (int in = 0; in < TileN; in++) { + tmpB[in] = bacc[helper.sg_idx_n() * SgNEle + helper.sg_id() * TileN + in + (ik + ikk) * WgNEle]; + } +#pragma unroll + for (size_t im = 0; im < TileM; im++) { + auto tmpA = helper.sg.shuffle(regA[ikk + im / SgSize * UnrollK], im % SgSize); +#pragma unroll + for (size_t in = 0; in < TileN; in++) { + accptr[im * TileN + in] += tmpA * tmpB[in]; + } + } + } + } + } + + static inline void compute_mtail(const TA* aptr, int lda, const SLM_B_Acc& bacc, TACC* accptr, + const sycl_utils::nd_item_helper>& helper, int& m_tail) { + if (m_tail > 0) { +#pragma unroll(1) + for (int ik = 0; ik < TileK; ik += UnrollK) { + for (int ikk = 0; ikk < UnrollK; ikk++) { + TB tmpB[TileN]; +#pragma unroll + for (int in = 0; in < TileN; in++) { + tmpB[in] = bacc[helper.sg_idx_n() * SgNEle + helper.sg_id() * TileN + in + (ik + ikk) * WgNEle]; + } + for (size_t im = 0; im < m_tail; im++) { + auto tmpA = aptr[im * lda + ik + ikk]; +#pragma unroll + for (size_t in = 0; in < TileN; in++) { + accptr[im * TileN + in] += tmpA * tmpB[in]; + } + } + } + } + } + } +}; + +using DefaultSGemmCore = SGemmCoreSharedB; + +class Config_Fp16Fp16Fp16 { + public: + static int constexpr sg_size = 16; + static int constexpr sg_m = 16; + static int constexpr sg_n = 4; + static int constexpr sg_k = 32; + static int constexpr unroll_k = 4; + static int constexpr wg_m = 16; + static int constexpr wg_n = 32; + + using data_type_a = sycl::half; + using data_type_b = sycl::half; + using data_type_c = sycl::half; + using data_type_acc = sycl::half; +}; + +template +class HGemmCoreSharedB { + public: + static int constexpr SgSize = ConfigT::sg_size; + static int constexpr WgM = ConfigT::wg_m; + static int constexpr WgN = ConfigT::wg_n; + static int constexpr SgNStride = WgN / SgSize; + static int constexpr WgWorkers = WgM * WgN; + static int constexpr SgCount = WgWorkers / SgSize; + static int constexpr TileM = ConfigT::sg_m; + static int constexpr TileN = ConfigT::sg_n; + static int constexpr TileK = ConfigT::sg_k; + static int constexpr UnrollK = ConfigT::unroll_k; + static int constexpr WgNEle = WgN * TileN; + static int constexpr WgMEle = WgM * TileM; + static int constexpr SgNEle = SgSize * TileN; + static int constexpr SLM_B_Size = WgNEle * TileK; + static int constexpr SLM_A_Size = 0; + + using TA = typename ConfigT::data_type_a; + using TB = typename ConfigT::data_type_b; + using TC = typename ConfigT::data_type_c; + using TACC = typename ConfigT::data_type_acc; + + using SLM_B_Acc = sycl::local_accessor; + + static inline void compute(const TA* aptr, int lda, const SLM_B_Acc& bacc, TACC* accptr, + const sycl_utils::nd_item_helper>& helper) { +#pragma unroll(1) + for (int ik = 0; ik < TileK; ik += UnrollK) { + static_assert((UnrollK * sizeof(TA)) % sizeof(float) == 0); + int constexpr MReg = TileM / SgSize; + static_assert(MReg == 1); + TA regA[UnrollK * MReg]; + for (int im = 0; im < MReg; im++) { + *(sycl::vec*)®A[im * UnrollK] = + *(sycl::vec*)&aptr[(helper.sg_id() + im * SgSize) * lda + ik]; + } +#pragma unroll + for (int ikk = 0; ikk < UnrollK; ikk++) { + TB tmpB[TileN]; +#pragma unroll + for (int in = 0; in < TileN; in++) { + tmpB[in] = bacc[helper.sg_idx_n() * SgNEle + helper.sg_id() * TileN + in + (ik + ikk) * WgNEle]; + } +#pragma unroll + for (size_t im = 0; im < TileM; im++) { + auto tmpA = helper.sg.shuffle(regA[ikk + im / SgSize * UnrollK], im % SgSize); +#pragma unroll + for (size_t in = 0; in < TileN; in++) { + accptr[im * TileN + in] += tmpA * tmpB[in]; + } + } + } + } + } + + static inline void compute_mtail(const TA* aptr, int lda, const SLM_B_Acc& bacc, TACC* accptr, + const sycl_utils::nd_item_helper>& helper, + const int& m_tail) { + if (m_tail > 0) { +#pragma unroll(1) + for (int ik = 0; ik < TileK; ik += UnrollK) { +#pragma unroll + for (int ikk = 0; ikk < UnrollK; ikk++) { + TB tmpB[TileN]; +#pragma unroll + for (int in = 0; in < TileN; in++) { + tmpB[in] = bacc[helper.sg_idx_n() * SgNEle + helper.sg_id() * TileN + in + (ik + ikk) * WgNEle]; + } + for (size_t im = 0; im < m_tail; im++) { + auto tmpA = aptr[im * lda + ik + ikk]; +#pragma unroll + for (size_t in = 0; in < TileN; in++) { + accptr[im * TileN + in] += tmpA * tmpB[in]; + } + } + } + } + } + } +}; + +using DefaultHGemmCore = HGemmCoreSharedB; +} // namespace xve + +} // namespace sycl_gemm +} // namespace bestla +#endif diff --git a/bestla/bestla/sycl/sycl_prologue_a.h b/bestla/bestla/sycl/sycl_prologue_a.h new file mode 100644 index 000000000..28350f276 --- /dev/null +++ b/bestla/bestla/sycl/sycl_prologue_a.h @@ -0,0 +1,41 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#ifdef BTLA_SYCL +#include + +#include "bestla_utils.h" +#include + +namespace bestla { +namespace sycl_prologue_a { + +template +struct ParamActivationBase { + const SrcT* A; + int lda; +}; +template +class ActivationBase { + public: + using AType = typename GemmCoreT::TA; + using SrcType = SrcT; + using Param = ParamActivationBase; + static inline void getActivation(const Param& _param, AType* aptr, sycl_utils::nd_item_helper& helper) {} +}; + +} // namespace sycl_prologue_a +} // namespace bestla +#endif diff --git a/bestla/bestla/sycl/sycl_prologue_b.h b/bestla/bestla/sycl/sycl_prologue_b.h new file mode 100644 index 000000000..089a81dd5 --- /dev/null +++ b/bestla/bestla/sycl/sycl_prologue_b.h @@ -0,0 +1,455 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#ifdef BTLA_SYCL +#include + +#include "bestla_utils.h" +#include + +namespace bestla { +namespace sycl_prologue_b { + +template +struct ParamWeightBase { + const SrcT* B; + int ldb; +}; +template +class WeightBase { + public: + using BType = typename GemmCoreT::TB; + using SRCType = SrcT; + using Param = ParamWeightBase; + + static inline void getWeight(const Param& _param, const sycl::local_accessor& dstptr, int koffset, + const sycl_utils::nd_item_helper& helper) { + int constexpr Iter_PerWorker = (GemmCoreT::TileK + GemmCoreT::WgM - 1) / GemmCoreT::WgM; +#pragma unroll + for (int icp = 0; icp < Iter_PerWorker; icp++) { + { + for (size_t in = 0; in < GemmCoreT::TileN; in++) { + dstptr[(helper.sg_idx_m() + icp * GemmCoreT::WgM) * GemmCoreT::WgNEle + + (helper.sg_idx_n() * GemmCoreT::SgSize + helper.sg_id()) * GemmCoreT::TileN + in] = + _param.B[helper.item_g_n() + in + (koffset + helper.sg_idx_m() + icp * GemmCoreT::WgM) * _param.ldb]; + } + } + } + } +}; + +class KernelConfigBase { + public: + static int constexpr SgSize = 16; + static int constexpr TileK = 16; + static int constexpr TileN = 2; +}; + +template +struct ParamWeightS4 { + const uint8_t* B; + const ScaleT* scale; + int ldb; +}; + +template +class WeightS4 { + public: + using BType = typename GemmCoreT::TB; + using Param = ParamWeightS4; + + static inline void getWeight(const Param& _param, const sycl::local_accessor& dstptr, int koffset, + int blocksize, const sycl_utils::nd_item_helper& helper) { + int constexpr Iter_PerWorker = (GemmCoreT::TileK + GemmCoreT::WgM - 1) / GemmCoreT::WgM; + ScaleT scale[GemmCoreT::TileN]; + for (size_t in = 0; in < GemmCoreT::TileN; in += 1) + scale[in] = _param.scale[helper.item_g_n() + in + koffset / blocksize * _param.ldb]; +#pragma unroll + for (int icp = 0; icp < Iter_PerWorker; icp++) { + { + for (size_t in = 0; in < GemmCoreT::TileN; in += 2) { + auto tmps8 = + _param + .B[(helper.item_g_n() + in + (koffset + helper.sg_idx_m() + icp * GemmCoreT::WgM) * _param.ldb) / 2]; + dstptr[(helper.sg_idx_m() + icp * GemmCoreT::WgM) * GemmCoreT::WgNEle + + (helper.sg_idx_n() * GemmCoreT::SgSize + helper.sg_id()) * GemmCoreT::TileN + in] = + static_cast((tmps8 & 0x0f) << 4) * scale[in]; + dstptr[(helper.sg_idx_m() + icp * GemmCoreT::WgM) * GemmCoreT::WgNEle + + (helper.sg_idx_n() * GemmCoreT::SgSize + helper.sg_id()) * GemmCoreT::TileN + in + 1] = + static_cast((tmps8 & 0xf0)) * scale[in + 1]; + } + } + } + } + + template + static inline sycl::event dequant_s4(int n, int k, int blocksize, const Param& in, BType* outptr, sycl::queue* q) { + int constexpr SgSize = KernelConfigBase::SgSize; + int constexpr TileK = KernelConfigBase::TileK; + int constexpr TileN = KernelConfigBase::TileN; + int constexpr GroupN = SgSize * TileN; + int constexpr GroupK = TileK; + static_assert(TileN % 2 == 0); + assert(blocksize % TileK == 0); + + int nsg_k = k / GroupK; + int nsg_n = n / GroupN; + sycl::range<1> group{SgSize}; + sycl::range<1> problem{nsg_n * nsg_k * SgSize}; + auto B_d = in.B; + auto S_d = in.scale; + int ldb = in.ldb; + auto deq_kernel = [&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<1>(problem, group), + [=](sycl::nd_item<1> it) [[intel::reqd_sub_group_size(SgSize)]] { + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_idx_n = g_idx % nsg_n; + int g_idx_k = g_idx / nsg_n; + int g_n = g_idx_n * GroupN; + int g_k = g_idx_k * GroupK; + auto sptr = S_d + g_k / blocksize * ldb + g_n; + auto bptr = B_d + (g_k * ldb + g_n) / 2; + auto dbptr = outptr + g_k * n + g_n; + float tmp[TileK * TileN]; + float scale[TileN]; + for (int in = 0; in < TileN; in += 1) { + scale[in] = sptr[sg_id * TileN + in]; + } + for (int ik = 0; ik < TileK; ik += 1) { + for (int in = 0; in < TileN; in += 2) { + uint8_t srcu8 = *(bptr + (ik * ldb + sg_id * TileN + in) / 2); + tmp[ik * TileN + in] = static_cast((srcu8 & 0x0f) << 4) * scale[in]; + tmp[ik * TileN + in + 1] = static_cast((srcu8 & 0xf0)) * scale[in + 1]; + } + } + for (int ik = 0; ik < TileK; ik += 1) { + for (int in = 0; in < TileN; in += 1) { + dbptr[ik * n + sg_id * TileN + in] = tmp[ik * TileN + in]; + } + } + }); + }; + return q->submit(deq_kernel); + } +}; + +class KernelConfigTrans { + public: + static int constexpr SgSize = 16; + static int constexpr TileK = 32; + static int constexpr TileN = 1; +}; + +template +class WeightS4Trans { + public: + using AType = typename GemmCoreT::TA; + using BType = typename GemmCoreT::TB; + using CType = typename GemmCoreT::TC; + using Param = ParamWeightS4; + + static inline void getWeight(const Param& _param, const sycl::local_accessor& dstptr, int koffset, + int blocksize, const sycl_utils::nd_item_helper& helper) { + int constexpr LoadTileK = 2; + static_assert(GemmCoreT::TileK == (LoadTileK * GemmCoreT::SgSize)); + int constexpr Iter_PerWorker = GemmCoreT::WgNEle / GemmCoreT::SgCount; + auto wldb = _param.ldb * blocksize; + int sgn = helper.wg_g_n() + helper.sg_group_id(); + int sg_off = helper.sg_id() * LoadTileK * GemmCoreT::WgNEle; +#pragma unroll + for (int icp = 0; icp < Iter_PerWorker; icp++) { + { + auto scale = _param.scale[(sgn + icp * GemmCoreT::SgCount) * _param.ldb + koffset / blocksize]; + auto tmps8 = _param.B[((sgn + icp * GemmCoreT::SgCount) * wldb + (koffset + helper.sg_id() * LoadTileK)) / 2]; + if constexpr (std::is_same_v) { + sycl::half2 tmpBf = {static_cast((tmps8 & 0x0f) << 4), static_cast((tmps8 & 0xf0))}; + tmpBf *= scale; + dstptr[sg_off + helper.sg_group_id() + icp * GemmCoreT::SgCount] = tmpBf[0]; + dstptr[sg_off + GemmCoreT::WgNEle + helper.sg_group_id() + icp * GemmCoreT::SgCount] = tmpBf[1]; + } else { + dstptr[sg_off + helper.sg_group_id() + icp * GemmCoreT::SgCount] = + static_cast((tmps8 & 0x0f) << 4) * scale; + dstptr[sg_off + GemmCoreT::WgNEle + helper.sg_group_id() + icp * GemmCoreT::SgCount] = + static_cast((tmps8 & 0xf0)) * scale; + } + } + } + } + + template + static inline sycl::event dequant_s4(int n, int k, int blocksize, const Param& in, BType* outptr, sycl::queue* q) { + int constexpr SgSize = KernelConfigBase::SgSize; + int constexpr TileK = KernelConfigBase::TileK; + int constexpr TileN = KernelConfigBase::TileN; + int constexpr GroupN = TileN; + int constexpr SubGroupK = SgSize * TileK; + int constexpr GroupK = SgSize * TileK; + static_assert(TileN == 1); + assert(blocksize % TileK == 0); + + int nsg_k = k / GroupK; + int nsg_n = n / GroupN; + sycl::range<1> group{SgSize}; + sycl::range<1> problem{nsg_n * nsg_k * SgSize}; + auto B_d = in.B; + auto S_d = in.scale; + int ldb = in.ldb; + int ldbn = in.ldb * blocksize; + auto deq_kernel = [&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<1>(problem, group), [=](sycl::nd_item<1> it) [[intel::reqd_sub_group_size(SgSize)]] { + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int sg_group_id = sg.get_group_id()[0]; + int g_idx_n = g_idx / nsg_k; + int g_idx_k = g_idx % nsg_k; + int g_n = g_idx_n * GroupN; + int g_k = g_idx_k * GroupK; + int sg_k = g_k + sg_group_id * SubGroupK; + auto sptr = S_d + sg_k / blocksize + g_n * ldb; + auto bptr = B_d + (sg_k + g_n * ldbn) / 2; + auto dbptr = outptr + sg_k + g_n * k; + float tmp[TileK]; + int constexpr Unroll = 4; +#pragma unroll + for (int ik = 0; ik < TileK; ik += Unroll) { + float dst[Unroll]; + float scale = sptr[(ik * SgSize + sg_id * Unroll) / blocksize]; + for (int ir = 0; ir < Unroll; ir += 2) { + uint8_t srcu8 = *(bptr + (ik * SgSize + sg_id * Unroll + ir) / 2); + dst[ir] = static_cast((srcu8 & 0x0f) << 4) * scale; + dst[ir + 1] = static_cast((srcu8 & 0xf0)) * scale; + } + *(sycl::vec*)&dbptr[ik * SgSize + sg_id * Unroll] = *(sycl::vec*)dst; + } + }); + }; + return q->submit(deq_kernel); + } + +#if 0 + template + static inline sycl::event dequant_s4_trans(int n, int k, int blocksize, const Param& in, BType* outptr, + sycl::queue* q) { + int constexpr SgSize = 16; + int constexpr TileK = 2; + int constexpr TileN = 16; + int constexpr GroupN = TileN; + int constexpr GroupK = SgSize * TileK; + assert(blocksize % TileK == 0); + static_assert(TileN == SgSize); + int nsg_k = k / GroupK; + int nsg_n = n / GroupN; + sycl::range<1> group{SgSize}; + sycl::range<1> problem{nsg_n * nsg_k * SgSize}; + auto B_d = in.B; + auto S_d = in.scale; + int ldb = in.ldb; + int ldbn = in.ldb * blocksize; + auto deq_kernel = [&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<1>(problem, group), + [=](sycl::nd_item<1> it) [[intel::reqd_sub_group_size(SgSize)]] { + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_idx_n = g_idx / nsg_k; + int g_idx_k = g_idx % nsg_k; + int g_n = g_idx_n * GroupN; + int g_k = g_idx_k * GroupK; + auto sptr = S_d + g_k / blocksize + g_n * ldb; + auto bptr = B_d + (g_k + g_n * ldbn) / 2; + auto dbptr = outptr + g_k * n + g_n; + float tmp[TileN * TileK]; + for (int in = 0; in < TileN; in++) { + float scale = sptr[sg_id * TileK / blocksize + in * ldb]; + for (int ik = 0; ik < TileK; ik += 2) { + uint8_t srcu8 = *(bptr + (sg_id * TileK + ik + in * ldbn) / 2); + tmp[in * TileK + ik] = static_cast((srcu8 & 0x0f) << 4) * scale; + tmp[in * TileK + ik + 1] = static_cast((srcu8 & 0xf0)) * scale; + } + } + + float tmpT[TileN * TileK]; + for (int ik = 0; ik < TileK; ik++) { + for (int in = 0; in < TileN; in++) { + for (int is = 0; is < SgSize; is++) { + auto shlv = sg.shuffle(tmp[in * TileK + ik], is); + if (sg_id == in) { + tmpT[ik * TileN + is] = shlv; + } + } + } + } + for (int in = 0; in < TileN; in++) { + for (int ik = 0; ik < TileK; ik++) { + dbptr[sg_id + (in * TileK + ik) * n] = tmpT[ik * TileN + in]; + } + } + }); + }; + return q->submit(deq_kernel); + } +#else + template + static inline sycl::event dequant_s4_trans(int n, int k, int blocksize, const Param& in, BType* outptr, + sycl::queue* q) { + int constexpr SgSize = 16; + int constexpr TileK = 1; + int constexpr TileN = 16; + int constexpr GroupN = TileN; + int constexpr GroupK = SgSize * TileK; + assert(blocksize % TileK == 0); + static_assert(TileN == SgSize); + static_assert(TileK == 1); + int nsg_k = k / GroupK; + int nsg_n = n / GroupN; + sycl::range<1> group{SgSize}; + sycl::range<1> problem{nsg_n * nsg_k * SgSize}; + auto B_d = in.B; + auto S_d = in.scale; + int ldb = in.ldb; + int ldbn = in.ldb * blocksize; + auto deq_kernel = [&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<1>(problem, group), + [=](sycl::nd_item<1> it) [[intel::reqd_sub_group_size(SgSize)]] { + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_idx_n = g_idx / nsg_k; + int g_idx_k = g_idx % nsg_k; + int g_n = g_idx_n * GroupN; + int g_k = g_idx_k * GroupK; + auto sptr = S_d + g_k / blocksize + g_n * ldb; + auto bptr = B_d + (g_k + g_n * ldbn) / 2; + auto dbptr = outptr + g_k * n + g_n; + float tmp[TileN]; + bool high4 = sg_id % 2 != 0; + for (int in = 0; in < TileN; in++) { + float scale = sptr[sg_id * TileK / blocksize + in * ldb]; + uint8_t srcu8 = *(bptr + (sg_id * TileK + in * ldbn) / 2); + tmp[in] = high4 ? static_cast((srcu8 & 0xf0)) * scale + : static_cast((srcu8 & 0x0f) << 4) * scale; + } + + float tmpT[TileN]; + for (int in = 0; in < TileN; in++) { + for (int is = 0; is < SgSize; is++) { + auto shlv = sg.shuffle(tmp[in], is); + if (sg_id == in) { + tmpT[is] = shlv; + } + } + } + for (int in = 0; in < TileN; in++) { + dbptr[sg_id + in * n] = tmpT[in]; + } + }); + }; + return q->submit(deq_kernel); + } +#endif + + static inline sycl::event gemv(const AType* A, const Param& paramB, CType* C, int n, int k, int blocksize, + sycl::queue* q) { + auto B = paramB.B; + auto B_scale = paramB.scale; + int ldb = paramB.ldb; + int constexpr SgSize = 16; + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + sycl::range<1> group{SgSize}; + sycl::range<1> problem{n * SgSize}; + + auto ev = q->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<1>(problem, group), + [=](sycl::nd_item<1> it) [[cl::reqd_work_group_size( + 1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] { + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = B_scale + g_n * ldb; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + if constexpr (std::is_same_v) { + sycl::half2 tmpAcc = {0.f, 0.f}; + int constexpr Unroll = 2; + for (int i = 0; i < k; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; + sycl::half2 tmpB = {static_cast((tmps8[ikk / 2] & 0x0f) << 4), + static_cast((tmps8[ikk / 2] & 0xf0))}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + sycl::half2 sum = {0.f, 0.f}; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum[0] + sum[1]; + } + } else { + CType tmpAcc = 0.f; + int constexpr Unroll = 2; + for (int i = 0; i < k; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + tmpAcc += + CType(aptr[sg_id * TileK + ikk]) * static_cast((tmps8[ikk / 2] & 0x0f) << 4) * scale; + tmpAcc += + CType(aptr[sg_id * TileK + ikk + 1]) * static_cast((tmps8[ikk / 2] & 0xf0)) * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + float sum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum; + } + } + }); + }); + return ev; + } +}; +} // namespace sycl_prologue_b +} // namespace bestla +#endif diff --git a/bestla/bestla/sycl/sycl_utils.h b/bestla/bestla/sycl/sycl_utils.h new file mode 100644 index 000000000..2cdf01626 --- /dev/null +++ b/bestla/bestla/sycl/sycl_utils.h @@ -0,0 +1,110 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "sycl_device.h" +#include "bestla_utils.h" + +namespace bestla { +namespace sycl_utils { + +struct sycl_deleter { + sycl::queue* queue_; + sycl_deleter(sycl::queue* _q) : queue_(_q) {} + template + void operator()(T* obj) const { + if (obj) { + sycl::free(obj, *queue_); + } + } +}; + +template +struct sycl_vector { + sycl_vector(uint64_t _size = 0, sycl::queue* _q = nullptr) : size_(_size) { + if (_q && _size) { + resize(_size, _q); + } + } + + void resize(uint64_t _size, sycl::queue* _q) { + size_ = _size; + _T* tmp = sycl::malloc_device<_T>(_size, *_q); + ptr_ = std::shared_ptr<_T>(tmp, sycl_deleter(_q)); + } + + inline uint64_t size() { return size_; } + + inline _T* data() { return ptr_.get(); } + + std::shared_ptr<_T> ptr_; + uint64_t size_; +}; + +template +__inline__ std::vector sycl2host(const T* syclptr, size_t elecount, sycl::queue* q) { + std::vector tmp(elecount); + q->memcpy(tmp.data(), syclptr, elecount * sizeof(T)).wait(); + return tmp; +} + +class event_helper { + public: + static float elapsed_time(sycl::event& evt) { + float t = 0.f; + const auto startKernExecutionTimePoint = evt.get_profiling_info(); + const auto endKernExecutionTimePoint = evt.get_profiling_info(); + t = (endKernExecutionTimePoint - startKernExecutionTimePoint) / 1e6; + return t; + } + + static float execute_time(sycl::event& evt) { + float t = 0.f; + const auto startKernExecutionTimePoint = evt.get_profiling_info(); + const auto endKernExecutionTimePoint = evt.get_profiling_info(); + t = (endKernExecutionTimePoint - startKernExecutionTimePoint) / 1e6; + return t; + } +}; +template +class nd_item_helper { + public: + const sycl::nd_item<2> it; + const sycl::sub_group sg; + nd_item_helper(sycl::nd_item<2>& _it) : it(_it), sg(it.get_sub_group()) {} + + constexpr inline void local_barrier() const { it.barrier(sycl::access::fence_space::local_space); } + + constexpr inline int sg_group_id() const { return sg.get_group_id()[0]; } + + constexpr inline int wg_idx_m() const { return it.get_group(0); } + constexpr inline int wg_size_m() const { return GemmCoreT::WgM * GemmCoreT::TileM; } + constexpr inline int wg_g_m() const { return wg_idx_m() * wg_size_m(); } + + constexpr inline int wg_idx_n() const { return it.get_group(1); } + constexpr inline int wg_size_n() const { return GemmCoreT::WgN * GemmCoreT::TileN; } + constexpr inline int wg_g_n() const { return wg_idx_n() * wg_size_n(); } + + constexpr inline int sg_idx_m() const { return sg_group_id() / GemmCoreT::SgNStride; } + constexpr inline int sg_g_m() const { return wg_g_m() + sg_idx_m() * GemmCoreT::TileM; } + + constexpr inline int sg_idx_n() const { return sg_group_id() % GemmCoreT::SgNStride; } + constexpr inline int sg_g_n() const { return wg_g_n() + sg_idx_n() * GemmCoreT::SgSize * GemmCoreT::TileN; } + + constexpr inline int sg_id() const { return sg.get_local_id()[0]; } + constexpr inline int item_g_m() const { return sg_g_m(); } + constexpr inline int item_g_n() const { return sg_g_n() + sg_id() * GemmCoreT::TileN; } +}; + +} // namespace sycl_utils +} // namespace bestla diff --git a/bestla/bestla/sycl/sycl_wrapper.h b/bestla/bestla/sycl/sycl_wrapper.h new file mode 100644 index 000000000..29dd84997 --- /dev/null +++ b/bestla/bestla/sycl/sycl_wrapper.h @@ -0,0 +1,216 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#ifdef BTLA_SYCL +#include + +#include "bestla_utils.h" +#include "sycl_utils.h" +#include "sycl_device.h" +#include "sycl_gemm.h" +#include "sycl_epilogue.h" +#include "sycl_prologue_a.h" +#include "sycl_prologue_b.h" + +namespace bestla { +namespace sycl_wrapper { +template