Skip to content

Commit

Permalink
Adding tests for device_pointer inside a permutation_iterator
Browse files Browse the repository at this point in the history
Signed-off-by: Dan Hoeflinger <[email protected]>
  • Loading branch information
danhoeflinger committed Aug 19, 2024
1 parent 9104c45 commit c23d65e
Showing 1 changed file with 55 additions and 4 deletions.
59 changes: 55 additions & 4 deletions help_function/src/onedpl_test_device_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,33 @@
#include <iostream>

template<typename String, typename _T1, typename _T2>
int ASSERT_EQUAL(String msg, _T1&& X, _T2&& Y) {
int ASSERT_EQUAL(String msg, _T1&& X, _T2&& Y, bool skip_pass_msg = false) {
if(X!=Y) {
std::cout << "FAIL: " << msg << " - (" << X << "," << Y << ")" << std::endl;
return 1;
}
else {
else if (!skip_pass_msg){
std::cout << "PASS: " << msg << std::endl;
}
return 0;
}

template <typename String, typename _T1, typename _T2>
int
ASSERT_EQUAL_N(String msg, _T1&& X, _T2&& Y, ::std::size_t n)
{
int failed_tests = 0;
for (size_t i = 0; i < n; i++)
{
failed_tests += ASSERT_EQUAL(msg, *X, *Y, true);
X++;
Y++;
}
if (failed_tests == 0)
{
std::cout << "PASS: " << msg << std::endl;
return 0;
}
return failed_tests;
}

int test_device_ptr_manipulation(void)
Expand Down Expand Up @@ -115,9 +133,42 @@ int test_device_ptr_iteration(void)
return failing_tests;
}

int
test_permutation_iterator()
{
int failing_tests = 0;
typedef size_t T;
#ifdef DPCT_USM_LEVEL_NONE
sycl::buffer<T, 1> data(sycl::range<1>(1024));

dpct::device_pointer<T> begin(data, 0);
dpct::device_pointer<T> end(data, 1024);

sycl::buffer<T, 1> data_res(sycl::range<1>(1024));
dpct::device_pointer<T> begin_res(data_res, 0);
dpct::device_pointer<T> end_res(data_res, 1024);
#else
dpct::device_pointer<T> data(1024*sizeof(T));
dpct::device_pointer<T> begin(data);
dpct::device_pointer<T> end(data + 1024);

dpct::device_pointer<T> data_res(1024*sizeof(T));
dpct::device_pointer<T> begin_res(data_res);
dpct::device_pointer<T> end_res(data_res + 1024);
#endif
auto policy = oneapi::dpl::execution::make_device_policy(dpct::get_default_queue());
std::fill(begin, end, T(1));
std::fill(begin_res, end_res, T(99));
auto perm = oneapi::dpl::make_permutation_iterator(begin, oneapi::dpl::counting_iterator(0));
std::copy(policy, perm, perm + 1024, begin_res);
return ASSERT_EQUAL_N("device_ptr in permutation_iterator", begin_res, dpct::make_constant_iterator(T(1)), 1024);
}

int main() {
int failed_tests = test_device_ptr_manipulation();
int failed_tests = 0;
failed_tests += test_device_ptr_manipulation();
failed_tests += test_device_ptr_iteration();
failed_tests += test_permutation_iterator();

std::cout << std::endl << failed_tests << " failing test(s) detected." << std::endl;
if (failed_tests == 0) {
Expand Down

0 comments on commit c23d65e

Please sign in to comment.