Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
Signed-off-by: Jiang, Zhiwei <[email protected]>
  • Loading branch information
zhiweij1 committed Mar 1, 2024
1 parent b89dfc4 commit 22a520d
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 88 deletions.
16 changes: 8 additions & 8 deletions clang/lib/DPCT/APINamesCUBLAS.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ WARNING_FACTORY_ENTRY(
"cublasIsamax_v2",
LAMBDA_FACTORY_ENTRY(
"cublasIsamax_v2",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int_t", "res",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int64_int_t", "res",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(4)),
CALL("oneapi::mkl::blas::column_major::iamax",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(1),
Expand All @@ -1542,7 +1542,7 @@ WARNING_FACTORY_ENTRY(
"cublasIdamax_v2",
LAMBDA_FACTORY_ENTRY(
"cublasIdamax_v2",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int_t", "res",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int64_int_t", "res",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(4)),
CALL("oneapi::mkl::blas::column_major::iamax",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(1),
Expand All @@ -1555,7 +1555,7 @@ WARNING_FACTORY_ENTRY(
"cublasIcamax_v2",
LAMBDA_FACTORY_ENTRY(
"cublasIcamax_v2",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int_t", "res",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int64_int_t", "res",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(4)),
CALL("oneapi::mkl::blas::column_major::iamax",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(1),
Expand All @@ -1568,7 +1568,7 @@ WARNING_FACTORY_ENTRY(
"cublasIzamax_v2",
LAMBDA_FACTORY_ENTRY(
"cublasIzamax_v2",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int_t", "res",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int64_int_t", "res",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(4)),
CALL("oneapi::mkl::blas::column_major::iamax",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(1),
Expand All @@ -1582,7 +1582,7 @@ WARNING_FACTORY_ENTRY(
"cublasIsamin_v2",
LAMBDA_FACTORY_ENTRY(
"cublasIsamin_v2",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int_t", "res",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int64_int_t", "res",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(4)),
CALL("oneapi::mkl::blas::column_major::iamin",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(1),
Expand All @@ -1595,7 +1595,7 @@ WARNING_FACTORY_ENTRY(
"cublasIdamin_v2",
LAMBDA_FACTORY_ENTRY(
"cublasIdamin_v2",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int_t", "res",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int64_int_t", "res",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(4)),
CALL("oneapi::mkl::blas::column_major::iamin",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(1),
Expand All @@ -1608,7 +1608,7 @@ WARNING_FACTORY_ENTRY(
"cublasIcamin_v2",
LAMBDA_FACTORY_ENTRY(
"cublasIcamin_v2",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int_t", "res",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int64_int_t", "res",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(4)),
CALL("oneapi::mkl::blas::column_major::iamin",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(1),
Expand All @@ -1621,7 +1621,7 @@ WARNING_FACTORY_ENTRY(
"cublasIzamin_v2",
LAMBDA_FACTORY_ENTRY(
"cublasIzamin_v2",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int_t", "res",
DECL(MapNames::getDpctNamespace() + "blas::out_mem_int64_int_t", "res",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(4)),
CALL("oneapi::mkl::blas::column_major::iamin",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(1),
Expand Down
98 changes: 55 additions & 43 deletions clang/runtime/dpct-rt/include/dpct/blas_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@

namespace dpct {
namespace blas {
namespace detail {
template <typename target_t, typename source_t, bool has_input>
class scalar_memory_base_t {
enum class mem_inout { in, out, inout };
template <typename target_t, typename source_t, mem_inout io> class mem_base_t {
public:
scalar_memory_base_t(sycl::queue q, source_t *source,
mem_base_t(sycl::queue q, source_t *source,
#ifdef DPCT_USM_LEVEL_NONE
sycl::buffer<target_t> target
sycl::buffer<target_t> target
#else
target_t *target
target_t *target
#endif
)
)
: _q(q), _source(source), _target(target),
_source_attribute(dpct::detail::get_pointer_attribute(_q, _source)) {
}
Expand All @@ -52,20 +51,18 @@ class scalar_memory_base_t {
dpct::detail::pointer_access_attribute _source_attribute;
};

template <typename target_t, typename source_t, bool has_input>
class scalar_memory_t
: public scalar_memory_base_t<target_t, source_t, has_input> {
static_assert(
!has_input &&
"input is not supported if target_t and source_t are not same.");
using base_t = scalar_memory_base_t<target_t, source_t, has_input>;
template <typename target_t, typename source_t, mem_inout io>
class mem_t : public mem_base_t<target_t, source_t, io> {
static_assert(io == mem_inout::out && "Only mem_inout::out is supported if "
"target_t and source_t are not same.");
using base_t = mem_base_t<target_t, source_t, io>;
using base_t::_q;
using base_t::_source;
using base_t::_source_attribute;
using base_t::_target;

public:
scalar_memory_t(sycl::queue q, source_t *source)
mem_t(sycl::queue q, source_t *source)
: base_t(q, source,
#ifdef DPCT_USM_LEVEL_NONE
sycl::buffer<target_t>(sycl::range<1>(1))
Expand All @@ -75,7 +72,7 @@ class scalar_memory_t
) {
}

~scalar_memory_t() {
~mem_t() {
#ifdef DPCT_USM_LEVEL_NONE
source_t temp = static_cast<source_t>(_target.get_host_access()[0]);
if (_source_attribute ==
Expand All @@ -98,36 +95,48 @@ class scalar_memory_t
}
};

template <typename target_t, bool has_input>
class scalar_memory_t<target_t, target_t, has_input>
: public scalar_memory_base_t<target_t, target_t, has_input> {
using base_t = scalar_memory_base_t<target_t, target_t, has_input>;
template <typename target_t, mem_inout io>
class mem_t<target_t, target_t, io>
: public mem_base_t<target_t, target_t, io> {
using base_t = mem_base_t<target_t, target_t, io>;
using base_t::_q;
using base_t::_source;
using base_t::_source_attribute;
using base_t::_target;
size_t _ele_num;
#ifndef DPCT_USM_LEVEL_NONE
bool _need_free = true;
#endif

public:
scalar_memory_t(sycl::queue q, target_t *source)
: base_t(q, source,
dpct::detail::get_pointer_attribute(q, source) !=
dpct::detail::pointer_access_attribute::host_only
?
#ifdef DPCT_USM_LEVEL_NONE
dpct::get_buffer<target_t>(source)
sycl::buffer<target_t>
#else
source
target_t *
#endif
:
construct_member_variable_target(sycl::queue q, target_t *source,
size_t ele_num) {
#ifdef DPCT_USM_LEVEL_NONE
sycl::buffer<target_t>(source, sycl::range<1>(1))
if (dpct::detail::get_pointer_attribute(q, source) !=
dpct::detail::pointer_access_attribute::host_only) {
target_t *host_ptr = dpct::get_host_ptr<target_t>(source);
return sycl::buffer<target_t>(host_ptr, sycl::range<1>(ele_num));
} else {
return sycl::buffer<target_t>(source, sycl::range<1>(ele_num));
}
#else
sycl::malloc_shared<target_t>(1, q)
if (dpct::detail::get_pointer_attribute(q, source) !=
dpct::detail::pointer_access_attribute::host_only) {
return source;
} else {
return sycl::malloc_shared<target_t>(ele_num, q);
}
#endif
) {
}

public:
mem_t(sycl::queue q, target_t *source, size_t ele_num = 1)
: base_t(q, source, construct_member_variable_target(q, source, ele_num)),
_ele_num(ele_num) {
if (_source_attribute !=
dpct::detail::pointer_access_attribute::host_only) {
#ifndef DPCT_USM_LEVEL_NONE
Expand All @@ -137,33 +146,36 @@ class scalar_memory_t<target_t, target_t, has_input>
}

#ifndef DPCT_USM_LEVEL_NONE
if constexpr (has_input) {
if constexpr (io != mem_inout::out) {
if (_source_attribute ==
dpct::detail::pointer_access_attribute::host_only) {
*_target = *_source;
_q.memcpy(_target, _source, sizeof(target_t) * _ele_num).wait();
}
}
#endif
}

~scalar_memory_t() {
~mem_t() {
#ifndef DPCT_USM_LEVEL_NONE
if (!_need_free) {
return;
}
if (_source_attribute ==
dpct::detail::pointer_access_attribute::host_only) {
_q.wait();
*_source = *_target;
if constexpr (io != mem_inout::in) {
if (_source_attribute ==
dpct::detail::pointer_access_attribute::host_only) {
_q.wait();
_q.memcpy(_source, _target, sizeof(target_t) * _ele_num).wait();
}
}
sycl::free(_target, _q);
#endif
}
};
} // namespace detail
using out_mem_int_t = detail::scalar_memory_t<std::int64_t, int, false>;
using out_mem_int64_t =
detail::scalar_memory_t<std::int64_t, std::int64_t, false>;
using out_mem_int64_int_t = mem_t<std::int64_t, int, mem_inout::out>;
using out_mem_int64_t = mem_t<std::int64_t, std::int64_t, mem_inout::out>;
using out_mem_float_t = mem_t<float, float, mem_inout::out>;
using inout_mem_float_t = mem_t<float, float, mem_inout::inout>;
using in_mem_float_t = mem_t<float, float, mem_inout::in>;

class descriptor {
public:
Expand Down
16 changes: 8 additions & 8 deletions clang/test/dpct/cublas-usm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,50 +83,50 @@ int main() {
//level 1

//CHECK:a = [&]() {
//CHECK-NEXT:dpct::blas::out_mem_int_t res(handle->get_queue(), result);
//CHECK-NEXT:dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
//CHECK-NEXT:oneapi::mkl::blas::column_major::iamax(handle->get_queue(), N, x_S, N, res.get_memory(), oneapi::mkl::index_base::one);
//CHECK-NEXT:return 0;
//CHECK-NEXT:}();
a = cublasIsamax(handle, N, x_S, N, result);
//CHECK:[&]() {
//CHECK-NEXT:dpct::blas::out_mem_int_t res(handle->get_queue(), result);
//CHECK-NEXT:dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
//CHECK-NEXT:oneapi::mkl::blas::column_major::iamax(handle->get_queue(), N, x_D, N, res.get_memory(), oneapi::mkl::index_base::one);
//CHECK-NEXT:return 0;
//CHECK-NEXT:}();
cublasIdamax(handle, N, x_D, N, result);
//CHECK:a = [&]() {
//CHECK-NEXT:dpct::blas::out_mem_int_t res(handle->get_queue(), result);
//CHECK-NEXT:dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
//CHECK-NEXT:oneapi::mkl::blas::column_major::iamax(handle->get_queue(), N, (std::complex<float>*)x_C, N, res.get_memory(), oneapi::mkl::index_base::one);
//CHECK-NEXT:return 0;
//CHECK-NEXT:}();
a = cublasIcamax(handle, N, x_C, N, result);
//CHECK:[&]() {
//CHECK-NEXT:dpct::blas::out_mem_int_t res(handle->get_queue(), result);
//CHECK-NEXT:dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
//CHECK-NEXT:oneapi::mkl::blas::column_major::iamax(handle->get_queue(), N, (std::complex<double>*)x_Z, N, res.get_memory(), oneapi::mkl::index_base::one);
//CHECK-NEXT:return 0;
//CHECK-NEXT:}();
cublasIzamax(handle, N, x_Z, N, result);

//CHECK:a = [&]() {
//CHECK-NEXT:dpct::blas::out_mem_int_t res(handle->get_queue(), result);
//CHECK-NEXT:dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
//CHECK-NEXT:oneapi::mkl::blas::column_major::iamin(handle->get_queue(), N, x_S, N, res.get_memory(), oneapi::mkl::index_base::one);
//CHECK-NEXT:return 0;
//CHECK-NEXT:}();
a = cublasIsamin(handle, N, x_S, N, result);
//CHECK:[&]() {
//CHECK-NEXT:dpct::blas::out_mem_int_t res(handle->get_queue(), result);
//CHECK-NEXT:dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
//CHECK-NEXT:oneapi::mkl::blas::column_major::iamin(handle->get_queue(), N, x_D, N, res.get_memory(), oneapi::mkl::index_base::one);
//CHECK-NEXT:return 0;
//CHECK-NEXT:}();
cublasIdamin(handle, N, x_D, N, result);
//CHECK:a = [&]() {
//CHECK-NEXT:dpct::blas::out_mem_int_t res(handle->get_queue(), result);
//CHECK-NEXT:dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
//CHECK-NEXT:oneapi::mkl::blas::column_major::iamin(handle->get_queue(), N, (std::complex<float>*)x_C, N, res.get_memory(), oneapi::mkl::index_base::one);
//CHECK-NEXT:return 0;
//CHECK-NEXT:}();
a = cublasIcamin(handle, N, x_C, N, result);
//CHECK:[&]() {
//CHECK-NEXT:dpct::blas::out_mem_int_t res(handle->get_queue(), result);
//CHECK-NEXT:dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
//CHECK-NEXT:oneapi::mkl::blas::column_major::iamin(handle->get_queue(), N, (std::complex<double>*)x_Z, N, res.get_memory(), oneapi::mkl::index_base::one);
//CHECK-NEXT:return 0;
//CHECK-NEXT:}();
Expand Down
16 changes: 8 additions & 8 deletions clang/test/dpct/cublasIsamax_etc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,25 @@ int main() {
//level1
//cublasI<t>amax
// CHECK: status = [&]() {
// CHECK-NEXT: dpct::blas::out_mem_int_t res(handle->get_queue(), result);
// CHECK-NEXT: dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(handle->get_queue(), n, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer<float>(x_S)), incx, res.get_memory(), oneapi::mkl::index_base::one);
// CHECK-NEXT: return 0;
// CHECK-NEXT: }();
// CHECK: [&]() {
// CHECK-NEXT: dpct::blas::out_mem_int_t res(handle->get_queue(), result);
// CHECK-NEXT: dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(handle->get_queue(), n, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer<float>(x_S)), incx, res.get_memory(), oneapi::mkl::index_base::one);
// CHECK-NEXT: return 0;
// CHECK-NEXT: }();
status = cublasIsamax(handle, n, x_S, incx, result);
cublasIsamax(handle, n, x_S, incx, result);

// CHECK: status = [&]() {
// CHECK-NEXT: dpct::blas::out_mem_int_t res(handle->get_queue(), result);
// CHECK-NEXT: dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(handle->get_queue(), n, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer<double>(x_D)), incx, res.get_memory(), oneapi::mkl::index_base::one);
// CHECK-NEXT: return 0;
// CHECK-NEXT: }();
// CHECK: [&]() {
// CHECK-NEXT: dpct::blas::out_mem_int_t res(handle->get_queue(), result);
// CHECK-NEXT: dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(handle->get_queue(), n, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer<double>(x_D)), incx, res.get_memory(), oneapi::mkl::index_base::one);
// CHECK-NEXT: return 0;
// CHECK-NEXT: }();
Expand All @@ -65,25 +65,25 @@ int main() {

//cublasI<t>amin
// CHECK: status = [&]() {
// CHECK-NEXT: dpct::blas::out_mem_int_t res(handle->get_queue(), result);
// CHECK-NEXT: dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(handle->get_queue(), n, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer<float>(x_S)), incx, res.get_memory(), oneapi::mkl::index_base::one);
// CHECK-NEXT: return 0;
// CHECK-NEXT: }();
// CHECK: [&]() {
// CHECK-NEXT: dpct::blas::out_mem_int_t res(handle->get_queue(), result);
// CHECK-NEXT: dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(handle->get_queue(), n, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer<float>(x_S)), incx, res.get_memory(), oneapi::mkl::index_base::one);
// CHECK-NEXT: return 0;
// CHECK-NEXT: }();
status = cublasIsamin(handle, n, x_S, incx, result);
cublasIsamin(handle, n, x_S, incx, result);

// CHECK: status = [&]() {
// CHECK-NEXT: dpct::blas::out_mem_int_t res(handle->get_queue(), result);
// CHECK-NEXT: dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(handle->get_queue(), n, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer<double>(x_D)), incx, res.get_memory(), oneapi::mkl::index_base::one);
// CHECK-NEXT: return 0;
// CHECK-NEXT: }();
// CHECK: [&]() {
// CHECK-NEXT: dpct::blas::out_mem_int_t res(handle->get_queue(), result);
// CHECK-NEXT: dpct::blas::out_mem_int64_int_t res(handle->get_queue(), result);
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(handle->get_queue(), n, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer<double>(x_D)), incx, res.get_memory(), oneapi::mkl::index_base::one);
// CHECK-NEXT: return 0;
// CHECK-NEXT: }();
Expand Down
Loading

0 comments on commit 22a520d

Please sign in to comment.