From 86c8dffe9079e887379bec291d520cea99bbe6df Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Tue, 9 Apr 2024 10:13:24 +0800 Subject: [PATCH] [SYCLomatic #1844] Adjust blas_utils_parameter_wrapper_buf to test more cases (#667) Signed-off-by: Jiang, Zhiwei --- .../src/blas_utils_parameter_wrapper_buf.cpp | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/help_function/src/blas_utils_parameter_wrapper_buf.cpp b/help_function/src/blas_utils_parameter_wrapper_buf.cpp index 50852b898ad0..5e670646fa90 100644 --- a/help_function/src/blas_utils_parameter_wrapper_buf.cpp +++ b/help_function/src/blas_utils_parameter_wrapper_buf.cpp @@ -111,17 +111,14 @@ void test_iamax2() { void test_rotg1() { dpct::queue_ptr handle; handle = &dpct::get_out_of_order_queue(); - float *a = (float *)dpct::dpct_malloc(sizeof(float) * 1); - float *b = (float *)dpct::dpct_malloc(sizeof(float) * 1); - float *c = (float *)dpct::dpct_malloc(sizeof(float) * 1); - float *s = (float *)dpct::dpct_malloc(sizeof(float) * 1); - dpct::get_host_ptr(a)[0] = 1; - dpct::get_host_ptr(b)[0] = 1.73205; + float *four_args = (float *)dpct::dpct_malloc(sizeof(float) * 4); + dpct::get_host_ptr(four_args)[0] = 1; + dpct::get_host_ptr(four_args)[1] = 1.73205; [&]() { - dpct::blas::wrapper_float_inout a_m(*handle, a); - dpct::blas::wrapper_float_inout b_m(*handle, b); - dpct::blas::wrapper_float_out c_m(*handle, c); - dpct::blas::wrapper_float_out s_m(*handle, s); + dpct::blas::wrapper_float_inout a_m(*handle, four_args); + dpct::blas::wrapper_float_inout b_m(*handle, four_args + 1); + dpct::blas::wrapper_float_out c_m(*handle, four_args + 2); + dpct::blas::wrapper_float_out s_m(*handle, four_args + 3); oneapi::mkl::blas::column_major::rotg( *handle, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer(a_m.get_ptr())), @@ -131,22 +128,19 @@ void test_rotg1() { }(); dpct::get_current_device().queues_wait_and_throw(); handle = nullptr; - if (std::abs(dpct::get_host_ptr(a)[0] - 2.0f) < 0.01 && - std::abs(dpct::get_host_ptr(b)[0] - 2.0f) < 0.01 && - std::abs(dpct::get_host_ptr(c)[0] - 0.5f) < 0.01 && - std::abs(dpct::get_host_ptr(s)[0] - 0.866025f) < 0.01) { + if (std::abs(dpct::get_host_ptr(four_args)[0] - 2.0f) < 0.01 && + std::abs(dpct::get_host_ptr(four_args)[1] - 2.0f) < 0.01 && + std::abs(dpct::get_host_ptr(four_args)[2] - 0.5f) < 0.01 && + std::abs(dpct::get_host_ptr(four_args)[3] - 0.866025f) < 0.01) { printf("test_rotg1 pass\n"); } else { printf("test_rotg1 fail:\n"); - printf("%f,%f,%f,%f\n", dpct::get_host_ptr(a)[0], - dpct::get_host_ptr(b)[0], dpct::get_host_ptr(c)[0], - dpct::get_host_ptr(s)[0]); + printf("%f,%f,%f,%f\n", dpct::get_host_ptr(four_args)[0], + dpct::get_host_ptr(four_args)[1], dpct::get_host_ptr(four_args)[2], + dpct::get_host_ptr(four_args)[3]); pass = false; } - dpct::dpct_free(a); - dpct::dpct_free(b); - dpct::dpct_free(c); - dpct::dpct_free(s); + dpct::dpct_free(four_args); } void test_rotg2() {