Skip to content

Commit

Permalink
[SYCLomatic oneapi-src#1844] Adjust blas_utils_parameter_wrapper_buf …
Browse files Browse the repository at this point in the history
…to test more cases (oneapi-src#667)

Signed-off-by: Jiang, Zhiwei <[email protected]>
  • Loading branch information
zhiweij1 authored Apr 9, 2024
1 parent fb64662 commit 86c8dff
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions help_function/src/blas_utils_parameter_wrapper_buf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,14 @@ void test_iamax2() {
void test_rotg1() {
dpct::queue_ptr handle;
handle = &dpct::get_out_of_order_queue();
float *a = (float *)dpct::dpct_malloc(sizeof(float) * 1);
float *b = (float *)dpct::dpct_malloc(sizeof(float) * 1);
float *c = (float *)dpct::dpct_malloc(sizeof(float) * 1);
float *s = (float *)dpct::dpct_malloc(sizeof(float) * 1);
dpct::get_host_ptr<float>(a)[0] = 1;
dpct::get_host_ptr<float>(b)[0] = 1.73205;
float *four_args = (float *)dpct::dpct_malloc(sizeof(float) * 4);
dpct::get_host_ptr<float>(four_args)[0] = 1;
dpct::get_host_ptr<float>(four_args)[1] = 1.73205;
[&]() {
dpct::blas::wrapper_float_inout a_m(*handle, a);
dpct::blas::wrapper_float_inout b_m(*handle, b);
dpct::blas::wrapper_float_out c_m(*handle, c);
dpct::blas::wrapper_float_out s_m(*handle, s);
dpct::blas::wrapper_float_inout a_m(*handle, four_args);
dpct::blas::wrapper_float_inout b_m(*handle, four_args + 1);
dpct::blas::wrapper_float_out c_m(*handle, four_args + 2);
dpct::blas::wrapper_float_out s_m(*handle, four_args + 3);
oneapi::mkl::blas::column_major::rotg(
*handle,
dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer<float>(a_m.get_ptr())),
Expand All @@ -131,22 +128,19 @@ void test_rotg1() {
}();
dpct::get_current_device().queues_wait_and_throw();
handle = nullptr;
if (std::abs(dpct::get_host_ptr<float>(a)[0] - 2.0f) < 0.01 &&
std::abs(dpct::get_host_ptr<float>(b)[0] - 2.0f) < 0.01 &&
std::abs(dpct::get_host_ptr<float>(c)[0] - 0.5f) < 0.01 &&
std::abs(dpct::get_host_ptr<float>(s)[0] - 0.866025f) < 0.01) {
if (std::abs(dpct::get_host_ptr<float>(four_args)[0] - 2.0f) < 0.01 &&
std::abs(dpct::get_host_ptr<float>(four_args)[1] - 2.0f) < 0.01 &&
std::abs(dpct::get_host_ptr<float>(four_args)[2] - 0.5f) < 0.01 &&
std::abs(dpct::get_host_ptr<float>(four_args)[3] - 0.866025f) < 0.01) {
printf("test_rotg1 pass\n");
} else {
printf("test_rotg1 fail:\n");
printf("%f,%f,%f,%f\n", dpct::get_host_ptr<float>(a)[0],
dpct::get_host_ptr<float>(b)[0], dpct::get_host_ptr<float>(c)[0],
dpct::get_host_ptr<float>(s)[0]);
printf("%f,%f,%f,%f\n", dpct::get_host_ptr<float>(four_args)[0],
dpct::get_host_ptr<float>(four_args)[1], dpct::get_host_ptr<float>(four_args)[2],
dpct::get_host_ptr<float>(four_args)[3]);
pass = false;
}
dpct::dpct_free(a);
dpct::dpct_free(b);
dpct::dpct_free(c);
dpct::dpct_free(s);
dpct::dpct_free(four_args);
}

void test_rotg2() {
Expand Down

0 comments on commit 86c8dff

Please sign in to comment.