Skip to content

Commit

Permalink
Refine
Browse files Browse the repository at this point in the history
Signed-off-by: Jiang, Zhiwei <[email protected]>
  • Loading branch information
zhiweij1 committed Apr 12, 2024
1 parent 2ffa077 commit d97c62a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 25 deletions.
10 changes: 4 additions & 6 deletions clang/lib/DPCT/MapNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1400,15 +1400,13 @@ void MapNames::setExplicitNamespaceMap() {
{"CUBLAS_DIAG_NON_UNIT", "oneapi::mkl::diag::nonunit"},
{"CUBLAS_DIAG_UNIT", "oneapi::mkl::diag::unit"},
{"CUBLAS_DEFAULT_MATH", getDpctNamespace() + "blas::math_mode::_default"},
{"CUBLAS_TENSOR_OP_MATH",
getDpctNamespace() + "blas::math_mode::_tensor_op"},
{"CUBLAS_TENSOR_OP_MATH", getDpctNamespace() + "blas::math_mode::_tf32"},
{"CUBLAS_PEDANTIC_MATH",
getDpctNamespace() + "blas::math_mode::_pedantic"},
getDpctNamespace() + "blas::math_mode::_default"},
{"CUBLAS_TF32_TENSOR_OP_MATH",
getDpctNamespace() + "blas::math_mode::_tf32_tensor_op"},
getDpctNamespace() + "blas::math_mode::_tf32"},
{"CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION",
getDpctNamespace() +
"blas::math_mode::_disallow_reduced_precision_reduction"},
getDpctNamespace() + "blas::math_mode::_default"},
};

ClassFieldMap = {};
Expand Down
33 changes: 14 additions & 19 deletions clang/runtime/dpct-rt/include/dpct/blas_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,23 +289,19 @@ inline void matrix_mem_copy(void *to_ptr, const void *from_ptr,

enum class math_mode : int {
_default,
_tensor_op,
_pedantic,
_tf32_tensor_op,
_disallow_reduced_precision_reduction
_tf32,
};
enum class compute_type : int {
_16f,
_16f_pedantic,
_16f_standard,
_32f,
_32f_pedantic,
_32f_fast_16f,
_32f_standard,
_32f_fast_16bf,
_32f_fast_tf32,
_64f,
_64f_pedantic,
_64f_standard,
_32i,
_32i_pedantic,
_32i_standard,
};

class descriptor {
Expand Down Expand Up @@ -1522,19 +1518,18 @@ namespace detail {
inline library_data_t compute_type_to_library_data_t(compute_type ct) {
switch (ct) {
case compute_type::_16f:
case compute_type::_16f_pedantic:
case compute_type::_16f_standard:
return library_data_t::real_half;
case compute_type::_32f:
case compute_type::_32f_pedantic:
case compute_type::_32f_fast_16f:
case compute_type::_32f_standard:
case compute_type::_32f_fast_16bf:
case compute_type::_32f_fast_tf32:
return library_data_t::real_float;
case compute_type::_64f:
case compute_type::_64f_pedantic:
case compute_type::_64f_standard:
return library_data_t::real_double;
case compute_type::_32i:
case compute_type::_32i_pedantic:
case compute_type::_32i_standard:
return library_data_t::real_int32;
default:
throw std::runtime_error("conversion is not supported.");
Expand All @@ -1548,10 +1543,10 @@ deduce_compute_mode(std::optional<compute_type> ct, math_mode mm) {
using Ty = typename DataType<T>::T2;
if (ct) {
switch (ct.value()) {
case compute_type::_16f_pedantic:
case compute_type::_32f_pedantic:
case compute_type::_64f_pedantic:
case compute_type::_32i_pedantic:
case compute_type::_16f_standard:
case compute_type::_32f_standard:
case compute_type::_64f_standard:
case compute_type::_32i_standard:
return oneapi::mkl::blas::compute_mode::standard;
case compute_type::_32f:
if constexpr (std::is_same_v<Ty, std::complex<float>> ||
Expand All @@ -1566,7 +1561,7 @@ deduce_compute_mode(std::optional<compute_type> ct, math_mode mm) {
[[fallthrough]];
}
}
if (mm == math_mode::_tf32_tensor_op)
if (mm == math_mode::_tf32)
return oneapi::mkl::blas::compute_mode::float_to_tf32;
return oneapi::mkl::blas::compute_mode::unset;
}
Expand Down

0 comments on commit d97c62a

Please sign in to comment.