Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[BesTLA] The initial SYCL support #229

Merged
merged 25 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@
"cacheVariables": {
"BTLA_UT_OPENMP": "OFF"
}
},
{
"name": "x64-release-sycl",
"displayName": "x64 Release SYCL",
"description": "x64 SYCL",
"inherits": "x64-debug",
"cacheVariables": {
"CMAKE_CXX_COMPILER": "icx-cl",
"CMAKE_C_COMPILER": "icx-cl",
"CMAKE_BUILD_TYPE": "Release",
"BTLA_UT_ALL": "ON"
}
}
]
}
37 changes: 30 additions & 7 deletions bestla/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp)
file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp)

option(BTLA_ENABLE_OPENMP "Compile OpenMP thread pool if OMP can be found" OFF)
option(BTLA_SYCL "Compile OpenMP thread pool if OMP can be found" OFF)

option(BTLA_UT_ALL "Enable all unit tests" OFF)
option(BTLA_UT_DEBUG "Enable debug unit tests" OFF)
Expand All @@ -21,14 +22,26 @@ option(BTLA_UT_NOASAN "Disable sanitize" OFF)
option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" OFF)
option(BTLA_UT_OPENMP "Use OpenMP for UT tests" OFF)





add_library(${PROJECT_NAME} INTERFACE)
add_library(neural_speed::${PROJECT_NAME} ALIAS ${PROJECT_NAME})
target_include_directories(
${PROJECT_NAME} INTERFACE
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>"
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>"
)

set(sycl_headers)
set(sycl_libs)
if(BTLA_SYCL)
include(cmake/sycl.cmake)
file(GLOB sycl_headers ${PROJECT_NAME}/sycl/*.h ${PROJECT_NAME}/sycl/*.hpp)
add_compile_definitions(BTLA_SYCL)
list(APPEND sycl_libs IntelSYCL::SYCL_CXX)
#add_link_options(-fsycl-targets=spir64 -Xsycl-target-backend "-options -ze-opt-large-register-file")
endif(BTLA_SYCL)

if(BTLA_ENABLE_OPENMP)
message(STATUS "BesTLA enable OpenMP ThreadPool")
Expand Down Expand Up @@ -69,12 +82,20 @@ function(add_ut_flag UT_OPTION)
endif()
endfunction()

set(benchmark_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/bestla_benchmark.cpp)
# list(APPEND benchmark_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/sycl_benchmark.cpp)


if(UT_BUILD)
file(GLOB srcs ${PROJECT_NAME}/ut/*.cc ${PROJECT_NAME}/ut/*.cpp) #compile everything even run parts of UTs
list(REMOVE_ITEM srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/bestla_benchmark.cpp)
file(GLOB sycl_srcs ${PROJECT_NAME}/ut/sycl*)
if(NOT BTLA_SYCL)
list(REMOVE_ITEM srcs ${sycl_srcs})
endif()
list(REMOVE_ITEM srcs ${benchmark_srcs})
file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h)
include_directories(${PROJECT_NAME})
add_executable(${PROJECT_NAME}_ut ${srcs} ${headers} ${ut_headers})
add_executable(${PROJECT_NAME}_ut ${srcs} ${headers} ${sycl_headers} ${ut_headers})
if(BTLA_UT_OPENMP)
include(FindOpenMP)
target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP)
Expand All @@ -98,14 +119,16 @@ if(UT_BUILD)
add_ut_flag(BTLA_UT_KERNEL_INTRIN)
add_ut_flag(BTLA_UT_KERNEL_JIT)
add_ut_flag(BTLA_UT_KERNEL_WRAPPER)
target_link_libraries(${PROJECT_NAME}_ut PRIVATE ${PROJECT_NAME})
if(BTLA_SYCL)
add_compile_definitions(BTLA_UT_SYCL)
endif()
target_link_libraries(${PROJECT_NAME}_ut PRIVATE ${PROJECT_NAME} ${sycl_libs})
endif(UT_BUILD)

if(BTLA_UT_BENCHMARK)
file(GLOB srcs ${PROJECT_NAME}/ut/bestla_benchmark.cpp) #compile everything even run parts of UTs
file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h)
include_directories(${PROJECT_NAME})
add_executable(${PROJECT_NAME}_benchmark ${srcs} ${headers} ${ut_headers})
add_executable(${PROJECT_NAME}_benchmark ${benchmark_srcs} ${headers} ${ut_headers})
if(BTLA_UT_OPENMP)
include(FindOpenMP)
target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP)
Expand All @@ -114,5 +137,5 @@ if(BTLA_UT_BENCHMARK)
if(NOT WIN32)
target_link_options(${PROJECT_NAME}_benchmark PRIVATE -lpthread)
endif()
target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE ${PROJECT_NAME})
target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE ${PROJECT_NAME} ${sycl_libs})
endif(BTLA_UT_BENCHMARK)
28 changes: 28 additions & 0 deletions bestla/CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,34 @@
"description": "Target Windows (64-bit) with the Visual Studio development environment. (RelWithDebInfo)",
"inherits": "x64-release",
"cacheVariables": { "BTLA_UT_ALL": "ON" }
},
{
"name": "x64-debug-sycl",
"displayName": "x64 Debug SYCL",
"description": "x64 Debug SYCL",
"inherits": "windows-base",
"architecture": {
"value": "x64",
"strategy": "external"
},
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Debug",
"BTLA_UT_DEBUG": "ON",
"BTLA_UT_ALL": "OFF",
"BTLA_SYCL": "ON",
"BTLA_UT_BENCHMARK": "ON",
"CMAKE_CXX_COMPILER": "icx",
"CMAKE_C_COMPILER": "icx"
}
},
{
"name": "x64-release-sycl",
"displayName": "x64 Release for SYCL",
"description": "x64 SYCL",
"inherits": "x64-debug-sycl",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Release"
}
}
]
}
14 changes: 7 additions & 7 deletions bestla/bestla/bestla_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
#define CompileAMXINT8() (CompileAMX())
#endif

#ifdef _MSC_VER
#if defined(_MSC_VER) && !defined(__INTEL_LLVM_COMPILER)
#define CompileAVX512F() _MSC_VER && (_MSC_VER >= 1911)
#define CompileAVX2() _MSC_VER && (_MSC_VER >= 1900)
#define CompileAMX() 0
Expand All @@ -80,12 +80,12 @@
#define CompileAMXINT8() 0
#endif

#ifdef __clang_major__
#define CompileAVX512F() (__clang_major__ >= 4)
#define CompileAVX2() (__clang_major__ >= 3)
#define CompileAMX() (__clang_major__ >= 11)
#define CompileBF16() (__clang_major__ >= 11)
#define CompileFP16() (__clang_major__ >= 16)
#if defined(_MSC_VER) && defined(__INTEL_LLVM_COMPILER)
#define CompileAVX512F() defined(__AVX512F__)
#define CompileAVX2() defined(__AVX2__) && defined(__F16C__) && defined(__FMA__)
#define CompileAMX() 0
#define CompileBF16() 0
#define CompileFP16() 0
#define CompileAMXBF16() (CompileAMX())
#define CompileAMXINT8() (CompileAMX())
#endif
Expand Down
12 changes: 6 additions & 6 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ namespace bestla {
namespace kernel {
namespace avx2 {
#if CompileAVX2()
#ifdef __GNUC__
#if defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx2", "fma", "f16c")
#else
#elif defined(ICX)
#pragma clang attribute push(__attribute__((target("avx,avx2,fma"))), apply_to = function)
#endif

template <bool LowBits>
static inline __m256i unpack_4bits_avx2(void* srcptr, __m256i mask) {
auto raw_data = _mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr));
Expand Down Expand Up @@ -74,7 +74,7 @@ inline __m256 ymm_cvt_bf16_fp32(__m128i vbf16) {

inline __m128i ymm_cvtepi32_epi16(__m256i src) {
__m128i tmp;
#ifdef __GNUC__
#if defined(__GNUC__) || defined(__clang_major__)
for (size_t i = 0; i < 8; i++) {
(reinterpret_cast<int16_t*>(&tmp))[i] = (reinterpret_cast<int32_t*>(&src))[i];
}
Expand Down Expand Up @@ -443,7 +443,7 @@ inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
e_revert = _mm256_srli_epi32(e_revert, mantissabit);
if constexpr (WITH_SCALE && std::is_same_v<_S_T, utils::f8>) {
auto scale = _mm256_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(sptr + j / _PACK_ROW)));
if constexpr (_PACK_ROW == 2) scale = _mm256_permutexvar_epi32(packrow2_permute_idx, scale);
if constexpr (_PACK_ROW == 2) scale = _mm256_permutevar8x32_epi32(packrow2_permute_idx, scale);
e_revert = _mm256_add_epi32(e_revert, scale);
}
e_revert = _mm256_sub_epi32(e_revert, e_revert_shift);
Expand All @@ -454,7 +454,7 @@ inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
fp_v = _mm256_or_ps(fp_v, _mm256_castsi256_ps(mantissa_revert));
if constexpr (WITH_SCALE && std::is_same_v<_S_T, float>) {
auto scale = _mm256_loadu_ps(sptr + j / _PACK_ROW);
if constexpr (_PACK_ROW == 2) scale = _mm256_permutexvar_ps(packrow2_permute_idx, scale);
if constexpr (_PACK_ROW == 2) scale = _mm256_permutevar8x32_ps(scale, packrow2_permute_idx);
fp_v = _mm256_mul_ps(fp_v, scale);
}
if constexpr (std::is_same_v<_DST_T, float>) {
Expand Down
6 changes: 6 additions & 0 deletions bestla/bestla/kernel_avx512_bf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ static inline BTLA_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr,
}
return BTLA_CODE::Success;
#endif
#if CompileAVX512F()
return avx512f::bf16_cvt_fp32_2D_write_back(src_ptr, dst_ptr, row, col, src_step, dst_step, zeropadding);
#endif
return BTLA_CODE::NotSupport;
}

static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col,
Expand Down Expand Up @@ -83,7 +86,10 @@ static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void
}
return BTLA_CODE::Success;
#endif
#if CompileAVX512F()
return avx512f::fp32_cvt_bf16_2D_write_back(raw_srcptr, raw_dstptr, row, col, srcstride, dststride, zeropadding);
#endif
return BTLA_CODE::NotSupport;
}
#if CompileBF16()
#pragma GCC pop_options
Expand Down
Loading
Loading