Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[BesTLA] The initial SYCL support (#229)
Browse files Browse the repository at this point in the history
* sycl init

* add helper

* add epilogue base

* launcher done

* dequant code

* add s4sgemm

* add sgemv

* add trans B support

* add hgemm

* finish half gemms

* add dequant kernels

* keep TILEK code

* enable all cases

* fix perf on MTL

* update gemv

* update fp16 performance

* add half for getweight

* add tail process for gemm and epilogue

* remove sycl sources when disabled

* protect mha from unsupported compilers
  • Loading branch information
ThanatosShinji authored Apr 25, 2024
1 parent 3736b27 commit dfdfb0f
Show file tree
Hide file tree
Showing 22 changed files with 2,590 additions and 30 deletions.
12 changes: 12 additions & 0 deletions CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@
"cacheVariables": {
"BTLA_UT_OPENMP": "OFF"
}
},
{
"name": "x64-release-sycl",
"displayName": "x64 Release SYCL",
"description": "x64 SYCL",
"inherits": "x64-debug",
"cacheVariables": {
"CMAKE_CXX_COMPILER": "icx-cl",
"CMAKE_C_COMPILER": "icx-cl",
"CMAKE_BUILD_TYPE": "Release",
"BTLA_UT_ALL": "ON"
}
}
]
}
37 changes: 30 additions & 7 deletions bestla/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp)
file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp)

option(BTLA_ENABLE_OPENMP "Compile OpenMP thread pool if OMP can be found" OFF)
option(BTLA_SYCL "Compile OpenMP thread pool if OMP can be found" OFF)

option(BTLA_UT_ALL "Enable all unit tests" OFF)
option(BTLA_UT_DEBUG "Enable debug unit tests" OFF)
Expand All @@ -21,14 +22,26 @@ option(BTLA_UT_NOASAN "Disable sanitize" OFF)
option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" OFF)
option(BTLA_UT_OPENMP "Use OpenMP for UT tests" OFF)





add_library(${PROJECT_NAME} INTERFACE)
add_library(neural_speed::${PROJECT_NAME} ALIAS ${PROJECT_NAME})
target_include_directories(
${PROJECT_NAME} INTERFACE
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>"
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>"
)

set(sycl_headers)
set(sycl_libs)
if(BTLA_SYCL)
include(cmake/sycl.cmake)
file(GLOB sycl_headers ${PROJECT_NAME}/sycl/*.h ${PROJECT_NAME}/sycl/*.hpp)
add_compile_definitions(BTLA_SYCL)
list(APPEND sycl_libs IntelSYCL::SYCL_CXX)
#add_link_options(-fsycl-targets=spir64 -Xsycl-target-backend "-options -ze-opt-large-register-file")
endif(BTLA_SYCL)

if(BTLA_ENABLE_OPENMP)
message(STATUS "BesTLA enable OpenMP ThreadPool")
Expand Down Expand Up @@ -69,12 +82,20 @@ function(add_ut_flag UT_OPTION)
endif()
endfunction()

set(benchmark_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/bestla_benchmark.cpp)
# list(APPEND benchmark_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/sycl_benchmark.cpp)


if(UT_BUILD)
file(GLOB srcs ${PROJECT_NAME}/ut/*.cc ${PROJECT_NAME}/ut/*.cpp) #compile everything even run parts of UTs
list(REMOVE_ITEM srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/bestla_benchmark.cpp)
file(GLOB sycl_srcs ${PROJECT_NAME}/ut/sycl*)
if(NOT BTLA_SYCL)
list(REMOVE_ITEM srcs ${sycl_srcs})
endif()
list(REMOVE_ITEM srcs ${benchmark_srcs})
file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h)
include_directories(${PROJECT_NAME})
add_executable(${PROJECT_NAME}_ut ${srcs} ${headers} ${ut_headers})
add_executable(${PROJECT_NAME}_ut ${srcs} ${headers} ${sycl_headers} ${ut_headers})
if(BTLA_UT_OPENMP)
include(FindOpenMP)
target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP)
Expand All @@ -98,14 +119,16 @@ if(UT_BUILD)
add_ut_flag(BTLA_UT_KERNEL_INTRIN)
add_ut_flag(BTLA_UT_KERNEL_JIT)
add_ut_flag(BTLA_UT_KERNEL_WRAPPER)
target_link_libraries(${PROJECT_NAME}_ut PRIVATE ${PROJECT_NAME})
if(BTLA_SYCL)
add_compile_definitions(BTLA_UT_SYCL)
endif()
target_link_libraries(${PROJECT_NAME}_ut PRIVATE ${PROJECT_NAME} ${sycl_libs})
endif(UT_BUILD)

if(BTLA_UT_BENCHMARK)
file(GLOB srcs ${PROJECT_NAME}/ut/bestla_benchmark.cpp) #compile everything even run parts of UTs
file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h)
include_directories(${PROJECT_NAME})
add_executable(${PROJECT_NAME}_benchmark ${srcs} ${headers} ${ut_headers})
add_executable(${PROJECT_NAME}_benchmark ${benchmark_srcs} ${headers} ${ut_headers})
if(BTLA_UT_OPENMP)
include(FindOpenMP)
target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP)
Expand All @@ -114,5 +137,5 @@ if(BTLA_UT_BENCHMARK)
if(NOT WIN32)
target_link_options(${PROJECT_NAME}_benchmark PRIVATE -lpthread)
endif()
target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE ${PROJECT_NAME})
target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE ${PROJECT_NAME} ${sycl_libs})
endif(BTLA_UT_BENCHMARK)
28 changes: 28 additions & 0 deletions bestla/CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,34 @@
"description": "Target Windows (64-bit) with the Visual Studio development environment. (RelWithDebInfo)",
"inherits": "x64-release",
"cacheVariables": { "BTLA_UT_ALL": "ON" }
},
{
"name": "x64-debug-sycl",
"displayName": "x64 Debug SYCL",
"description": "x64 Debug SYCL",
"inherits": "windows-base",
"architecture": {
"value": "x64",
"strategy": "external"
},
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Debug",
"BTLA_UT_DEBUG": "ON",
"BTLA_UT_ALL": "OFF",
"BTLA_SYCL": "ON",
"BTLA_UT_BENCHMARK": "ON",
"CMAKE_CXX_COMPILER": "icx",
"CMAKE_C_COMPILER": "icx"
}
},
{
"name": "x64-release-sycl",
"displayName": "x64 Release for SYCL",
"description": "x64 SYCL",
"inherits": "x64-debug-sycl",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Release"
}
}
]
}
14 changes: 7 additions & 7 deletions bestla/bestla/bestla_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
#define CompileAMXINT8() (CompileAMX())
#endif

#ifdef _MSC_VER
#if defined(_MSC_VER) && !defined(__INTEL_LLVM_COMPILER)
#define CompileAVX512F() _MSC_VER && (_MSC_VER >= 1911)
#define CompileAVX2() _MSC_VER && (_MSC_VER >= 1900)
#define CompileAMX() 0
Expand All @@ -80,12 +80,12 @@
#define CompileAMXINT8() 0
#endif

#ifdef __clang_major__
#define CompileAVX512F() (__clang_major__ >= 4)
#define CompileAVX2() (__clang_major__ >= 3)
#define CompileAMX() (__clang_major__ >= 11)
#define CompileBF16() (__clang_major__ >= 11)
#define CompileFP16() (__clang_major__ >= 16)
#if defined(_MSC_VER) && defined(__INTEL_LLVM_COMPILER)
#define CompileAVX512F() defined(__AVX512F__)
#define CompileAVX2() defined(__AVX2__) && defined(__F16C__) && defined(__FMA__)
#define CompileAMX() 0
#define CompileBF16() 0
#define CompileFP16() 0
#define CompileAMXBF16() (CompileAMX())
#define CompileAMXINT8() (CompileAMX())
#endif
Expand Down
12 changes: 6 additions & 6 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ namespace bestla {
namespace kernel {
namespace avx2 {
#if CompileAVX2()
#ifdef __GNUC__
#if defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx2", "fma", "f16c")
#else
#elif defined(ICX)
#pragma clang attribute push(__attribute__((target("avx,avx2,fma"))), apply_to = function)
#endif

template <bool LowBits>
static inline __m256i unpack_4bits_avx2(void* srcptr, __m256i mask) {
auto raw_data = _mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr));
Expand Down Expand Up @@ -74,7 +74,7 @@ inline __m256 ymm_cvt_bf16_fp32(__m128i vbf16) {

inline __m128i ymm_cvtepi32_epi16(__m256i src) {
__m128i tmp;
#ifdef __GNUC__
#if defined(__GNUC__) || defined(__clang_major__)
for (size_t i = 0; i < 8; i++) {
(reinterpret_cast<int16_t*>(&tmp))[i] = (reinterpret_cast<int32_t*>(&src))[i];
}
Expand Down Expand Up @@ -443,7 +443,7 @@ inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
e_revert = _mm256_srli_epi32(e_revert, mantissabit);
if constexpr (WITH_SCALE && std::is_same_v<_S_T, utils::f8>) {
auto scale = _mm256_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(sptr + j / _PACK_ROW)));
if constexpr (_PACK_ROW == 2) scale = _mm256_permutexvar_epi32(packrow2_permute_idx, scale);
if constexpr (_PACK_ROW == 2) scale = _mm256_permutevar8x32_epi32(packrow2_permute_idx, scale);
e_revert = _mm256_add_epi32(e_revert, scale);
}
e_revert = _mm256_sub_epi32(e_revert, e_revert_shift);
Expand All @@ -454,7 +454,7 @@ inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
fp_v = _mm256_or_ps(fp_v, _mm256_castsi256_ps(mantissa_revert));
if constexpr (WITH_SCALE && std::is_same_v<_S_T, float>) {
auto scale = _mm256_loadu_ps(sptr + j / _PACK_ROW);
if constexpr (_PACK_ROW == 2) scale = _mm256_permutexvar_ps(packrow2_permute_idx, scale);
if constexpr (_PACK_ROW == 2) scale = _mm256_permutevar8x32_ps(scale, packrow2_permute_idx);
fp_v = _mm256_mul_ps(fp_v, scale);
}
if constexpr (std::is_same_v<_DST_T, float>) {
Expand Down
6 changes: 6 additions & 0 deletions bestla/bestla/kernel_avx512_bf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ static inline BTLA_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr,
}
return BTLA_CODE::Success;
#endif
#if CompileAVX512F()
return avx512f::bf16_cvt_fp32_2D_write_back(src_ptr, dst_ptr, row, col, src_step, dst_step, zeropadding);
#endif
return BTLA_CODE::NotSupport;
}

static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col,
Expand Down Expand Up @@ -83,7 +86,10 @@ static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void
}
return BTLA_CODE::Success;
#endif
#if CompileAVX512F()
return avx512f::fp32_cvt_bf16_2D_write_back(raw_srcptr, raw_dstptr, row, col, srcstride, dststride, zeropadding);
#endif
return BTLA_CODE::NotSupport;
}
#if CompileBF16()
#pragma GCC pop_options
Expand Down
Loading

0 comments on commit dfdfb0f

Please sign in to comment.