From c27a6b6ba8efca4e5137d8b74ff3b24bad4f3f23 Mon Sep 17 00:00:00 2001 From: Bryce Allen Date: Mon, 16 Jan 2023 15:38:03 -0500 Subject: [PATCH 1/2] update sycl_ext_complex For some reason the gtensor specified AssignN kernel names were missing the second sycl_cplx::complex type arg, and failed to compile. Using only the to_kernel types fixes the issue, and also makes the names more compact and easier to read. I don't think they are actually necessary to avoid duplication, the kernel types alone should be unique. --- include/gtensor/assign.h | 14 +- include/gtensor/backend_sycl.h | 8 +- include/gtensor/complex.h | 2 +- include/gtensor/sycl_ext_complex.hpp | 1029 +++++++++----------------- 4 files changed, 346 insertions(+), 707 deletions(-) diff --git a/include/gtensor/assign.h b/include/gtensor/assign.h index d8f56398..4b113159 100644 --- a/include/gtensor/assign.h +++ b/include/gtensor/assign.h @@ -401,7 +401,7 @@ struct assigner<1, space::device> auto e = q.submit([&](sycl::handler& cgh) { using ltype = decltype(k_lhs); using rtype = decltype(k_rhs); - using kname = gt::backend::sycl::Assign1; + using kname = gt::backend::sycl::Assign1; cgh.parallel_for(range, [=](sycl::item<1> item) { auto i = item.get_id(); k_lhs(i) = k_rhs(i); @@ -427,7 +427,7 @@ struct assigner<2, space::device> auto e = q.submit([&](sycl::handler& cgh) { using ltype = decltype(k_lhs); using rtype = decltype(k_rhs); - using kname = gt::backend::sycl::Assign2; + using kname = gt::backend::sycl::Assign2; cgh.parallel_for(range, [=](sycl::item<2> item) { auto i = item.get_id(1); auto j = item.get_id(0); @@ -454,7 +454,7 @@ struct assigner<3, space::device> auto e = q.submit([&](sycl::handler& cgh) { using ltype = decltype(k_lhs); using rtype = decltype(k_rhs); - using kname = gt::backend::sycl::Assign3; + using kname = gt::backend::sycl::Assign3; cgh.parallel_for(range, [=](sycl::item<3> item) { auto i = item.get_id(2); auto j = item.get_id(1); @@ -493,16 +493,16 @@ struct assigner q.copy(&k_rhs, d_rhs_p, 1).wait(); auto e = q.submit([&](sycl::handler& cgh) { - using kname = gt::backend::sycl::AssignN; - cgh.parallel_for(sycl::range<1>(size), [=](sycl::id<1> i) { + using kname = gt::backend::sycl::AssignN; + cgh.parallel_for(sycl::range<1>(size), [=](sycl::id<1> i) { auto idx = unravel(i, strides); index_expression(k_lhs, idx) = index_expression(*d_rhs_p, idx); }); }); } else { auto e = q.submit([&](sycl::handler& cgh) { - using kname = gt::backend::sycl::AssignN; - cgh.parallel_for(sycl::range<1>(size), [=](sycl::id<1> i) { + using kname = gt::backend::sycl::AssignN; + cgh.parallel_for(sycl::range<1>(size), [=](sycl::id<1> i) { auto idx = unravel(i, strides); index_expression(k_lhs, idx) = index_expression(k_rhs, idx); }); diff --git a/include/gtensor/backend_sycl.h b/include/gtensor/backend_sycl.h index f23a8827..8ddedf5e 100644 --- a/include/gtensor/backend_sycl.h +++ b/include/gtensor/backend_sycl.h @@ -24,13 +24,13 @@ namespace sycl { // kernel name templates -template +template class Assign1; -template +template class Assign2; -template +template class Assign3; -template +template class AssignN; template diff --git a/include/gtensor/complex.h b/include/gtensor/complex.h index a20c91ec..65991bba 100644 --- a/include/gtensor/complex.h +++ b/include/gtensor/complex.h @@ -29,7 +29,7 @@ using complex = thrust::complex; // TODO: this will hopefully be standardized soon and be sycl::complex template -using complex = gt::sycl_cplx::complex; +using complex = gt::sycl_cplx::complex; #else // fallback to std::complex, e.g. for host backend diff --git a/include/gtensor/sycl_ext_complex.hpp b/include/gtensor/sycl_ext_complex.hpp index 0a9e08be..fb732fd1 100644 --- a/include/gtensor/sycl_ext_complex.hpp +++ b/include/gtensor/sycl_ext_complex.hpp @@ -271,6 +271,7 @@ template complex tanh (const complex&); _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD +namespace cplex::detail { template struct __numeric_type { static void __test(...); static sycl::half __test(sycl::half); @@ -328,8 +329,9 @@ template class __promote_imp<_A1, void, void, true> { template class __promote : public __promote_imp<_A1, _A2, _A3> {}; +} -template class complex; +template class complex; template struct is_gencomplex @@ -345,11 +347,7 @@ struct is_genfloat std::is_same_v<_Tp, sycl::half>> {}; template -complex<_Tp> operator*(const complex<_Tp> &__z, const complex<_Tp> &__w); -template -complex<_Tp> operator/(const complex<_Tp> &__x, const complex<_Tp> &__y); - -template class complex { +class complex<_Tp, typename std::enable_if::value>::type> { public: typedef _Tp value_type; @@ -358,628 +356,344 @@ template class complex { value_type __im_; public: - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex( - const value_type &__re = value_type(), - const value_type &__im = value_type()) - : __re_(__re), __im_(__im) {} - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex(const complex<_Xp> &__c) - : __re_(__c.real()), __im_(__c.imag()) {} - - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex( - const std::complex<_Xp> &__c) - : __re_(__c.real()), __im_(__c.imag()) {} - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY operator std::complex<_Xp>() { - return std::complex<_Xp>(__re_, __im_); - } - - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr value_type real() const { - return __re_; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr value_type imag() const { - return __im_; - } - - _SYCL_EXT_CPLX_INLINE_VISIBILITY void real(value_type __re) { __re_ = __re; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY void imag(value_type __im) { __im_ = __im; } + _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex(value_type __re = value_type(), value_type __im = value_type()) : __re_(__re), __im_(__im) { - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator=(const value_type &__re) { - __re_ = __re; - __im_ = value_type(); - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator+=(const value_type &__re) { - __re_ += __re; - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator-=(const value_type &__re) { - __re_ -= __re; - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator*=(const value_type &__re) { - __re_ *= __re; - __im_ *= __re; - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator/=(const value_type &__re) { - __re_ /= __re; - __im_ /= __re; - return *this; - } - - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator=(const complex<_Xp> &__c) { - __re_ = __c.real(); - __im_ = __c.imag(); - return *this; - } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator=(const std::complex<_Xp> &__c) { - __re_ = __c.real(); - __im_ = __c.imag(); - return *this; - } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator+=(const complex<_Xp> &__c) { - __re_ += __c.real(); - __im_ += __c.imag(); - return *this; - } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator-=(const complex<_Xp> &__c) { - __re_ -= __c.real(); - __im_ -= __c.imag(); - return *this; } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator*=(const complex<_Xp> &__c) { - *this = *this * complex(__c.real(), __c.imag()); - return *this; - } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator/=(const complex<_Xp> &__c) { - *this = *this / complex(__c.real(), __c.imag()); - return *this; - } -}; - -template <> class complex; -template <> class complex; -template <> class complex { - sycl::half __re_; - sycl::half __im_; + template + _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex(const complex<_Xp> &__c) : __re_(__c.real()), __im_(__c.imag()) { -public: - typedef sycl::half value_type; - - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex( - sycl::half __re = sycl::half{}, sycl::half __im = sycl::half{}) - : __re_(__re), __im_(__im) {} - _SYCL_EXT_CPLX_INLINE_VISIBILITY - explicit constexpr complex(const complex &__c); - _SYCL_EXT_CPLX_INLINE_VISIBILITY - explicit constexpr complex(const complex &__c); - _SYCL_EXT_CPLX_INLINE_VISIBILITY - constexpr operator std::complex() { - return std::complex(__re_, __im_); } - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr sycl::half real() const { - return __re_; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr sycl::half imag() const { - return __im_; - } - - _SYCL_EXT_CPLX_INLINE_VISIBILITY void real(value_type __re) { __re_ = __re; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY void imag(value_type __im) { __im_ = __im; } - - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator=(sycl::half __re) { - __re_ = __re; - __im_ = value_type(); - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator+=(sycl::half __re) { - __re_ += __re; - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator-=(sycl::half __re) { - __re_ -= __re; - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator*=(sycl::half __re) { - __re_ *= __re; - __im_ *= __re; - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator/=(sycl::half __re) { - __re_ /= __re; - __im_ /= __re; - return *this; - } + _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex(const std::complex &__c) : __re_(__c.real()), __im_(__c.imag()) { - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator=(const complex<_Xp> &__c) { - __re_ = __c.real(); - __im_ = __c.imag(); - return *this; - } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator=(const std::complex<_Xp> &__c) { - __re_ = __c.real(); - __im_ = __c.imag(); - return *this; - } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator+=(const complex<_Xp> &__c) { - __re_ += __c.real(); - __im_ += __c.imag(); - return *this; - } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator-=(const complex<_Xp> &__c) { - __re_ -= __c.real(); - __im_ -= __c.imag(); - return *this; } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator*=(const complex<_Xp> &__c) { - *this = *this * complex(__c.real(), __c.imag()); - return *this; - } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator/=(const complex<_Xp> &__c) { - *this = *this / complex(__c.real(), __c.imag()); - return *this; - } -}; -template <> class complex { - float __re_; - float __im_; - -public: - typedef float value_type; - - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex(float __re = 0.0f, - float __im = 0.0f) - : __re_(__re), __im_(__im) {} - _SYCL_EXT_CPLX_INLINE_VISIBILITY - constexpr complex(const complex &__c); - _SYCL_EXT_CPLX_INLINE_VISIBILITY - explicit constexpr complex(const complex &__c); - _SYCL_EXT_CPLX_INLINE_VISIBILITY - constexpr complex(const std::complex &__c) - : __re_(__c.real()), __im_(__c.imag()) {} - _SYCL_EXT_CPLX_INLINE_VISIBILITY - constexpr operator std::complex() { - return std::complex(__re_, __im_); + _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr operator std::complex() const { + return std::complex(__re_, __im_); } - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr float real() const { + _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr value_type real() const { return __re_; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr float imag() const { + _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr value_type imag() const { return __im_; } _SYCL_EXT_CPLX_INLINE_VISIBILITY void real(value_type __re) { __re_ = __re; } _SYCL_EXT_CPLX_INLINE_VISIBILITY void imag(value_type __im) { __im_ = __im; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator=(float __re) { + template + _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Xp> &operator=(value_type __re) { __re_ = __re; __im_ = value_type(); return *this; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator+=(float __re) { - __re_ += __re; - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + &operator+=(complex &__c, value_type __re) { + __c.__re_ += __re; + return __c; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator-=(float __re) { - __re_ -= __re; - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + &operator-=(complex &__c, value_type __re) { + __c.__re_ -= __re; + return __c; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator*=(float __re) { - __re_ *= __re; - __im_ *= __re; - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + &operator*=(complex &__c, value_type __re) { + __c.__re_ *= __re; + __c.__im_ *= __re; + return __c; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator/=(float __re) { - __re_ /= __re; - __im_ /= __re; - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + &operator/=(complex &__c, value_type __re) { + __c.__re_ /= __re; + __c.__im_ /= __re; + return __c; } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator=(const complex<_Xp> &__c) { - __re_ = __c.real(); - __im_ = __c.imag(); - return *this; - } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator=(const std::complex<_Xp> &__c) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Xp> &operator=(const complex<_Xp> &__c) { __re_ = __c.real(); __im_ = __c.imag(); return *this; } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator+=(const complex<_Xp> &__c) { - __re_ += __c.real(); - __im_ += __c.imag(); - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + &operator+=(complex &__x, const complex<_Xp> &__y) { + __x.__re_ += __y.real(); + __x.__im_ += __y.imag(); + return __x; } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator-=(const complex<_Xp> &__c) { - __re_ -= __c.real(); - __im_ -= __c.imag(); - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + &operator-=(complex &__x, const complex<_Xp> &__y) { + __x.__re_ -= __y.real(); + __x.__im_ -= __y.imag(); + return __x; } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator*=(const complex<_Xp> &__c) { - *this = *this * complex(__c.real(), __c.imag()); - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + &operator*=(complex &__x, const complex<_Xp> &__y) { + __x = __x * complex(__y.real(), __y.imag()); + return __x; } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator/=(const complex<_Xp> &__c) { - *this = *this / complex(__c.real(), __c.imag()); - return *this; - } -}; - -template <> class complex { - double __re_; - double __im_; - -public: - typedef double value_type; - - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex(double __re = 0.0, - double __im = 0.0) - : __re_(__re), __im_(__im) {} - _SYCL_EXT_CPLX_INLINE_VISIBILITY - constexpr complex(const complex &__c); - _SYCL_EXT_CPLX_INLINE_VISIBILITY - constexpr complex(const complex &__c); - _SYCL_EXT_CPLX_INLINE_VISIBILITY - constexpr complex(const std::complex &__c) - : __re_(__c.real()), __im_(__c.imag()) {} - _SYCL_EXT_CPLX_INLINE_VISIBILITY - constexpr operator std::complex() { - return std::complex(__re_, __im_); - } - - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr double real() const { - return __re_; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr double imag() const { - return __im_; - } - - _SYCL_EXT_CPLX_INLINE_VISIBILITY void real(value_type __re) { __re_ = __re; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY void imag(value_type __im) { __im_ = __im; } - - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator=(double __re) { - __re_ = __re; - __im_ = value_type(); - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator+=(double __re) { - __re_ += __re; - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator-=(double __re) { - __re_ -= __re; - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator*=(double __re) { - __re_ *= __re; - __im_ *= __re; - return *this; - } - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator/=(double __re) { - __re_ /= __re; - __im_ /= __re; - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + &operator/=(complex &__x, const complex<_Xp> &__y) { + __x = __x / complex(__y.real(), __y.imag()); + return __x; } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator=(const complex<_Xp> &__c) { - __re_ = __c.real(); - __im_ = __c.imag(); - return *this; - } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator=(const std::complex<_Xp> &__c) { - __re_ = __c.real(); - __im_ = __c.imag(); - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator+(const complex &__x, const complex &__y) { + complex __t(__x); + __t += __y; + return __t; } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator+=(const complex<_Xp> &__c) { - __re_ += __c.real(); - __im_ += __c.imag(); - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator+(const complex &__x, value_type __y) { + complex __t(__x); + __t += __y; + return __t; } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator-=(const complex<_Xp> &__c) { - __re_ -= __c.real(); - __im_ -= __c.imag(); - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator+(value_type __x, const complex &__y) { + complex __t(__y); + __t += __x; + return __t; } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator*=(const complex<_Xp> &__c) { - *this = *this * complex(__c.real(), __c.imag()); - return *this; - } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex & - operator/=(const complex<_Xp> &__c) { - *this = *this / complex(__c.real(), __c.imag()); - return *this; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator+(const complex &__x) { + return __x; } -}; - -inline constexpr complex::complex(const complex &__c) - : __re_(__c.real()), __im_(__c.imag()) {} - -inline constexpr complex::complex(const complex &__c) - : __re_(__c.real()), __im_(__c.imag()) {} - -inline constexpr complex::complex(const complex &__c) - : __re_(__c.real()), __im_(__c.imag()) {} - -inline constexpr complex::complex(const complex &__c) - : __re_(__c.real()), __im_(__c.imag()) {} - -inline constexpr complex::complex(const complex &__c) - : __re_(__c.real()), __im_(__c.imag()) {} - -inline constexpr complex::complex(const complex &__c) - : __re_(__c.real()), __im_(__c.imag()) {} - -// 26.3.6 operators: - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> -operator+(const complex<_Tp> &__x, const complex<_Tp> &__y) { - complex<_Tp> __t(__x); - __t += __y; - return __t; -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> operator+(const complex<_Tp> &__x, - const _Tp &__y) { - complex<_Tp> __t(__x); - __t += __y; - return __t; -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> -operator+(const _Tp &__x, const complex<_Tp> &__y) { - complex<_Tp> __t(__y); - __t += __x; - return __t; -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> -operator-(const complex<_Tp> &__x, const complex<_Tp> &__y) { - complex<_Tp> __t(__x); - __t -= __y; - return __t; -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> operator-(const complex<_Tp> &__x, - const _Tp &__y) { - complex<_Tp> __t(__x); - __t -= __y; - return __t; -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> -operator-(const _Tp &__x, const complex<_Tp> &__y) { - complex<_Tp> __t(-__y); - __t += __x; - return __t; -} -template -complex<_Tp> operator*(const complex<_Tp> &__z, const complex<_Tp> &__w) { - _Tp __a = __z.real(); - _Tp __b = __z.imag(); - _Tp __c = __w.real(); - _Tp __d = __w.imag(); - _Tp __ac = __a * __c; - _Tp __bd = __b * __d; - _Tp __ad = __a * __d; - _Tp __bc = __b * __c; - _Tp __x = __ac - __bd; - _Tp __y = __ad + __bc; - if (sycl::isnan(__x) && sycl::isnan(__y)) { - bool __recalc = false; - if (sycl::isinf(__a) || sycl::isinf(__b)) { - __a = sycl::copysign(sycl::isinf(__a) ? _Tp(1) : _Tp(0), __a); - __b = sycl::copysign(sycl::isinf(__b) ? _Tp(1) : _Tp(0), __b); - if (sycl::isnan(__c)) - __c = sycl::copysign(_Tp(0), __c); - if (sycl::isnan(__d)) - __d = sycl::copysign(_Tp(0), __d); - __recalc = true; + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator-(const complex &__x, const complex &__y) { + complex __t(__x); + __t -= __y; + return __t; + } + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator-(const complex &__x, value_type __y) { + complex __t(__x); + __t -= __y; + return __t; + } + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator-(value_type __x, const complex &__y) { + complex __t(-__y); + __t += __x; + return __t; + } + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator-(const complex &__x) { + return complex(-__x.__re_, -__x.__im_); + } + + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator*(const complex &__z, const complex &__w) { + value_type __a = __z.__re_; + value_type __b = __z.__im_; + value_type __c = __w.__re_; + value_type __d = __w.__im_; + value_type __ac = __a * __c; + value_type __bd = __b * __d; + value_type __ad = __a * __d; + value_type __bc = __b * __c; + value_type __x = __ac - __bd; + value_type __y = __ad + __bc; + if (sycl::isnan(__x) && sycl::isnan(__y)) { + bool __recalc = false; + if (sycl::isinf(__a) || sycl::isinf(__b)) { + __a = sycl::copysign(sycl::isinf(__a) ? value_type(1) : value_type(0), __a); + __b = sycl::copysign(sycl::isinf(__b) ? value_type(1) : value_type(0), __b); + if (sycl::isnan(__c)) + __c = sycl::copysign(value_type(0), __c); + if (sycl::isnan(__d)) + __d = sycl::copysign(value_type(0), __d); + __recalc = true; + } + if (sycl::isinf(__c) || sycl::isinf(__d)) { + __c = sycl::copysign(sycl::isinf(__c) ? value_type(1) : value_type(0), __c); + __d = sycl::copysign(sycl::isinf(__d) ? value_type(1) : value_type(0), __d); + if (sycl::isnan(__a)) + __a = sycl::copysign(value_type(0), __a); + if (sycl::isnan(__b)) + __b = sycl::copysign(value_type(0), __b); + __recalc = true; + } + if (!__recalc && (sycl::isinf(__ac) || sycl::isinf(__bd) || + sycl::isinf(__ad) || sycl::isinf(__bc))) { + if (sycl::isnan(__a)) + __a = sycl::copysign(value_type(0), __a); + if (sycl::isnan(__b)) + __b = sycl::copysign(value_type(0), __b); + if (sycl::isnan(__c)) + __c = sycl::copysign(value_type(0), __c); + if (sycl::isnan(__d)) + __d = sycl::copysign(value_type(0), __d); + __recalc = true; + } + if (__recalc) { + __x = value_type(INFINITY) * (__a * __c - __b * __d); + __y = value_type(INFINITY) * (__a * __d + __b * __c); + } } - if (sycl::isinf(__c) || sycl::isinf(__d)) { - __c = sycl::copysign(sycl::isinf(__c) ? _Tp(1) : _Tp(0), __c); - __d = sycl::copysign(sycl::isinf(__d) ? _Tp(1) : _Tp(0), __d); - if (sycl::isnan(__a)) - __a = sycl::copysign(_Tp(0), __a); - if (sycl::isnan(__b)) - __b = sycl::copysign(_Tp(0), __b); - __recalc = true; + return complex(__x, __y); + } + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator*(const complex &__x, value_type __y) { + complex __t(__x); + __t *= __y; + return __t; + } + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator*(value_type __x, const complex &__y) { + complex __t(__y); + __t *= __x; + return __t; + } + + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator/(const complex &__z, const complex &__w) { + int __ilogbw = 0; + value_type __a = __z.__re_; + value_type __b = __z.__im_; + value_type __c = __w.__re_; + value_type __d = __w.__im_; + value_type __logbw = sycl::logb(sycl::fmax(sycl::fabs(__c), sycl::fabs(__d))); + if (sycl::isfinite(__logbw)) { + __ilogbw = static_cast(__logbw); + __c = sycl::ldexp(__c, -__ilogbw); + __d = sycl::ldexp(__d, -__ilogbw); } - if (!__recalc && (sycl::isinf(__ac) || sycl::isinf(__bd) || - sycl::isinf(__ad) || sycl::isinf(__bc))) { - if (sycl::isnan(__a)) - __a = sycl::copysign(_Tp(0), __a); - if (sycl::isnan(__b)) - __b = sycl::copysign(_Tp(0), __b); - if (sycl::isnan(__c)) - __c = sycl::copysign(_Tp(0), __c); - if (sycl::isnan(__d)) - __d = sycl::copysign(_Tp(0), __d); - __recalc = true; - } - if (__recalc) { - __x = _Tp(INFINITY) * (__a * __c - __b * __d); - __y = _Tp(INFINITY) * (__a * __d + __b * __c); + value_type __denom = __c * __c + __d * __d; + value_type __x = sycl::ldexp((__a * __c + __b * __d) / __denom, -__ilogbw); + value_type __y = sycl::ldexp((__b * __c - __a * __d) / __denom, -__ilogbw); + if (sycl::isnan(__x) && sycl::isnan(__y)) { + if ((__denom == value_type(0)) && (!sycl::isnan(__a) || !sycl::isnan(__b))) { + __x = sycl::copysign(value_type(INFINITY), __c) * __a; + __y = sycl::copysign(value_type(INFINITY), __c) * __b; + } else if ((sycl::isinf(__a) || sycl::isinf(__b)) && sycl::isfinite(__c) && + sycl::isfinite(__d)) { + __a = sycl::copysign(sycl::isinf(__a) ? value_type(1) : value_type(0), __a); + __b = sycl::copysign(sycl::isinf(__b) ? value_type(1) : value_type(0), __b); + __x = value_type(INFINITY) * (__a * __c + __b * __d); + __y = value_type(INFINITY) * (__b * __c - __a * __d); + } else if (sycl::isinf(__logbw) && __logbw > value_type(0) && + sycl::isfinite(__a) && sycl::isfinite(__b)) { + __c = sycl::copysign(sycl::isinf(__c) ? value_type(1) : value_type(0), __c); + __d = sycl::copysign(sycl::isinf(__d) ? value_type(1) : value_type(0), __d); + __x = value_type(0) * (__a * __c + __b * __d); + __y = value_type(0) * (__b * __c - __a * __d); + } } + return complex(__x, __y); + } + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator/(const complex &__x, value_type __y) { + return complex(__x.__re_ / __y, __x.__im_ / __y); + } + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex + operator/(value_type __x, const complex &__y) { + complex __t(__x); + __t /= __y; + return __t; + } + + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend constexpr bool + operator==(const complex &__x, const complex &__y) { + return __x.__re_ == __y.__re_ && __x.__im_ == __y.__im_; + } + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend constexpr bool + operator==(const complex &__x, value_type __y) { + return __x.__re_ == __y && __x.__im_ == 0; + } + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend constexpr bool + operator==(value_type __x, const complex &__y) { + return __x == __y.__re_ && 0 == __y.__im_; + } + + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend constexpr bool + operator!=(const complex &__x, const complex &__y) { + return !(__x == __y); + } + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend constexpr bool + operator!=(const complex &__x, value_type __y) { + return !(__x == __y); + } + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend constexpr bool + operator!=(value_type __x, const complex &__y) { + return !(__x == __y); + } + + template + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend std::basic_istream<_CharT, _Traits> + &operator>>(std::basic_istream<_CharT, _Traits> &__is, complex &__x) { + if (__is.good()) { + ws(__is); + if (__is.peek() == _CharT('(')) { + __is.get(); + value_type __r; + __is >> __r; + if (!__is.fail()) { + ws(__is); + _CharT __c = __is.peek(); + if (__c == _CharT(',')) { + __is.get(); + value_type __i; + __is >> __i; + if (!__is.fail()) { + ws(__is); + __c = __is.peek(); + if (__c == _CharT(')')) { + __is.get(); + __x = complex(__r, __i); + } else + __is.setstate(__is.failbit); + } else + __is.setstate(__is.failbit); + } else if (__c == _CharT(')')) { + __is.get(); + __x = complex(__r, value_type(0)); + } else + __is.setstate(__is.failbit); + } else + __is.setstate(__is.failbit); + } else { + value_type __r; + __is >> __r; + if (!__is.fail()) + __x = complex(__r, value_type(0)); + else + __is.setstate(__is.failbit); + } + } else + __is.setstate(__is.failbit); + return __is; } - return complex<_Tp>(__x, __y); -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> operator*(const complex<_Tp> &__x, - const _Tp &__y) { - complex<_Tp> __t(__x); - __t *= __y; - return __t; -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> -operator*(const _Tp &__x, const complex<_Tp> &__y) { - complex<_Tp> __t(__y); - __t *= __x; - return __t; -} -template -complex<_Tp> operator/(const complex<_Tp> &__z, const complex<_Tp> &__w) { - int __ilogbw = 0; - _Tp __a = __z.real(); - _Tp __b = __z.imag(); - _Tp __c = __w.real(); - _Tp __d = __w.imag(); - _Tp __logbw = sycl::logb(sycl::fmax(sycl::fabs(__c), sycl::fabs(__d))); - if (sycl::isfinite(__logbw)) { - __ilogbw = static_cast(__logbw); - __c = sycl::ldexp(__c, -__ilogbw); - __d = sycl::ldexp(__d, -__ilogbw); + template + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend std::basic_ostream<_CharT, _Traits> + &operator<<(std::basic_ostream<_CharT, _Traits> &__os, const complex &__x) { + std::basic_ostringstream<_CharT, _Traits> __s; + __s.flags(__os.flags()); + __s.imbue(__os.getloc()); + __s.precision(__os.precision()); + __s << '(' << __x.__re_ << ',' << __x.__im_ << ')'; + return __os << __s.str(); } - _Tp __denom = __c * __c + __d * __d; - _Tp __x = sycl::ldexp((__a * __c + __b * __d) / __denom, -__ilogbw); - _Tp __y = sycl::ldexp((__b * __c - __a * __d) / __denom, -__ilogbw); - if (sycl::isnan(__x) && sycl::isnan(__y)) { - if ((__denom == _Tp(0)) && (!sycl::isnan(__a) || !sycl::isnan(__b))) { - __x = sycl::copysign(_Tp(INFINITY), __c) * __a; - __y = sycl::copysign(_Tp(INFINITY), __c) * __b; - } else if ((sycl::isinf(__a) || sycl::isinf(__b)) && sycl::isfinite(__c) && - sycl::isfinite(__d)) { - __a = sycl::copysign(sycl::isinf(__a) ? _Tp(1) : _Tp(0), __a); - __b = sycl::copysign(sycl::isinf(__b) ? _Tp(1) : _Tp(0), __b); - __x = _Tp(INFINITY) * (__a * __c + __b * __d); - __y = _Tp(INFINITY) * (__b * __c - __a * __d); - } else if (sycl::isinf(__logbw) && __logbw > _Tp(0) && - sycl::isfinite(__a) && sycl::isfinite(__b)) { - __c = sycl::copysign(sycl::isinf(__c) ? _Tp(1) : _Tp(0), __c); - __d = sycl::copysign(sycl::isinf(__d) ? _Tp(1) : _Tp(0), __d); - __x = _Tp(0) * (__a * __c + __b * __d); - __y = _Tp(0) * (__b * __c - __a * __d); - } - } - return complex<_Tp>(__x, __y); -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> operator/(const complex<_Tp> &__x, - const _Tp &__y) { - return complex<_Tp>(__x.real() / __y, __x.imag() / __y); -} -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> -operator/(const _Tp &__x, const complex<_Tp> &__y) { - complex<_Tp> __t(__x); - __t /= __y; - return __t; -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> -operator+(const complex<_Tp> &__x) { - return __x; -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> -operator-(const complex<_Tp> &__x) { - return complex<_Tp>(-__x.real(), -__x.imag()); -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr bool -operator==(const complex<_Tp> &__x, const complex<_Tp> &__y) { - return __x.real() == __y.real() && __x.imag() == __y.imag(); -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr bool -operator==(const complex<_Tp> &__x, const _Tp &__y) { - return __x.real() == __y && __x.imag() == 0; -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr bool -operator==(const _Tp &__x, const complex<_Tp> &__y) { - return __x == __y.real() && 0 == __y.imag(); -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr bool -operator!=(const complex<_Tp> &__x, const complex<_Tp> &__y) { - return !(__x == __y); -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr bool -operator!=(const complex<_Tp> &__x, const _Tp &__y) { - return !(__x == __y); -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr bool -operator!=(const _Tp &__x, const complex<_Tp> &__y) { - return !(__x == __y); -} - -// 26.3.7 values: + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend const sycl::stream + &operator<<(const sycl::stream &__ss, const complex &_x) { + return __ss << "(" << _x.__re_ << "," << _x.__im_ << ")"; + } +}; +namespace cplex::detail { template ::value, - bool = std::is_floating_point<_Tp>::value> + bool = is_genfloat<_Tp>::value> struct __libcpp_complex_overload_traits {}; // Integral Types @@ -993,68 +707,60 @@ template struct __libcpp_complex_overload_traits<_Tp, false, true> { typedef _Tp _ValueType; typedef complex<_Tp> _ComplexType; }; +} // real -template +template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr _Tp real(const complex<_Tp> &__c) { return __c.real(); } template _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr - typename __libcpp_complex_overload_traits<_Tp>::_ValueType + typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType real(_Tp __re) { return __re; } // imag -template +template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr _Tp imag(const complex<_Tp> &__c) { return __c.imag(); } template _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr - typename __libcpp_complex_overload_traits<_Tp>::_ValueType + typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType imag(_Tp) { return 0; } // abs -template +template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY _Tp abs(const complex<_Tp> &__c) { return sycl::hypot(__c.real(), __c.imag()); } // arg -template +template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY _Tp arg(const complex<_Tp> &__c) { return sycl::atan2(__c.imag(), __c.real()); } template _SYCL_EXT_CPLX_INLINE_VISIBILITY - typename std::enable_if::value || - std::is_same<_Tp, double>::value, - double>::type - arg(_Tp __re) { - return sycl::atan2(0., __re); -} - -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY - typename std::enable_if::value, float>::type - arg(_Tp __re) { - return sycl::atan2(0.F, __re); +typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType arg(_Tp __re) { + typedef typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType _ValueType; + return sycl::atan2<_ValueType>(0, __re); } // norm -template +template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY _Tp norm(const complex<_Tp> &__c) { if (sycl::isinf(__c.real())) return sycl::fabs(__c.real()); @@ -1065,31 +771,31 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY _Tp norm(const complex<_Tp> &__c) { template _SYCL_EXT_CPLX_INLINE_VISIBILITY - typename __libcpp_complex_overload_traits<_Tp>::_ValueType + typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType norm(_Tp __re) { - typedef typename __libcpp_complex_overload_traits<_Tp>::_ValueType _ValueType; + typedef typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType _ValueType; return static_cast<_ValueType>(__re) * __re; } // conj -template +template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> conj(const complex<_Tp> &__c) { return complex<_Tp>(__c.real(), -__c.imag()); } template _SYCL_EXT_CPLX_INLINE_VISIBILITY - typename __libcpp_complex_overload_traits<_Tp>::_ComplexType + typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ComplexType conj(_Tp __re) { typedef - typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; + typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; return _ComplexType(__re); } // proj -template +template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> proj(const complex<_Tp> &__c) { complex<_Tp> __r = __c; if (sycl::isinf(__c.real()) || sycl::isinf(__c.imag())) @@ -1098,29 +804,23 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> proj(const complex<_Tp> &__c) { } template -_SYCL_EXT_CPLX_INLINE_VISIBILITY typename std::enable_if< - std::is_floating_point<_Tp>::value, - typename __libcpp_complex_overload_traits<_Tp>::_ComplexType>::type -proj(_Tp __re) { - if (sycl::isinf(__re)) - __re = sycl::fabs(__re); - return complex<_Tp>(__re); -} +_SYCL_EXT_CPLX_INLINE_VISIBILITY + typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ComplexType + proj(_Tp __re) { + typedef typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; + + if constexpr(!std::is_integral_v<_Tp>) { + if (sycl::isinf(__re)) + __re = sycl::fabs(__re); + } -template -_SYCL_EXT_CPLX_INLINE_VISIBILITY typename std::enable_if< - std::is_integral<_Tp>::value, - typename __libcpp_complex_overload_traits<_Tp>::_ComplexType>::type -proj(_Tp __re) { - typedef - typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; return _ComplexType(__re); } // polar template ::value>> -complex<_Tp> polar(const _Tp &__rho, const _Tp &__theta = _Tp()) { +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> polar(const _Tp &__rho, const _Tp &__theta = _Tp()) { if (sycl::isnan(__rho) || sycl::signbit(__rho)) return complex<_Tp>(_Tp(NAN), _Tp(NAN)); if (sycl::isnan(__theta)) { @@ -1159,7 +859,7 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> log10(const complex<_Tp> &__x) { // sqrt template ::value>> -complex<_Tp> sqrt(const complex<_Tp> &__x) { +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> sqrt(const complex<_Tp> &__x) { if (sycl::isinf(__x.imag())) return complex<_Tp>(_Tp(INFINITY), __x.imag()); if (sycl::isinf(__x.real())) { @@ -1176,7 +876,7 @@ complex<_Tp> sqrt(const complex<_Tp> &__x) { // exp template ::value>> -complex<_Tp> exp(const complex<_Tp> &__x) { +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> exp(const complex<_Tp> &__x) { _Tp __i = __x.imag(); if (__i == 0) { return complex<_Tp>(sycl::exp(__x.real()), @@ -1206,9 +906,9 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> pow(const complex<_Tp> &__x, template ::value>> -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex::type> +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex::type> pow(const complex<_Tp> &__x, const complex<_Up> &__y) { - typedef complex::type> result_type; + typedef complex::type> result_type; return pow(result_type(__x), result_type(__y)); } @@ -1216,9 +916,9 @@ template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY typename std::enable_if::value, - complex::type>>::type + complex::type>>::type pow(const complex<_Tp> &__x, const _Up &__y) { - typedef complex::type> result_type; + typedef complex::type> result_type; return pow(result_type(__x), result_type(__y)); } @@ -1226,12 +926,13 @@ template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY typename std::enable_if::value, - complex::type>>::type + complex::type>>::type pow(const _Tp &__x, const complex<_Up> &__y) { - typedef complex::type> result_type; + typedef complex::type> result_type; return pow(result_type(__x), result_type(__y)); } +namespace cplex::detail { // __sqr, computes pow(x, 2) template ::value>> @@ -1239,11 +940,12 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> __sqr(const complex<_Tp> &__x) { return complex<_Tp>((__x.real() - __x.imag()) * (__x.real() + __x.imag()), _Tp(2) * __x.real() * __x.imag()); } +} // asinh template ::value>> -complex<_Tp> asinh(const complex<_Tp> &__x) { +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> asinh(const complex<_Tp> &__x) { const _Tp __pi(sycl::atan2(+0., -0.)); if (sycl::isinf(__x.real())) { if (sycl::isnan(__x.imag())) @@ -1263,7 +965,7 @@ complex<_Tp> asinh(const complex<_Tp> &__x) { if (sycl::isinf(__x.imag())) return complex<_Tp>(sycl::copysign(__x.imag(), __x.real()), sycl::copysign(__pi / _Tp(2), __x.imag())); - complex<_Tp> __z = log(__x + sqrt(__sqr(__x) + _Tp(1))); + complex<_Tp> __z = log(__x + sqrt(cplex::detail::__sqr(__x) + _Tp(1))); return complex<_Tp>(sycl::copysign(__z.real(), __x.real()), sycl::copysign(__z.imag(), __x.imag())); } @@ -1271,7 +973,7 @@ complex<_Tp> asinh(const complex<_Tp> &__x) { // acosh template ::value>> -complex<_Tp> acosh(const complex<_Tp> &__x) { +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> acosh(const complex<_Tp> &__x) { const _Tp __pi(sycl::atan2(+0., -0.)); if (sycl::isinf(__x.real())) { if (sycl::isnan(__x.imag())) @@ -1296,7 +998,7 @@ complex<_Tp> acosh(const complex<_Tp> &__x) { if (sycl::isinf(__x.imag())) return complex<_Tp>(sycl::fabs(__x.imag()), sycl::copysign(__pi / _Tp(2), __x.imag())); - complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); + complex<_Tp> __z = log(__x + sqrt(cplex::detail::__sqr(__x) - _Tp(1))); return complex<_Tp>(sycl::copysign(__z.real(), _Tp(0)), sycl::copysign(__z.imag(), __x.imag())); } @@ -1304,7 +1006,7 @@ complex<_Tp> acosh(const complex<_Tp> &__x) { // atanh template ::value>> -complex<_Tp> atanh(const complex<_Tp> &__x) { +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> atanh(const complex<_Tp> &__x) { const _Tp __pi(sycl::atan2(+0., -0.)); if (sycl::isinf(__x.imag())) { return complex<_Tp>(sycl::copysign(_Tp(0), __x.real()), @@ -1334,7 +1036,7 @@ complex<_Tp> atanh(const complex<_Tp> &__x) { // sinh template ::value>> -complex<_Tp> sinh(const complex<_Tp> &__x) { +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> sinh(const complex<_Tp> &__x) { if (sycl::isinf(__x.real()) && !sycl::isfinite(__x.imag())) return complex<_Tp>(__x.real(), _Tp(NAN)); if (__x.real() == 0 && !sycl::isfinite(__x.imag())) @@ -1348,7 +1050,7 @@ complex<_Tp> sinh(const complex<_Tp> &__x) { // cosh template ::value>> -complex<_Tp> cosh(const complex<_Tp> &__x) { +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> cosh(const complex<_Tp> &__x) { if (sycl::isinf(__x.real()) && !sycl::isfinite(__x.imag())) return complex<_Tp>(sycl::fabs(__x.real()), _Tp(NAN)); if (__x.real() == 0 && !sycl::isfinite(__x.imag())) @@ -1364,7 +1066,7 @@ complex<_Tp> cosh(const complex<_Tp> &__x) { // tanh template ::value>> -complex<_Tp> tanh(const complex<_Tp> &__x) { +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> tanh(const complex<_Tp> &__x) { if (sycl::isinf(__x.real())) { if (!sycl::isfinite(__x.imag())) return complex<_Tp>(sycl::copysign(_Tp(1), __x.real()), _Tp(0)); @@ -1386,7 +1088,7 @@ complex<_Tp> tanh(const complex<_Tp> &__x) { // asin template ::value>> -complex<_Tp> asin(const complex<_Tp> &__x) { +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> asin(const complex<_Tp> &__x) { complex<_Tp> __z = asinh(complex<_Tp>(-__x.imag(), __x.real())); return complex<_Tp>(__z.imag(), -__z.real()); } @@ -1394,7 +1096,7 @@ complex<_Tp> asin(const complex<_Tp> &__x) { // acos template ::value>> -complex<_Tp> acos(const complex<_Tp> &__x) { +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> acos(const complex<_Tp> &__x) { const _Tp __pi(sycl::atan2(+0., -0.)); if (sycl::isinf(__x.real())) { if (sycl::isnan(__x.imag())) @@ -1419,7 +1121,7 @@ complex<_Tp> acos(const complex<_Tp> &__x) { return complex<_Tp>(__pi / _Tp(2), -__x.imag()); if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag()))) return complex<_Tp>(__pi / _Tp(2), -__x.imag()); - complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); + complex<_Tp> __z = log(__x + sqrt(cplex::detail::__sqr(__x) - _Tp(1))); if (sycl::signbit(__x.imag())) return complex<_Tp>(sycl::fabs(__z.imag()), sycl::fabs(__z.real())); return complex<_Tp>(sycl::fabs(__z.imag()), -sycl::fabs(__z.real())); @@ -1456,69 +1158,6 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> tan(const complex<_Tp> &__x) { return complex<_Tp>(__z.imag(), -__z.real()); } -template -std::basic_istream<_CharT, _Traits> & -operator>>(std::basic_istream<_CharT, _Traits> &__is, complex<_Tp> &__x) { - if (__is.good()) { - ws(__is); - if (__is.peek() == _CharT('(')) { - __is.get(); - _Tp __r; - __is >> __r; - if (!__is.fail()) { - ws(__is); - _CharT __c = __is.peek(); - if (__c == _CharT(',')) { - __is.get(); - _Tp __i; - __is >> __i; - if (!__is.fail()) { - ws(__is); - __c = __is.peek(); - if (__c == _CharT(')')) { - __is.get(); - __x = complex<_Tp>(__r, __i); - } else - __is.setstate(__is.failbit); - } else - __is.setstate(__is.failbit); - } else if (__c == _CharT(')')) { - __is.get(); - __x = complex<_Tp>(__r, _Tp(0)); - } else - __is.setstate(__is.failbit); - } else - __is.setstate(__is.failbit); - } else { - _Tp __r; - __is >> __r; - if (!__is.fail()) - __x = complex<_Tp>(__r, _Tp(0)); - else - __is.setstate(__is.failbit); - } - } else - __is.setstate(__is.failbit); - return __is; -} - -template -std::basic_ostream<_CharT, _Traits> & -operator<<(std::basic_ostream<_CharT, _Traits> &__os, const complex<_Tp> &__x) { - std::basic_ostringstream<_CharT, _Traits> __s; - __s.flags(__os.flags()); - __s.imbue(__os.getloc()); - __s.precision(__os.precision()); - __s << '(' << __x.real() << ',' << __x.imag() << ')'; - return __os << __s.str(); -} - -template ::value>> -_SYCL_EXT_CPLX_INLINE_VISIBILITY const sycl::stream & -operator<<(const sycl::stream &__ss, const complex<_Tp> &_x) { - return __ss << "(" << _x.real() << "," << _x.imag() << ")"; -} - _SYCL_EXT_CPLX_END_NAMESPACE_STD #undef _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD From 00f934099874424d041916488c9aaa58194ab713 Mon Sep 17 00:00:00 2001 From: Bryce Allen Date: Mon, 19 Jun 2023 10:02:59 -0400 Subject: [PATCH 2/2] sycl: new sycl cplx iplementation --- include/gtensor/sycl_ext_complex.hpp | 1345 +++++++++++++++++++++++--- 1 file changed, 1194 insertions(+), 151 deletions(-) diff --git a/include/gtensor/sycl_ext_complex.hpp b/include/gtensor/sycl_ext_complex.hpp index fb732fd1..91ffde32 100644 --- a/include/gtensor/sycl_ext_complex.hpp +++ b/include/gtensor/sycl_ext_complex.hpp @@ -254,6 +254,22 @@ template complex tanh (const complex&); #define _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD namespace _SYCL_CPLX_NAMESPACE { #define _SYCL_EXT_CPLX_END_NAMESPACE_STD } + +#ifndef _SYCL_MARRAY_NAMESPACE +#ifdef __HIPSYCL__ +#define _SYCL_MARRAY_NAMESPACE hipsycl::sycl +#else +#define _SYCL_MARRAY_NAMESPACE sycl +#endif +#endif + +#define _SYCL_MARRAY_BEGIN_NAMESPACE namespace _SYCL_MARRAY_NAMESPACE { +#define _SYCL_MARRAY_END_NAMESPACE } + +#if defined(__FAST_MATH__) || defined(_M_FP_FAST) +#define _SYCL_EXT_CPLX_FAST_MATH +#endif + #define _SYCL_EXT_CPLX_INLINE_VISIBILITY \ [[gnu::always_inline]] [[clang::always_inline]] inline @@ -329,9 +345,54 @@ template class __promote_imp<_A1, void, void, true> { template class __promote : public __promote_imp<_A1, _A2, _A3> {}; + +// Define our own fast-math aware wrappers for these routines, because +// some compilers are not able to perform the appropriate optimization +// without this extra help. +template +_SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr bool isnan(const T a) { +#ifdef _SYCL_EXT_CPLX_FAST_MATH + return false; +#else + return sycl::isnan(a); +#endif } -template class complex; +template +_SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr bool isfinite(const T a) { +#ifdef _SYCL_EXT_CPLX_FAST_MATH + return true; +#else + return sycl::isfinite(a); +#endif +} + +template +_SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr bool isinf(const T a) { +#ifdef _SYCL_EXT_CPLX_FAST_MATH + return false; +#else + return sycl::isinf(a); +#endif +} + +// To ensure loop unrolling is done when processing dimensions. +template +void loop_impl(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} + +template void loop(F &&f) { + loop_impl(std::make_index_sequence{}, std::forward(f)); +} + +} // namespace cplex::detail + +//////////////////////////////////////////////////////////////////////////////// +// COMPLEX IMPLEMENTATION +//////////////////////////////////////////////////////////////////////////////// + +template class complex; template struct is_gencomplex @@ -339,12 +400,16 @@ struct is_gencomplex std::is_same_v<_Tp, complex> || std::is_same_v<_Tp, complex> || std::is_same_v<_Tp, complex>> {}; +template +inline constexpr bool is_gencomplex_v = is_gencomplex<_Tp>::value; template struct is_genfloat : std::integral_constant || std::is_same_v<_Tp, float> || std::is_same_v<_Tp, sycl::half>> {}; +template +inline constexpr bool is_genfloat_v = is_genfloat<_Tp>::value; template class complex<_Tp, typename std::enable_if::value>::type> { @@ -356,21 +421,24 @@ class complex<_Tp, typename std::enable_if::value>::type> { value_type __im_; public: - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex(value_type __re = value_type(), value_type __im = value_type()) : __re_(__re), __im_(__im) { - - } + _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex( + value_type __re = value_type(), value_type __im = value_type()) + : __re_(__re), __im_(__im) {} template - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex(const complex<_Xp> &__c) : __re_(__c.real()), __im_(__c.imag()) { - - } - - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex(const std::complex &__c) : __re_(__c.real()), __im_(__c.imag()) { - - } - - _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr operator std::complex() const { - return std::complex(__re_, __im_); + _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex(const complex<_Xp> &__c) + : __re_(__c.real()), __im_(__c.imag()) {} + + template ::value>> + _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr complex( + const std::complex<_Xp> &__c) + : __re_(static_cast(__c.real())), + __im_(static_cast(__c.imag())) {} + + template ::value>> + _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr + operator std::complex<_Xp>() const { + return std::complex<_Xp>(static_cast<_Xp>(__re_), static_cast<_Xp>(__im_)); } _SYCL_EXT_CPLX_INLINE_VISIBILITY constexpr value_type real() const { @@ -383,64 +451,63 @@ class complex<_Tp, typename std::enable_if::value>::type> { _SYCL_EXT_CPLX_INLINE_VISIBILITY void real(value_type __re) { __re_ = __re; } _SYCL_EXT_CPLX_INLINE_VISIBILITY void imag(value_type __im) { __im_ = __im; } - template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Xp> &operator=(value_type __re) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator=(value_type __re) { __re_ = __re; __im_ = value_type(); return *this; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex - &operator+=(complex &__c, value_type __re) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex & + operator+=(complex &__c, value_type __re) { __c.__re_ += __re; return __c; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex - &operator-=(complex &__c, value_type __re) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex & + operator-=(complex &__c, value_type __re) { __c.__re_ -= __re; return __c; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex - &operator*=(complex &__c, value_type __re) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex & + operator*=(complex &__c, value_type __re) { __c.__re_ *= __re; __c.__im_ *= __re; return __c; } - _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex - &operator/=(complex &__c, value_type __re) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex & + operator/=(complex &__c, value_type __re) { __c.__re_ /= __re; __c.__im_ /= __re; return __c; } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Xp> &operator=(const complex<_Xp> &__c) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY complex &operator=(const complex<_Xp> &__c) { __re_ = __c.real(); __im_ = __c.imag(); return *this; } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex - &operator+=(complex &__x, const complex<_Xp> &__y) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex & + operator+=(complex &__x, const complex<_Xp> &__y) { __x.__re_ += __y.real(); __x.__im_ += __y.imag(); return __x; } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex - &operator-=(complex &__x, const complex<_Xp> &__y) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex & + operator-=(complex &__x, const complex<_Xp> &__y) { __x.__re_ -= __y.real(); __x.__im_ -= __y.imag(); return __x; } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex - &operator*=(complex &__x, const complex<_Xp> &__y) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex & + operator*=(complex &__x, const complex<_Xp> &__y) { __x = __x * complex(__y.real(), __y.imag()); return __x; } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex - &operator/=(complex &__x, const complex<_Xp> &__y) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex & + operator/=(complex &__x, const complex<_Xp> &__y) { __x = __x / complex(__y.real(), __y.imag()); return __x; } @@ -503,35 +570,40 @@ class complex<_Tp, typename std::enable_if::value>::type> { value_type __bc = __b * __c; value_type __x = __ac - __bd; value_type __y = __ad + __bc; - if (sycl::isnan(__x) && sycl::isnan(__y)) { + if (cplex::detail::isnan(__x) && cplex::detail::isnan(__y)) { bool __recalc = false; - if (sycl::isinf(__a) || sycl::isinf(__b)) { - __a = sycl::copysign(sycl::isinf(__a) ? value_type(1) : value_type(0), __a); - __b = sycl::copysign(sycl::isinf(__b) ? value_type(1) : value_type(0), __b); - if (sycl::isnan(__c)) + if (cplex::detail::isinf(__a) || cplex::detail::isinf(__b)) { + __a = sycl::copysign( + cplex::detail::isinf(__a) ? value_type(1) : value_type(0), __a); + __b = sycl::copysign( + cplex::detail::isinf(__b) ? value_type(1) : value_type(0), __b); + if (cplex::detail::isnan(__c)) __c = sycl::copysign(value_type(0), __c); - if (sycl::isnan(__d)) + if (cplex::detail::isnan(__d)) __d = sycl::copysign(value_type(0), __d); __recalc = true; } - if (sycl::isinf(__c) || sycl::isinf(__d)) { - __c = sycl::copysign(sycl::isinf(__c) ? value_type(1) : value_type(0), __c); - __d = sycl::copysign(sycl::isinf(__d) ? value_type(1) : value_type(0), __d); - if (sycl::isnan(__a)) + if (cplex::detail::isinf(__c) || cplex::detail::isinf(__d)) { + __c = sycl::copysign( + cplex::detail::isinf(__c) ? value_type(1) : value_type(0), __c); + __d = sycl::copysign( + cplex::detail::isinf(__d) ? value_type(1) : value_type(0), __d); + if (cplex::detail::isnan(__a)) __a = sycl::copysign(value_type(0), __a); - if (sycl::isnan(__b)) + if (cplex::detail::isnan(__b)) __b = sycl::copysign(value_type(0), __b); __recalc = true; } - if (!__recalc && (sycl::isinf(__ac) || sycl::isinf(__bd) || - sycl::isinf(__ad) || sycl::isinf(__bc))) { - if (sycl::isnan(__a)) + if (!__recalc && + (cplex::detail::isinf(__ac) || cplex::detail::isinf(__bd) || + cplex::detail::isinf(__ad) || cplex::detail::isinf(__bc))) { + if (cplex::detail::isnan(__a)) __a = sycl::copysign(value_type(0), __a); - if (sycl::isnan(__b)) + if (cplex::detail::isnan(__b)) __b = sycl::copysign(value_type(0), __b); - if (sycl::isnan(__c)) + if (cplex::detail::isnan(__c)) __c = sycl::copysign(value_type(0), __c); - if (sycl::isnan(__d)) + if (cplex::detail::isnan(__d)) __d = sycl::copysign(value_type(0), __d); __recalc = true; } @@ -557,13 +629,28 @@ class complex<_Tp, typename std::enable_if::value>::type> { _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex operator/(const complex &__z, const complex &__w) { +#if defined(_SYCL_EXT_CPLX_FAST_MATH) + // This implementation is around 20% faster for single precision, 5% for + // double, at the expense of larger error in some cases, because no scaling + // is done. + value_type __a = __z.__re_; + value_type __b = __z.__im_; + value_type __c = __w.__re_; + value_type __d = __w.__im_; + value_type __r = __a * __c + __b * __d; + value_type __n = __b * __b + __d * __d; + value_type __x = __r / __n; + value_type __y = (__b * __c - __a * __d) / __n; + return complex(__x, __y); +#else int __ilogbw = 0; value_type __a = __z.__re_; value_type __b = __z.__im_; value_type __c = __w.__re_; value_type __d = __w.__im_; - value_type __logbw = sycl::logb(sycl::fmax(sycl::fabs(__c), sycl::fabs(__d))); - if (sycl::isfinite(__logbw)) { + value_type __logbw = + sycl::logb(sycl::fmax(sycl::fabs(__c), sycl::fabs(__d))); + if (cplex::detail::isfinite(__logbw)) { __ilogbw = static_cast(__logbw); __c = sycl::ldexp(__c, -__ilogbw); __d = sycl::ldexp(__d, -__ilogbw); @@ -571,25 +658,31 @@ class complex<_Tp, typename std::enable_if::value>::type> { value_type __denom = __c * __c + __d * __d; value_type __x = sycl::ldexp((__a * __c + __b * __d) / __denom, -__ilogbw); value_type __y = sycl::ldexp((__b * __c - __a * __d) / __denom, -__ilogbw); - if (sycl::isnan(__x) && sycl::isnan(__y)) { - if ((__denom == value_type(0)) && (!sycl::isnan(__a) || !sycl::isnan(__b))) { + if (cplex::detail::isnan(__x) && cplex::detail::isnan(__y)) { + if ((__denom == value_type(0)) && + (!cplex::detail::isnan(__a) || !cplex::detail::isnan(__b))) { __x = sycl::copysign(value_type(INFINITY), __c) * __a; __y = sycl::copysign(value_type(INFINITY), __c) * __b; - } else if ((sycl::isinf(__a) || sycl::isinf(__b)) && sycl::isfinite(__c) && - sycl::isfinite(__d)) { - __a = sycl::copysign(sycl::isinf(__a) ? value_type(1) : value_type(0), __a); - __b = sycl::copysign(sycl::isinf(__b) ? value_type(1) : value_type(0), __b); + } else if ((cplex::detail::isinf(__a) || cplex::detail::isinf(__b)) && + cplex::detail::isfinite(__c) && cplex::detail::isfinite(__d)) { + __a = sycl::copysign( + cplex::detail::isinf(__a) ? value_type(1) : value_type(0), __a); + __b = sycl::copysign( + cplex::detail::isinf(__b) ? value_type(1) : value_type(0), __b); __x = value_type(INFINITY) * (__a * __c + __b * __d); __y = value_type(INFINITY) * (__b * __c - __a * __d); - } else if (sycl::isinf(__logbw) && __logbw > value_type(0) && - sycl::isfinite(__a) && sycl::isfinite(__b)) { - __c = sycl::copysign(sycl::isinf(__c) ? value_type(1) : value_type(0), __c); - __d = sycl::copysign(sycl::isinf(__d) ? value_type(1) : value_type(0), __d); + } else if (cplex::detail::isinf(__logbw) && __logbw > value_type(0) && + cplex::detail::isfinite(__a) && cplex::detail::isfinite(__b)) { + __c = sycl::copysign( + cplex::detail::isinf(__c) ? value_type(1) : value_type(0), __c); + __d = sycl::copysign( + cplex::detail::isinf(__d) ? value_type(1) : value_type(0), __d); __x = value_type(0) * (__a * __c + __b * __d); __y = value_type(0) * (__b * __c - __a * __d); } } return complex(__x, __y); +#endif } _SYCL_EXT_CPLX_INLINE_VISIBILITY friend complex operator/(const complex &__x, value_type __y) { @@ -629,8 +722,9 @@ class complex<_Tp, typename std::enable_if::value>::type> { } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY friend std::basic_istream<_CharT, _Traits> - &operator>>(std::basic_istream<_CharT, _Traits> &__is, complex &__x) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend std::basic_istream<_CharT, _Traits> & + operator>>(std::basic_istream<_CharT, _Traits> &__is, + complex &__x) { if (__is.good()) { ws(__is); if (__is.peek() == _CharT('(')) { @@ -675,8 +769,9 @@ class complex<_Tp, typename std::enable_if::value>::type> { } template - _SYCL_EXT_CPLX_INLINE_VISIBILITY friend std::basic_ostream<_CharT, _Traits> - &operator<<(std::basic_ostream<_CharT, _Traits> &__os, const complex &__x) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend std::basic_ostream<_CharT, _Traits> & + operator<<(std::basic_ostream<_CharT, _Traits> &__os, + const complex &__x) { std::basic_ostringstream<_CharT, _Traits> __s; __s.flags(__os.flags()); __s.imbue(__os.getloc()); @@ -685,8 +780,8 @@ class complex<_Tp, typename std::enable_if::value>::type> { return __os << __s.str(); } - _SYCL_EXT_CPLX_INLINE_VISIBILITY friend const sycl::stream - &operator<<(const sycl::stream &__ss, const complex &_x) { + _SYCL_EXT_CPLX_INLINE_VISIBILITY friend const sycl::stream & + operator<<(const sycl::stream &__ss, const complex &_x) { return __ss << "(" << _x.__re_ << "," << _x.__im_ << ")"; } }; @@ -707,7 +802,7 @@ template struct __libcpp_complex_overload_traits<_Tp, false, true> { typedef _Tp _ValueType; typedef complex<_Tp> _ComplexType; }; -} +} // namespace cplex::detail // real @@ -753,18 +848,21 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY _Tp arg(const complex<_Tp> &__c) { template _SYCL_EXT_CPLX_INLINE_VISIBILITY -typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType arg(_Tp __re) { - typedef typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType _ValueType; - return sycl::atan2<_ValueType>(0, __re); + typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType + arg(_Tp __re) { + typedef + typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType + _ValueType; + return sycl::atan2(static_cast<_ValueType>(0), static_cast<_ValueType>(__re)); } // norm template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY _Tp norm(const complex<_Tp> &__c) { - if (sycl::isinf(__c.real())) + if (cplex::detail::isinf(__c.real())) return sycl::fabs(__c.real()); - if (sycl::isinf(__c.imag())) + if (cplex::detail::isinf(__c.imag())) return sycl::fabs(__c.imag()); return __c.real() * __c.real() + __c.imag() * __c.imag(); } @@ -773,7 +871,9 @@ template _SYCL_EXT_CPLX_INLINE_VISIBILITY typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType norm(_Tp __re) { - typedef typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType _ValueType; + typedef + typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ValueType + _ValueType; return static_cast<_ValueType>(__re) * __re; } @@ -788,8 +888,8 @@ template _SYCL_EXT_CPLX_INLINE_VISIBILITY typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ComplexType conj(_Tp __re) { - typedef - typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; + typedef typename cplex::detail::__libcpp_complex_overload_traits< + _Tp>::_ComplexType _ComplexType; return _ComplexType(__re); } @@ -798,19 +898,20 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> proj(const complex<_Tp> &__c) { complex<_Tp> __r = __c; - if (sycl::isinf(__c.real()) || sycl::isinf(__c.imag())) + if (cplex::detail::isinf(__c.real()) || cplex::detail::isinf(__c.imag())) __r = complex<_Tp>(INFINITY, sycl::copysign(_Tp(0), __c.imag())); return __r; } template _SYCL_EXT_CPLX_INLINE_VISIBILITY - typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ComplexType - proj(_Tp __re) { - typedef typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; + typename cplex::detail::__libcpp_complex_overload_traits<_Tp>::_ComplexType + proj(_Tp __re) { + typedef typename cplex::detail::__libcpp_complex_overload_traits< + _Tp>::_ComplexType _ComplexType; - if constexpr(!std::is_integral_v<_Tp>) { - if (sycl::isinf(__re)) + if constexpr (!std::is_integral_v<_Tp>) { + if (cplex::detail::isinf(__re)) __re = sycl::fabs(__re); } @@ -820,24 +921,25 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY // polar template ::value>> -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> polar(const _Tp &__rho, const _Tp &__theta = _Tp()) { - if (sycl::isnan(__rho) || sycl::signbit(__rho)) +_SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> +polar(const _Tp &__rho, const _Tp &__theta = _Tp()) { + if (cplex::detail::isnan(__rho) || sycl::signbit(__rho)) return complex<_Tp>(_Tp(NAN), _Tp(NAN)); - if (sycl::isnan(__theta)) { - if (sycl::isinf(__rho)) + if (cplex::detail::isnan(__theta)) { + if (cplex::detail::isinf(__rho)) return complex<_Tp>(__rho, __theta); return complex<_Tp>(__theta, __theta); } - if (sycl::isinf(__theta)) { - if (sycl::isinf(__rho)) + if (cplex::detail::isinf(__theta)) { + if (cplex::detail::isinf(__rho)) return complex<_Tp>(__rho, _Tp(NAN)); return complex<_Tp>(_Tp(NAN), _Tp(NAN)); } _Tp __x = __rho * sycl::cos(__theta); - if (sycl::isnan(__x)) + if (cplex::detail::isnan(__x)) __x = 0; _Tp __y = __rho * sycl::sin(__theta); - if (sycl::isnan(__y)) + if (cplex::detail::isnan(__y)) __y = 0; return complex<_Tp>(__x, __y); } @@ -860,14 +962,14 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> log10(const complex<_Tp> &__x) { template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> sqrt(const complex<_Tp> &__x) { - if (sycl::isinf(__x.imag())) + if (cplex::detail::isinf(__x.imag())) return complex<_Tp>(_Tp(INFINITY), __x.imag()); - if (sycl::isinf(__x.real())) { + if (cplex::detail::isinf(__x.real())) { if (__x.real() > _Tp(0)) - return complex<_Tp>(__x.real(), sycl::isnan(__x.imag()) + return complex<_Tp>(__x.real(), cplex::detail::isnan(__x.imag()) ? __x.imag() : sycl::copysign(_Tp(0), __x.imag())); - return complex<_Tp>(sycl::isnan(__x.imag()) ? __x.imag() : _Tp(0), + return complex<_Tp>(cplex::detail::isnan(__x.imag()) ? __x.imag() : _Tp(0), sycl::copysign(__x.real(), __x.imag())); } return polar(sycl::sqrt(abs(__x)), arg(__x) / _Tp(2)); @@ -882,12 +984,12 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> exp(const complex<_Tp> &__x) { return complex<_Tp>(sycl::exp(__x.real()), sycl::copysign(_Tp(0), __x.imag())); } - if (sycl::isinf(__x.real())) { + if (cplex::detail::isinf(__x.real())) { if (__x.real() < _Tp(0)) { - if (!sycl::isfinite(__i)) + if (!cplex::detail::isfinite(__i)) __i = _Tp(1); - } else if (__i == 0 || !sycl::isfinite(__i)) { - if (sycl::isinf(__i)) + } else if (__i == 0 || !cplex::detail::isfinite(__i)) { + if (cplex::detail::isinf(__i)) __i = _Tp(NAN); return complex<_Tp>(__x.real(), __i); } @@ -906,29 +1008,33 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> pow(const complex<_Tp> &__x, template ::value>> -_SYCL_EXT_CPLX_INLINE_VISIBILITY complex::type> -pow(const complex<_Tp> &__x, const complex<_Up> &__y) { - typedef complex::type> result_type; +_SYCL_EXT_CPLX_INLINE_VISIBILITY + complex::type> + pow(const complex<_Tp> &__x, const complex<_Up> &__y) { + typedef complex::type> + result_type; return pow(result_type(__x), result_type(__y)); } template ::value>> -_SYCL_EXT_CPLX_INLINE_VISIBILITY - typename std::enable_if::value, - complex::type>>::type - pow(const complex<_Tp> &__x, const _Up &__y) { - typedef complex::type> result_type; +_SYCL_EXT_CPLX_INLINE_VISIBILITY typename std::enable_if< + is_genfloat<_Up>::value, + complex::type>>::type +pow(const complex<_Tp> &__x, const _Up &__y) { + typedef complex::type> + result_type; return pow(result_type(__x), result_type(__y)); } template ::value>> -_SYCL_EXT_CPLX_INLINE_VISIBILITY - typename std::enable_if::value, - complex::type>>::type - pow(const _Tp &__x, const complex<_Up> &__y) { - typedef complex::type> result_type; +_SYCL_EXT_CPLX_INLINE_VISIBILITY typename std::enable_if< + is_genfloat<_Up>::value, + complex::type>>::type +pow(const _Tp &__x, const complex<_Up> &__y) { + typedef complex::type> + result_type; return pow(result_type(__x), result_type(__y)); } @@ -940,29 +1046,29 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> __sqr(const complex<_Tp> &__x) { return complex<_Tp>((__x.real() - __x.imag()) * (__x.real() + __x.imag()), _Tp(2) * __x.real() * __x.imag()); } -} +} // namespace cplex::detail // asinh template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> asinh(const complex<_Tp> &__x) { - const _Tp __pi(sycl::atan2(+0., -0.)); - if (sycl::isinf(__x.real())) { - if (sycl::isnan(__x.imag())) + const _Tp __pi(sycl::atan2(_Tp(+0.), _Tp(-0.))); + if (cplex::detail::isinf(__x.real())) { + if (cplex::detail::isnan(__x.imag())) return __x; - if (sycl::isinf(__x.imag())) + if (cplex::detail::isinf(__x.imag())) return complex<_Tp>(__x.real(), sycl::copysign(__pi * _Tp(0.25), __x.imag())); return complex<_Tp>(__x.real(), sycl::copysign(_Tp(0), __x.imag())); } - if (sycl::isnan(__x.real())) { - if (sycl::isinf(__x.imag())) + if (cplex::detail::isnan(__x.real())) { + if (cplex::detail::isinf(__x.imag())) return complex<_Tp>(__x.imag(), __x.real()); if (__x.imag() == 0) return __x; return complex<_Tp>(__x.real(), __x.real()); } - if (sycl::isinf(__x.imag())) + if (cplex::detail::isinf(__x.imag())) return complex<_Tp>(sycl::copysign(__x.imag(), __x.real()), sycl::copysign(__pi / _Tp(2), __x.imag())); complex<_Tp> __z = log(__x + sqrt(cplex::detail::__sqr(__x) + _Tp(1))); @@ -974,11 +1080,11 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> asinh(const complex<_Tp> &__x) { template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> acosh(const complex<_Tp> &__x) { - const _Tp __pi(sycl::atan2(+0., -0.)); - if (sycl::isinf(__x.real())) { - if (sycl::isnan(__x.imag())) + const _Tp __pi(sycl::atan2(_Tp(+0.), _Tp(-0.))); + if (cplex::detail::isinf(__x.real())) { + if (cplex::detail::isnan(__x.imag())) return complex<_Tp>(sycl::fabs(__x.real()), __x.imag()); - if (sycl::isinf(__x.imag())) { + if (cplex::detail::isinf(__x.imag())) { if (__x.real() > 0) return complex<_Tp>(__x.real(), sycl::copysign(__pi * _Tp(0.25), __x.imag())); @@ -990,12 +1096,12 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> acosh(const complex<_Tp> &__x) { return complex<_Tp>(-__x.real(), sycl::copysign(__pi, __x.imag())); return complex<_Tp>(__x.real(), sycl::copysign(_Tp(0), __x.imag())); } - if (sycl::isnan(__x.real())) { - if (sycl::isinf(__x.imag())) + if (cplex::detail::isnan(__x.real())) { + if (cplex::detail::isinf(__x.imag())) return complex<_Tp>(sycl::fabs(__x.imag()), __x.real()); return complex<_Tp>(__x.real(), __x.real()); } - if (sycl::isinf(__x.imag())) + if (cplex::detail::isinf(__x.imag())) return complex<_Tp>(sycl::fabs(__x.imag()), sycl::copysign(__pi / _Tp(2), __x.imag())); complex<_Tp> __z = log(__x + sqrt(cplex::detail::__sqr(__x) - _Tp(1))); @@ -1007,20 +1113,20 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> acosh(const complex<_Tp> &__x) { template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> atanh(const complex<_Tp> &__x) { - const _Tp __pi(sycl::atan2(+0., -0.)); - if (sycl::isinf(__x.imag())) { + const _Tp __pi(sycl::atan2(_Tp(+0.), _Tp(-0.))); + if (cplex::detail::isinf(__x.imag())) { return complex<_Tp>(sycl::copysign(_Tp(0), __x.real()), sycl::copysign(__pi / _Tp(2), __x.imag())); } - if (sycl::isnan(__x.imag())) { - if (sycl::isinf(__x.real()) || __x.real() == 0) + if (cplex::detail::isnan(__x.imag())) { + if (cplex::detail::isinf(__x.real()) || __x.real() == 0) return complex<_Tp>(sycl::copysign(_Tp(0), __x.real()), __x.imag()); return complex<_Tp>(__x.imag(), __x.imag()); } - if (sycl::isnan(__x.real())) { + if (cplex::detail::isnan(__x.real())) { return complex<_Tp>(__x.real(), __x.real()); } - if (sycl::isinf(__x.real())) { + if (cplex::detail::isinf(__x.real())) { return complex<_Tp>(sycl::copysign(_Tp(0), __x.real()), sycl::copysign(__pi / _Tp(2), __x.imag())); } @@ -1037,11 +1143,11 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> atanh(const complex<_Tp> &__x) { template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> sinh(const complex<_Tp> &__x) { - if (sycl::isinf(__x.real()) && !sycl::isfinite(__x.imag())) + if (cplex::detail::isinf(__x.real()) && !cplex::detail::isfinite(__x.imag())) return complex<_Tp>(__x.real(), _Tp(NAN)); - if (__x.real() == 0 && !sycl::isfinite(__x.imag())) + if (__x.real() == 0 && !cplex::detail::isfinite(__x.imag())) return complex<_Tp>(__x.real(), _Tp(NAN)); - if (__x.imag() == 0 && !sycl::isfinite(__x.real())) + if (__x.imag() == 0 && !cplex::detail::isfinite(__x.real())) return __x; return complex<_Tp>(sycl::sinh(__x.real()) * sycl::cos(__x.imag()), sycl::cosh(__x.real()) * sycl::sin(__x.imag())); @@ -1051,13 +1157,13 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> sinh(const complex<_Tp> &__x) { template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> cosh(const complex<_Tp> &__x) { - if (sycl::isinf(__x.real()) && !sycl::isfinite(__x.imag())) + if (cplex::detail::isinf(__x.real()) && !cplex::detail::isfinite(__x.imag())) return complex<_Tp>(sycl::fabs(__x.real()), _Tp(NAN)); - if (__x.real() == 0 && !sycl::isfinite(__x.imag())) + if (__x.real() == 0 && !cplex::detail::isfinite(__x.imag())) return complex<_Tp>(_Tp(NAN), __x.real()); if (__x.real() == 0 && __x.imag() == 0) return complex<_Tp>(_Tp(1), __x.imag()); - if (__x.imag() == 0 && !sycl::isfinite(__x.real())) + if (__x.imag() == 0 && !cplex::detail::isfinite(__x.real())) return complex<_Tp>(sycl::fabs(__x.real()), __x.imag()); return complex<_Tp>(sycl::cosh(__x.real()) * sycl::cos(__x.imag()), sycl::sinh(__x.real()) * sycl::sin(__x.imag())); @@ -1067,19 +1173,19 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> cosh(const complex<_Tp> &__x) { template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> tanh(const complex<_Tp> &__x) { - if (sycl::isinf(__x.real())) { - if (!sycl::isfinite(__x.imag())) + if (cplex::detail::isinf(__x.real())) { + if (!cplex::detail::isfinite(__x.imag())) return complex<_Tp>(sycl::copysign(_Tp(1), __x.real()), _Tp(0)); return complex<_Tp>(sycl::copysign(_Tp(1), __x.real()), sycl::copysign(_Tp(0), sycl::sin(_Tp(2) * __x.imag()))); } - if (sycl::isnan(__x.real()) && __x.imag() == 0) + if (cplex::detail::isnan(__x.real()) && __x.imag() == 0) return __x; _Tp __2r(_Tp(2) * __x.real()); _Tp __2i(_Tp(2) * __x.imag()); _Tp __d(sycl::cosh(__2r) + sycl::cos(__2i)); _Tp __2rsh(sycl::sinh(__2r)); - if (sycl::isinf(__2rsh) && sycl::isinf(__d)) + if (cplex::detail::isinf(__2rsh) && cplex::detail::isinf(__d)) return complex<_Tp>(__2rsh > _Tp(0) ? _Tp(1) : _Tp(-1), __2i > _Tp(0) ? _Tp(0) : _Tp(-0.)); return complex<_Tp>(__2rsh / __d, sycl::sin(__2i) / __d); @@ -1097,11 +1203,11 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> asin(const complex<_Tp> &__x) { template ::value>> _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> acos(const complex<_Tp> &__x) { - const _Tp __pi(sycl::atan2(+0., -0.)); - if (sycl::isinf(__x.real())) { - if (sycl::isnan(__x.imag())) + const _Tp __pi(sycl::atan2(_Tp(+0.), _Tp(-0.))); + if (cplex::detail::isinf(__x.real())) { + if (cplex::detail::isnan(__x.imag())) return complex<_Tp>(__x.imag(), __x.real()); - if (sycl::isinf(__x.imag())) { + if (cplex::detail::isinf(__x.imag())) { if (__x.real() < _Tp(0)) return complex<_Tp>(_Tp(0.75) * __pi, -__x.imag()); return complex<_Tp>(_Tp(0.25) * __pi, -__x.imag()); @@ -1112,14 +1218,14 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> acos(const complex<_Tp> &__x) { return complex<_Tp>(_Tp(0), sycl::signbit(__x.imag()) ? __x.real() : -__x.real()); } - if (sycl::isnan(__x.real())) { - if (sycl::isinf(__x.imag())) + if (cplex::detail::isnan(__x.real())) { + if (cplex::detail::isinf(__x.imag())) return complex<_Tp>(__x.real(), -__x.imag()); return complex<_Tp>(__x.real(), __x.real()); } - if (sycl::isinf(__x.imag())) + if (cplex::detail::isinf(__x.imag())) return complex<_Tp>(__pi / _Tp(2), -__x.imag()); - if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag()))) + if (__x.real() == 0 && (__x.imag() == 0 || cplex::detail::isnan(__x.imag()))) return complex<_Tp>(__pi / _Tp(2), -__x.imag()); complex<_Tp> __z = log(__x + sqrt(cplex::detail::__sqr(__x) - _Tp(1))); if (sycl::signbit(__x.imag())) @@ -1160,6 +1266,943 @@ _SYCL_EXT_CPLX_INLINE_VISIBILITY complex<_Tp> tan(const complex<_Tp> &__x) { _SYCL_EXT_CPLX_END_NAMESPACE_STD +//////////////////////////////////////////////////////////////////////////////// +// MARRAY IMPLEMENTATION +//////////////////////////////////////////////////////////////////////////////// + +_SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD + +template struct is_mgencomplex : std::false_type {}; + +template +struct is_mgencomplex> + : std::integral_constant> {}; + +template +inline constexpr bool is_mgencomplex_v = is_mgencomplex::value; + +_SYCL_EXT_CPLX_END_NAMESPACE_STD + +_SYCL_MARRAY_BEGIN_NAMESPACE + +// marray of complex class specialisation +template +class marray, NumElements> { +private: + using ComplexDataT = sycl::ext::cplx::complex; + +public: + using value_type = ComplexDataT; + using reference = ComplexDataT &; + using const_reference = const ComplexDataT &; + using iterator = ComplexDataT *; + using const_iterator = const ComplexDataT *; + +private: + value_type MData[NumElements]; + +public: + constexpr marray() : MData{} {}; + + explicit constexpr marray(const ComplexDataT &arg) { + for (size_t i = 0; i < NumElements; ++i) + MData[i] = arg; + } + + template + constexpr marray(const ArgTN &...args) : MData{args...} {}; + + constexpr marray(const marray &rhs) = default; + constexpr marray(marray &&rhs) = default; + + // Available only when: NumElements == 1 + template > + operator ComplexDataT() const { + return MData[0]; + } + + static constexpr std::size_t size() noexcept { return NumElements; } + + marray real() const { + sycl::marray rtn; + + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i] = MData[i].real(); + } + + return rtn; + } + + marray imag() const { + sycl::marray rtn; + + for (std::size_t i = 0; i < NumElements; ++i) { + rtn[i] = MData[i].imag(); + } + + return rtn; + } + + // subscript operator + reference operator[](std::size_t i) { return MData[i]; } + const_reference operator[](std::size_t i) const { return MData[i]; } + + marray &operator=(const marray &rhs) = default; + marray &operator=(const ComplexDataT &rhs) { + for (std::size_t i = 0; i < NumElements; ++i) + MData[i] = rhs; + + return *this; + } + + // iterator functions + iterator begin() { return MData; } + const_iterator begin() const { return MData; } + + iterator end() { return MData + NumElements; } + const_iterator end() const { return MData + NumElements; } + + // OP is: +, -, *, / +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs[i]; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const marray &lhs, const ComplexDataT &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const ComplexDataT &lhs, const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs op rhs[i]; \ + \ + return rtn; \ + } + + OP(+) + OP(-) + OP(*) + OP(/) + +#undef OP + + // OP is: % + friend marray operator%(const marray &lhs, const marray &rhs) = delete; + friend marray operator%(const marray &lhs, const ComplexDataT &rhs) = delete; + friend marray operator%(const ComplexDataT &lhs, const marray &rhs) = delete; + + // OP is: +=, -=, *=, /= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) \ + lhs[i] op rhs[i]; \ + \ + return lhs; \ + } \ + \ + friend marray &operator op(marray &lhs, const ComplexDataT &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) \ + lhs[i] op rhs; \ + \ + return lhs; \ + } \ + friend marray &operator op(ComplexDataT &lhs, const marray &rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) \ + lhs[i] op rhs; \ + \ + return lhs; \ + } + + OP(+=) + OP(-=) + OP(*=) + OP(/=) + +#undef OP + + // OP is: %= + friend marray &operator%=(marray &lhs, const marray &rhs) = delete; + friend marray &operator%=(marray &lhs, const ComplexDataT &rhs) = delete; + friend marray &operator%=(ComplexDataT &lhs, const marray &rhs) = delete; + +// OP is: ++, -- +#define OP(op) \ + friend marray operator op(marray &lhs, int) = delete; \ + friend marray &operator op(marray &rhs) = delete; + + OP(++) + OP(--) + +#undef OP + +// OP is: unary +, unary - +#define OP(op) \ + friend marray operator op( \ + const marray &rhs) { \ + marray rtn; \ + \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = op rhs[i]; \ + } \ + \ + return rtn; \ + } + + OP(+) + OP(-) + +#undef OP + +// OP is: &, |, ^ +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, const ComplexDataT &rhs) = \ + delete; + + OP(&) + OP(|) + OP(^) + +#undef OP + +// OP is: &=, |=, ^= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) = delete; \ + friend marray &operator op(marray &lhs, const ComplexDataT &rhs) = delete; \ + friend marray &operator op(ComplexDataT &lhs, const marray &rhs) = delete; + + OP(&=) + OP(|=) + OP(^=) + +#undef OP + +// OP is: &&, || +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) = delete; \ + friend marray operator op( \ + const marray &lhs, const ComplexDataT &rhs) = delete; \ + friend marray operator op(const ComplexDataT &lhs, \ + const marray &rhs) = delete; + + OP(&&) + OP(||) + +#undef OP + +// OP is: <<, >> +#define OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) = delete; \ + friend marray operator op(const marray &lhs, const ComplexDataT &rhs) = \ + delete; \ + friend marray operator op(const ComplexDataT &lhs, const marray &rhs) = \ + delete; + + OP(<<) + OP(>>) + +#undef OP + +// OP is: <<=, >>= +#define OP(op) \ + friend marray &operator op(marray &lhs, const marray &rhs) = delete; \ + friend marray &operator op(marray &lhs, const ComplexDataT &rhs) = delete; + + OP(<<=) + OP(>>=) + +#undef OP + + // OP is: ==, != +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs[i]; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const marray &lhs, \ + const ComplexDataT &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs[i] op rhs; \ + \ + return rtn; \ + } \ + \ + friend marray operator op(const ComplexDataT &lhs, \ + const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = lhs op rhs[i]; \ + \ + return rtn; \ + } + + OP(==) + OP(!=) + +#undef OP + + // OP is: <, >, <=, >= +#define OP(op) \ + friend marray operator op(const marray &lhs, \ + const marray &rhs) = delete; \ + friend marray operator op( \ + const marray &lhs, const ComplexDataT &rhs) = delete; \ + friend marray operator op(const ComplexDataT &lhs, \ + const marray &rhs) = delete; + + OP(<); + OP(>); + OP(<=); + OP(>=); + +#undef OP + + friend marray operator~(const marray &v) = delete; + + friend marray operator!(const marray &v) = delete; +}; + +_SYCL_MARRAY_END_NAMESPACE + +_SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD + +// Math marray overloads + +#define MATH_OP_ONE_PARAM(math_func, rtn_type, arg_type) \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const sycl::marray &x) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = sycl::ext::cplx::math_func(x[i]); \ + \ + return rtn; \ + } + +MATH_OP_ONE_PARAM(abs, T, complex); +MATH_OP_ONE_PARAM(acos, complex, complex); +MATH_OP_ONE_PARAM(asin, complex, complex); +MATH_OP_ONE_PARAM(atan, complex, complex); +MATH_OP_ONE_PARAM(acosh, complex, complex); +MATH_OP_ONE_PARAM(asinh, complex, complex); +MATH_OP_ONE_PARAM(atanh, complex, complex); +MATH_OP_ONE_PARAM(arg, T, complex); +MATH_OP_ONE_PARAM(conj, complex, complex); +MATH_OP_ONE_PARAM(cos, complex, complex); +MATH_OP_ONE_PARAM(cosh, complex, complex); +MATH_OP_ONE_PARAM(exp, complex, complex); +MATH_OP_ONE_PARAM(log, complex, complex); +MATH_OP_ONE_PARAM(log10, complex, complex); +MATH_OP_ONE_PARAM(norm, T, complex); +MATH_OP_ONE_PARAM(proj, complex, complex); +MATH_OP_ONE_PARAM(proj, complex, T); +MATH_OP_ONE_PARAM(sin, complex, complex); +MATH_OP_ONE_PARAM(sinh, complex, complex); +MATH_OP_ONE_PARAM(sqrt, complex, complex); +MATH_OP_ONE_PARAM(tan, complex, complex); +MATH_OP_ONE_PARAM(tanh, complex, complex); + +#undef MATH_OP_ONE_PARAM + +#define MATH_OP_TWO_PARAM(math_func, rtn_type, arg_type1, arg_type2) \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const sycl::marray &x, \ + const sycl::marray &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = sycl::ext::cplx::math_func(x[i], y[i]); \ + \ + return rtn; \ + } \ + \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const sycl::marray &x, \ + const arg_type2 &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = sycl::ext::cplx::math_func(x[i], y); \ + \ + return rtn; \ + } \ + \ + template ::value || \ + is_gencomplex::value>> \ + _SYCL_EXT_CPLX_INLINE_VISIBILITY sycl::marray \ + math_func(const arg_type1 &x, \ + const sycl::marray &y) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) \ + rtn[i] = math_func(x, y[i]); \ + \ + return rtn; \ + } + +MATH_OP_TWO_PARAM(pow, complex, complex, T); +MATH_OP_TWO_PARAM(pow, complex, complex, complex); +MATH_OP_TWO_PARAM(pow, complex, T, complex); + +#undef MATH_OP_TWO_PARAM + +// Special definition as polar requires default argument + +template ::value>> +_SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const sycl::marray &rho, + const sycl::marray &theta) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = sycl::ext::cplx::polar(rho[i], theta[i]); + + return rtn; +} + +template ::value>> +_SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const sycl::marray &rho, const T &theta = 0) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = sycl::ext::cplx::polar(rho[i], theta); + + return rtn; +} + +template ::value>> +_SYCL_EXT_CPLX_INLINE_VISIBILITY + sycl::marray, NumElements> + polar(const T &rho, const sycl::marray &theta) { + sycl::marray, NumElements> rtn; + for (std::size_t i = 0; i < NumElements; ++i) + rtn[i] = sycl::ext::cplx::polar(rho, theta[i]); + + return rtn; +} + +//////////////////////////////////////////////////////////////////////////////// +// GROUP ALGORITMHS +//////////////////////////////////////////////////////////////////////////////// + +namespace cplex::detail { + +/// Helper traits to check if the type is a sycl::plus +template +struct is_plus + : std::integral_constant>> { +}; +template +inline constexpr bool is_plus_v = is_plus::value; + +/// Helper traits to check if the type is a sycl:multiplies +template +struct is_multiplies + : std::integral_constant< + bool, std::is_same_v>> {}; +template +inline constexpr bool is_multiplies_v = is_multiplies::value; + +/// Wrapper trait to check if the binary operation is supported +template +struct is_binary_op_supported + : std::integral_constant::value || + detail::is_multiplies::value)> { +}; +template +inline constexpr bool is_binary_op_supported_v = + is_binary_op_supported::value; + +/// Helper functions to get the init for sycl::plus binary operation when the +/// type is a gencomplex +template +std::enable_if_t<(sycl::ext::cplx::is_gencomplex_v && + detail::is_plus_v), + T> +get_init() { + return T{0, 0}; +} +/// Helper functions to get the init for sycl::multiply binary operation when +/// the type is a gencomplex +template +std::enable_if_t<(sycl::ext::cplx::is_gencomplex_v && + detail::is_multiplies::value), + T> +get_init() { + return T{1, 0}; +} +/// Helper functions to get the init for sycl::plus binary operation when the +/// type is a mgencomplex +template +std::enable_if_t< + (is_mgencomplex_v && detail::is_plus::value), T> +get_init() { + using Complex = typename T::value_type; + + T result; + std::fill(result.begin(), result.end(), Complex{0, 0}); + return result; +} +/// Helper functions to get the init for sycl::multiply binary operation when +/// the type is a mgencomplex +template +std::enable_if_t< + (is_mgencomplex_v && detail::is_multiplies::value), T> +get_init() { + using Complex = typename T::value_type; + + T result; + std::fill(result.begin(), result.end(), Complex{1, 0}); + return result; +} + +} // namespace cplex::detail + +/* REDUCE_OVER_GROUP'S OVERLOADS */ + +/// Complex specialization +template > && is_genfloat_v && + is_genfloat_v && + cplex::detail::is_binary_op_supported_v>> +complex reduce_over_group(Group g, complex x, complex init, + BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + complex result; + + result.real(sycl::reduce_over_group(g, x.real(), init.real(), binary_op)); + result.imag(sycl::reduce_over_group(g, x.imag(), init.imag(), binary_op)); + + return result; +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/// Marray specialization +template > && is_gencomplex_v && + is_gencomplex_v && + cplex::detail::is_binary_op_supported_v>> +sycl::marray reduce_over_group(Group g, sycl::marray x, + sycl::marray init, + BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + sycl::marray result; + + cplex::detail::loop([&](size_t s) { + result[s] = reduce_over_group(g, x[s], init[s], binary_op); + }); + + return result; +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/// Marray and Complex specialization +template > && + (is_gencomplex_v || is_mgencomplex_v)&&cplex::detail:: + is_binary_op_supported_v>> +T reduce_over_group(Group g, T x, BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + auto init = cplex::detail::get_init(); + + return reduce_over_group(g, x, init, binary_op); +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/* JOINT_REDUCE'S OVERLOADS */ + +/// Marray and Complex specialization +template > && + sycl::detail::is_pointer::value && + (is_gencomplex_v> || + is_mgencomplex_v>)&&(is_gencomplex_v || is_mgencomplex_v)&&cplex:: + detail::is_binary_op_supported_v>> +T joint_reduce(Group g, Ptr first, Ptr last, T init, + BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + auto partial = cplex::detail::get_init(); + + sycl::detail::for_each( + g, first, last, + [&](const typename sycl::detail::remove_pointer::type &x) { + partial = binary_op(partial, x); + }); + + return reduce_over_group(g, partial, init, binary_op); +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/// Marray and Complex specialization +template > && + sycl::detail::is_pointer::value && + (is_gencomplex_v> || + is_mgencomplex_v>)&&cplex:: + detail::is_binary_op_supported_v>> +typename sycl::detail::remove_pointer_t +joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + using T = typename sycl::detail::remove_pointer_t; + + auto init = cplex::detail::get_init(); + + return joint_reduce(g, first, last, init, binary_op); +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/* INCLUSIVE_SCAN_OVER_GROUP'S OVERLOADS */ + +/// Complex specialization +template > && is_genfloat_v && + is_genfloat_v && + cplex::detail::is_binary_op_supported_v>> +complex inclusive_scan_over_group(Group g, complex x, + BinaryOperation binary_op, + complex init) { +#ifdef __SYCL_DEVICE_ONLY__ + complex result; + + result.real( + sycl::inclusive_scan_over_group(g, x.real(), binary_op, init.real())); + result.imag( + sycl::inclusive_scan_over_group(g, x.imag(), binary_op, init.imag())); + + return result; +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/// Marray specialization +template > && is_gencomplex_v && + is_gencomplex_v && + cplex::detail::is_binary_op_supported_v>> +sycl::marray inclusive_scan_over_group(Group g, sycl::marray x, + BinaryOperation binary_op, + sycl::marray init) { +#ifdef __SYCL_DEVICE_ONLY__ + sycl::marray result; + + cplex::detail::loop([&](size_t s) { + result[s] = inclusive_scan_over_group(g, x[s], binary_op, init[s]); + }); + + return result; +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/// Marray and Complex specialization +template > && + (is_gencomplex_v || is_mgencomplex_v)&&cplex::detail:: + is_binary_op_supported_v>> +T inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + auto init = cplex::detail::get_init(); + + return inclusive_scan_over_group(g, x, binary_op, init); +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/* JOINT_INCLUSIVE_SCAN'S OVERLOADS */ + +/// Complex specialization +template > && + sycl::detail::is_pointer::value && + sycl::detail::is_pointer::value && + (is_gencomplex_v> || + is_mgencomplex_v>)&&(is_gencomplex_v> || + is_mgencomplex_v>)&&(is_gencomplex_v || + is_mgencomplex_v)&&cplex:: + detail::is_binary_op_supported_v>> +OutPtr joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, + BinaryOperation binary_op, T init) { +#ifdef __SYCL_DEVICE_ONLY__ + std::ptrdiff_t offset = g.get_local_linear_id(); + std::ptrdiff_t stride = g.get_local_linear_range(); + std::ptrdiff_t N = last - first; + + auto roundup = [=](const std::ptrdiff_t &v, + const std::ptrdiff_t &divisor) -> std::ptrdiff_t { + return ((v + divisor - 1) / divisor) * divisor; + }; + + typename std::remove_const_t> + x; + typename sycl::detail::remove_pointer_t carry = init; + + for (std::ptrdiff_t chunk = 0; chunk < roundup(N, stride); chunk += stride) { + std::ptrdiff_t i = chunk + offset; + + if (i < N) + x = first[i]; + + typename sycl::detail::remove_pointer_t out = + inclusive_scan_over_group(g, x, binary_op, carry); + + if (i < N) + result[i] = out; + + carry = sycl::group_broadcast(g, out, stride - 1); + } + return result + N; +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/// Complex specialization +template < + typename Group, typename InPtr, typename OutPtr, class BinaryOperation, + typename = std::enable_if_t< + sycl::is_group_v> && + sycl::detail::is_pointer::value && + sycl::detail::is_pointer::value && + (is_gencomplex_v> || + is_mgencomplex_v>)&&(is_gencomplex_v> || + is_mgencomplex_v< + sycl::detail::remove_pointer_t>)&&cplex:: + detail::is_binary_op_supported_v>> +OutPtr joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, + BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + using T = typename sycl::detail::remove_pointer_t; + + auto init = cplex::detail::get_init(); + + return joint_inclusive_scan(g, first, last, result, binary_op, init); +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/* EXCLUSIVE_SCAN_OVER_GROUP'S OVERLOADS */ + +/// Complex specialization +template > && is_genfloat_v && + is_genfloat_v && + cplex::detail::is_binary_op_supported_v>> +complex exclusive_scan_over_group(Group g, complex x, complex init, + BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + complex result; + + result.real( + sycl::exclusive_scan_over_group(g, x.real(), init.real(), binary_op)); + result.imag( + sycl::exclusive_scan_over_group(g, x.imag(), init.imag(), binary_op)); + + return result; +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/// Marray specialization +template > && is_gencomplex_v && + is_gencomplex_v && + cplex::detail::is_binary_op_supported_v>> +sycl::marray exclusive_scan_over_group(Group g, sycl::marray x, + sycl::marray init, + BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + sycl::marray result; + + cplex::detail::loop([&](size_t s) { + result[s] = exclusive_scan_over_group(g, x[s], init[s], binary_op); + }); + + return result; +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/// Marray and Complex specialization +template > && + (is_gencomplex_v || is_mgencomplex_v)&&cplex::detail:: + is_binary_op_supported_v>> +T exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + auto init = cplex::detail::get_init(); + + return exclusive_scan_over_group(g, x, init, binary_op); +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/* JOINT_EXCLUSIVE_SCAN'S OVERLOADS */ + +/// Complex specialization +template > && + sycl::detail::is_pointer::value && + sycl::detail::is_pointer::value && + (is_gencomplex_v> || + is_mgencomplex_v>)&&(is_gencomplex_v> || + is_mgencomplex_v>)&&(is_gencomplex_v || + is_mgencomplex_v)&& + // + cplex::detail::is_binary_op_supported_v>> +OutPtr joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, + T init, BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + std::ptrdiff_t offset = g.get_local_linear_id(); + std::ptrdiff_t stride = g.get_local_linear_range(); + std::ptrdiff_t N = last - first; + + auto roundup = [=](const std::ptrdiff_t &v, + const std::ptrdiff_t &divisor) -> std::ptrdiff_t { + return ((v + divisor - 1) / divisor) * divisor; + }; + + typename std::remove_const_t> + x; + typename sycl::detail::remove_pointer_t carry = init; + + for (std::ptrdiff_t chunk = 0; chunk < roundup(N, stride); chunk += stride) { + std::ptrdiff_t i = chunk + offset; + if (i < N) + x = first[i]; + + typename sycl::detail::remove_pointer_t out = + exclusive_scan_over_group(g, x, carry, binary_op); + + if (i < N) + result[i] = out; + + carry = sycl::group_broadcast(g, binary_op(out, x), stride - 1); + } + return result + N; +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +/// Complex specialization +template < + typename Group, typename InPtr, typename OutPtr, class BinaryOperation, + typename = std::enable_if_t< + sycl::is_group_v> && + sycl::detail::is_pointer::value && + sycl::detail::is_pointer::value && + (is_gencomplex_v> || + is_mgencomplex_v>)&&(is_gencomplex_v> || + is_mgencomplex_v< + sycl::detail::remove_pointer_t>)&&cplex:: + detail::is_binary_op_supported_v>> +OutPtr joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, + BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + using T = typename sycl::detail::remove_pointer_t; + + auto init = cplex::detail::get_init(); + + return joint_exclusive_scan(g, first, last, result, init, binary_op); +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Group algorithms are not supported on host."); +#endif +} + +_SYCL_EXT_CPLX_END_NAMESPACE_STD + +#undef _SYCL_MARRAY_BEGIN_NAMESPACE +#undef _SYCL_MARRAY_END_NAMESPACE + #undef _SYCL_EXT_CPLX_BEGIN_NAMESPACE_STD #undef _SYCL_EXT_CPLX_END_NAMESPACE_STD #undef _SYCL_EXT_CPLX_INLINE_VISIBILITY