diff --git a/include/ygm/container/detail/array_impl.hpp b/include/ygm/container/detail/array_impl.hpp index d1464b3f..0caf6826 100644 --- a/include/ygm/container/detail/array_impl.hpp +++ b/include/ygm/container/detail/array_impl.hpp @@ -7,6 +7,7 @@ #include #include #include +#include namespace ygm::container::detail { @@ -121,9 +122,23 @@ class array_impl { template void for_all(Function fn) { m_comm.barrier(); - for (int i = 0; i < m_local_vec.size(); ++i) { - index_type g_index = global_index(i); - fn(g_index, m_local_vec[i]); + local_for_all(fn); + } + + template + void local_for_all(Function fn) { + if constexpr (std::is_invocable()) { + for (int i = 0; i < m_local_vec.size(); ++i) { + index_type g_index = global_index(i); + fn(g_index, m_local_vec[i]); + } + } else if constexpr (std::is_invocable()) { + std::for_each(std::begin(m_local_vec), std::end(m_local_vec), fn); + } else { + static_assert(ygm::detail::always_false<>, + "local array lambda must be invocable with (const " + "index_type, value_type &) or (value_type &) signatures"); } } diff --git a/test/test_array.cpp b/test/test_array.cpp index eee1bd84..cb2fb31d 100644 --- a/test/test_array.cpp +++ b/test/test_array.cpp @@ -111,5 +111,28 @@ int main(int argc, char **argv) { }); } + // Test value-only for_all + { + int size = 64; + + ygm::container::array arr(world, size); + + if (world.rank0()) { + for (int i = 0; i < size; ++i) { + arr.async_set(i, 1); + } + } + + world.barrier(); + + for (int i = 0; i < size; ++i) { + arr.async_increment(i); + } + + arr.for_all([&world](const auto value) { + ASSERT_RELEASE(value == world.size() + 1); + }); + } + return 0; }