Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 1-element vec ambiguities #38

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions include/simsycl/sycl/vec.hh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "../detail/check.hh"
#include "../detail/utils.hh"

#include <concepts>
#include <cstdint>
#include <cstdlib>
#include <type_traits>
Expand Down Expand Up @@ -69,6 +70,11 @@ template<typename T, typename DataT, int NumElements>
concept VecCompatible
= vec_like_num_elements<DataT, T>::value == 1 || vec_like_num_elements<DataT, T>::value == NumElements;

template<typename From, typename To>
concept implicitly_convertible = requires { std::is_convertible_v<From, To>; };

template<typename From, typename To>
concept explicitly_convertible = requires { static_cast<To>(std::declval<From>()); };

template<int... Is>
struct no_repeat_indices;
Expand Down Expand Up @@ -229,8 +235,9 @@ class swizzled_vec {
swizzled_vec &operator=(const swizzled_vec &) = delete;
swizzled_vec &operator=(swizzled_vec &&) = delete;

swizzled_vec &operator=(const value_type &rhs)
requires(allow_assign)
template<typename T>
swizzled_vec &operator=(const T &rhs)
requires(allow_assign && std::convertible_to<T, value_type>)
Comment on lines +238 to +240
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template<typename T>
swizzled_vec &operator=(const T &rhs)
requires(allow_assign && std::convertible_to<T, value_type>)
template<std::convertible_to<value_type> T>
swizzled_vec &operator=(const T &rhs)
requires(allow_assign)

{
for(size_t i = 0; i < num_elements; ++i) { m_elems[indices[i]] = rhs; }
return *this;
Expand Down Expand Up @@ -261,6 +268,13 @@ class swizzled_vec {
return m_elems[indices[0]];
}

template<typename T>
explicit operator T() const
requires(num_elements == 1 && detail::explicitly_convertible<value_type, T>)
{
return m_elems[indices[0]];
}

static constexpr size_t byte_size() noexcept { return sycl::vec<value_type, num_elements>::byte_size(); }

static constexpr size_t size() noexcept { return num_elements; }
Expand Down Expand Up @@ -515,7 +529,7 @@ class alignas(detail::vec_alignment_v<DataT, NumElements>) vec {

vec() = default;

explicit constexpr vec(const DataT &arg) {
explicit(num_elements > 1) constexpr vec(const DataT &arg) {
for(int i = 0; i < NumElements; ++i) { m_elems[i] = arg; }
}

Expand All @@ -528,7 +542,10 @@ class alignas(detail::vec_alignment_v<DataT, NumElements>) vec {
vec(const vec &) = default;
vec &operator=(const vec &rhs) = default;

vec &operator=(const DataT &rhs) {
template<typename T>
vec &operator=(const T &rhs)
requires(std::convertible_to<T, DataT>)
{
Comment on lines +545 to +548
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template<typename T>
vec &operator=(const T &rhs)
requires(std::convertible_to<T, DataT>)
{
template<std::convertible_to<DataT> T>
vec &operator=(const T &rhs)
{

for(int i = 0; i < NumElements; ++i) { m_elems[i] = rhs; }
return *this;
}
Expand All @@ -548,6 +565,13 @@ class alignas(detail::vec_alignment_v<DataT, NumElements>) vec {
return m_elems[0];
}

template<typename T>
explicit operator T() const
requires(NumElements == 1 && detail::explicitly_convertible<DataT, T>)
{
return m_elems[0];
}

static constexpr size_t byte_size() noexcept { return sizeof m_elems; }

static constexpr size_t size() noexcept { return NumElements; }
Expand Down Expand Up @@ -736,15 +760,17 @@ class alignas(detail::vec_alignment_v<DataT, NumElements>) vec {
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs.m_elems[i] op rhs.m_elems[i]; } \
return result; \
} \
friend vec operator op(const vec &lhs, const DataT &rhs) \
requires(enable_if) \
template<typename T> \
friend vec operator op(const vec &lhs, const T &rhs) \
requires(enable_if && std::convertible_to<T, DataT>) \
Comment on lines +763 to +765
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template<typename T> \
friend vec operator op(const vec &lhs, const T &rhs) \
requires(enable_if && std::convertible_to<T, DataT>) \
template<std::convertible_to<DataT> T> \
friend vec operator op(const vec &lhs, const T &rhs) \
requires(enable_if) \

(same for overloads below)

{ \
vec result; \
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs.m_elems[i] op rhs; } \
return result; \
} \
friend vec operator op(const DataT &lhs, const vec &rhs) \
requires(enable_if) \
template<typename T> \
friend vec operator op(const T &lhs, const vec &rhs) \
requires(enable_if && std::convertible_to<T, DataT>) \
{ \
vec result; \
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs op rhs.m_elems[i]; } \
Expand Down Expand Up @@ -778,8 +804,9 @@ class alignas(detail::vec_alignment_v<DataT, NumElements>) vec {
for(int i = 0; i < NumElements; ++i) { lhs.m_elems[i] op rhs.m_elems[rhs.indices[i]]; } \
return lhs; \
} \
friend vec &operator op(vec & lhs, const DataT & rhs) \
requires(enable_if) \
template<typename T> \
friend vec &operator op(vec & lhs, const T & rhs) \
requires(enable_if && std::convertible_to<T, DataT>) \
{ \
for(int i = 0; i < NumElements; ++i) { lhs.m_elems[i] op rhs; } \
return lhs; \
Expand Down Expand Up @@ -847,12 +874,18 @@ class alignas(detail::vec_alignment_v<DataT, NumElements>) vec {
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs.m_elems[i] op rhs.m_elems[i]; } \
return result; \
} \
friend vec<decltype(DataT {} op DataT{}), NumElements> operator op(const vec & lhs, const DataT & rhs) { \
template<typename T> \
friend vec<decltype(DataT {} op DataT{}), NumElements> operator op(const vec & lhs, const T & rhs) \
requires(std::convertible_to<T, DataT>) \
{ \
vec<decltype(DataT {} op DataT{}), NumElements> result; \
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs.m_elems[i] op rhs; } \
return result; \
} \
friend vec<decltype(DataT {} op DataT{}), NumElements> operator op(const DataT & lhs, const vec & rhs) { \
template<typename T> \
friend vec<decltype(DataT {} op DataT{}), NumElements> operator op(const T & lhs, const vec & rhs) \
requires(std::convertible_to<T, DataT>) \
{ \
vec<decltype(DataT {} op DataT{}), NumElements> result; \
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs op rhs.m_elems[i]; } \
return result; \
Expand Down
Loading