Skip to content

Commit

Permalink
[SYCLomatic] Refine the migration of cublas API with cublasComputeTyp…
Browse files Browse the repository at this point in the history
…e_t argument (#1826)


Signed-off-by: Jiang, Zhiwei <[email protected]>
  • Loading branch information
zhiweij1 authored Apr 23, 2024
1 parent a8716a3 commit c30a9b5
Show file tree
Hide file tree
Showing 31 changed files with 1,010 additions and 1,567 deletions.
1,497 changes: 524 additions & 973 deletions clang/lib/DPCT/APINamesCUBLAS.inc

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions clang/lib/DPCT/APINames_cuBLAS.inc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ ENTRY(cublasSetMatrixAsync, cublasSetMatrixAsync, true, NO_FLAG, P4, "DPCT1018/D
ENTRY(cublasGetMatrixAsync, cublasGetMatrixAsync, true, NO_FLAG, P4, "DPCT1018/DPCT1020")
ENTRY(cublasSetAtomicsMode, cublasSetAtomicsMode, true, NO_FLAG, P4, "DPCT1026/DPCT1027")
ENTRY(cublasGetAtomicsMode, cublasGetAtomicsMode, true, NO_FLAG, P4, "DPCT1026/DPCT1027")
ENTRY(cublasSetMathMode, cublasSetMathMode, true, NO_FLAG, P4, "DPCT1026/DPCT1027")
ENTRY(cublasGetMathMode, cublasGetMathMode, true, NO_FLAG, P4, "DPCT1026/DPCT1027")
ENTRY(cublasSetMathMode, cublasSetMathMode, true, NO_FLAG, P4, "Successful")
ENTRY(cublasGetMathMode, cublasGetMathMode, true, NO_FLAG, P4, "Successful")
ENTRY(cublasLoggerConfigure, cublasLoggerConfigure, false, NO_FLAG, P4, "comment")
ENTRY(cublasGetLoggerCallback, cublasGetLoggerCallback, false, NO_FLAG, P4, "comment")
ENTRY(cublasSetLoggerCallback, cublasSetLoggerCallback, false, NO_FLAG, P4, "comment")
Expand Down
16 changes: 8 additions & 8 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3628,11 +3628,13 @@ void BLASEnumsRule::registerMatcher(MatchFinder &MF) {
"CUBLAS_GEMM_.*)|(CUBLAS_POINTER_MODE.*)"))))
.bind("BLASStatusConstants"),
this);
MF.addMatcher(declRefExpr(to(enumConstantDecl(matchesName(
"(CUBLAS_OP.*)|(CUBLAS_SIDE.*)|(CUBLAS_FILL_"
"MODE.*)|(CUBLAS_DIAG.*)"))))
.bind("BLASNamedValueConstants"),
this);
MF.addMatcher(
declRefExpr(to(enumConstantDecl(matchesName(
"(CUBLAS_OP.*)|(CUBLAS_SIDE.*)|(CUBLAS_FILL_"
"MODE.*)|(CUBLAS_DIAG.*)|(CUBLAS_.*_MATH)|CUBLAS_MATH_"
"DISALLOW_REDUCED_PRECISION_REDUCTION"))))
.bind("BLASNamedValueConstants"),
this);
}

void BLASEnumsRule::runRule(const MatchFinder::MatchResult &Result) {
Expand Down Expand Up @@ -5229,9 +5231,7 @@ void BLASFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
} else if (FuncName == "cublasGetPointerMode_v2" ||
FuncName == "cublasSetPointerMode_v2" ||
FuncName == "cublasGetAtomicsMode" ||
FuncName == "cublasSetAtomicsMode" ||
FuncName == "cublasGetMathMode" ||
FuncName == "cublasSetMathMode") {
FuncName == "cublasSetAtomicsMode") {
std::string Msg = "this functionality is redundant in SYCL.";
if (IsAssigned) {
report(CE->getBeginLoc(), Diagnostics::FUNC_CALL_REMOVED_0, false,
Expand Down
197 changes: 82 additions & 115 deletions clang/lib/DPCT/MapNames.cpp

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion clang/lib/DPCT/MapNames.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ class MapNames {
static const MapTy MacrosMap;
static std::unordered_map<std::string, MacroMigrationRule> MacroRuleMap;
static std::unordered_map<std::string, MetaRuleObject &> HeaderRuleMap;
static MapTy BLASEnumsMap;
static MapTy SPBLASEnumsMap;
static const MapTy BLASEnumsMap;
static std::map<std::string, MapNames::BLASFuncReplInfo> BLASFuncReplInfoMap;
static const std::map<std::string, MapNames::BLASFuncComplexReplInfo>
BLASFuncComplexReplInfoMap;
Expand Down
62 changes: 31 additions & 31 deletions clang/runtime/dpct-rt/include/dpct/blas_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,20 +288,20 @@ inline void matrix_mem_copy(void *to_ptr, const void *from_ptr,
}

enum class math_mode : int {
_default,
_tf32,
mm_default,
mm_tf32,
};
enum class compute_type : int {
_16f,
_16f_standard,
_32f,
_32f_standard,
_32f_fast_16bf,
_32f_fast_tf32,
_64f,
_64f_standard,
_32i,
_32i_standard,
f16,
f16_standard,
f32,
f32_standard,
f32_fast_bf16,
f32_fast_tf32,
f64,
f64_standard,
i32,
i32_standard,
};

class descriptor {
Expand All @@ -319,7 +319,7 @@ class descriptor {

private:
queue_ptr _queue_ptr = &dpct::get_default_queue();
math_mode _mm = math_mode::_default;
math_mode _mm = math_mode::mm_default;
static inline queue_ptr _saved_queue_ptr = &dpct::get_default_queue();
};

Expand Down Expand Up @@ -1517,19 +1517,19 @@ namespace blas {
namespace detail {
inline library_data_t compute_type_to_library_data_t(compute_type ct) {
switch (ct) {
case compute_type::_16f:
case compute_type::_16f_standard:
case compute_type::f16:
case compute_type::f16_standard:
return library_data_t::real_half;
case compute_type::_32f:
case compute_type::_32f_standard:
case compute_type::_32f_fast_16bf:
case compute_type::_32f_fast_tf32:
case compute_type::f32:
case compute_type::f32_standard:
case compute_type::f32_fast_bf16:
case compute_type::f32_fast_tf32:
return library_data_t::real_float;
case compute_type::_64f:
case compute_type::_64f_standard:
case compute_type::f64:
case compute_type::f64_standard:
return library_data_t::real_double;
case compute_type::_32i:
case compute_type::_32i_standard:
case compute_type::i32:
case compute_type::i32_standard:
return library_data_t::real_int32;
default:
throw std::runtime_error("conversion is not supported.");
Expand All @@ -1543,25 +1543,25 @@ deduce_compute_mode(std::optional<compute_type> ct, math_mode mm) {
using Ty = typename DataType<T>::T2;
if (ct) {
switch (ct.value()) {
case compute_type::_16f_standard:
case compute_type::_32f_standard:
case compute_type::_64f_standard:
case compute_type::_32i_standard:
case compute_type::f16_standard:
case compute_type::f32_standard:
case compute_type::f64_standard:
case compute_type::i32_standard:
return oneapi::mkl::blas::compute_mode::standard;
case compute_type::_32f:
case compute_type::f32:
if constexpr (std::is_same_v<Ty, std::complex<float>> ||
std::is_same_v<Ty, std::complex<double>>)
return oneapi::mkl::blas::compute_mode::complex_3m;
break;
case compute_type::_32f_fast_16bf:
case compute_type::f32_fast_bf16:
return oneapi::mkl::blas::compute_mode::float_to_bf16;
case compute_type::_32f_fast_tf32:
case compute_type::f32_fast_tf32:
return oneapi::mkl::blas::compute_mode::float_to_tf32;
default:
[[fallthrough]];
}
}
if (mm == math_mode::_tf32)
if (mm == math_mode::mm_tf32)
return oneapi::mkl::blas::compute_mode::float_to_tf32;
return oneapi::mkl::blas::compute_mode::unset;
}
Expand Down
Loading

0 comments on commit c30a9b5

Please sign in to comment.