Skip to content

Commit

Permalink
Update finally to support value channel references (#545)
Browse files Browse the repository at this point in the history
This aligns with the general idea that senders should be able to send
references and values, so the adaptor algorithms should preserve the
types in value channels when there is no specific requirement to modify
the value channel type.

In making this change, the CreateTest.AwaitTest test failed to compile.
This appears to be the create operation accepting universal reference
`Ts&&...` arguments which are then forwarded to the underlying receiver's
`set_value`. However, in this test, the type of `result` is `int&`,
but the create sender only sends `int`. Before this change (and before
the schedule affine task), something else ended up decaying the
parameter's type further downstream. With the scheduler affine task
and its reliance on `finally`, the `finally` receiver's `set_value`
strictly enforces that it's only called with the expected type, which
caused the compilation error with my changes. Fundamentally, the create
sender was incorrectly sending `int&` despite advertising it only
sends `int`. I updated the create algorithm to faithfully send exactly
what it advertises to fix the problem.

I left finally's existing behavior of decay_t-ing error types since I
ran into other issues trying to get that to work, and it seems less
useful/likely (though, I don't see why it shouldn't be done later).

Note that per discussion in #541, this is a source and behavior
breaking change.
  • Loading branch information
ccotter authored Jul 5, 2023
1 parent e4049bd commit 0df7f0b
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 25 deletions.
33 changes: 21 additions & 12 deletions include/unifex/create.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace unifex {
namespace _create {

template <typename Receiver, typename Fn, typename Context>
template <typename Receiver, typename Fn, typename Context, typename... ValueTypes>
struct _op {
struct type {
explicit type(Receiver rec, Fn fn, Context ctx)
Expand All @@ -34,9 +34,17 @@ struct _op {
: rec_((Receiver&&) rec), fn_((Fn&&) fn), ctx_((Context&&) ctx) {}

template (typename... Ts)
(requires receiver_of<Receiver, Ts...>)
void set_value(Ts&&... ts) noexcept(is_nothrow_receiver_of_v<Receiver, Ts...>) {
unifex::set_value((Receiver&&) rec_, (Ts&&) ts...);
(requires (convertible_to<Ts, ValueTypes> && ...))
void set_value(Ts&&... ts) noexcept {
UNIFEX_TRY {
// Satisfy the value completion contract by converting to the
// Sender's value_types. For example, if set_value is called with
// an lvalue reference but the create Sender sends non-reference
// values.
unifex::set_value(std::move(rec_), static_cast<ValueTypes>(static_cast<Ts&&>(ts))...);
} UNIFEX_CATCH(...) {
unifex::set_error(std::move(rec_), std::current_exception());
}
}

template (typename Error)
Expand Down Expand Up @@ -79,10 +87,10 @@ struct _op {
};
};

template <typename Receiver, typename Fn, typename Context>
using _operation = typename _op<Receiver, Fn, Context>::type;
template <typename Receiver, typename Fn, typename Context, typename... ValueTypes>
using _operation = typename _op<Receiver, Fn, Context, ValueTypes...>::type;

template <typename Fn, typename Context>
template <typename Fn, typename Context, typename... ValueTypes>
struct _snd_base {
struct type {
template <template<typename...> class Variant>
Expand All @@ -100,15 +108,16 @@ struct _snd_base {
template (typename Self, typename Receiver)
(requires derived_from<remove_cvref_t<Self>, type> AND
constructible_from<Fn, member_t<Self, Fn>> AND
constructible_from<Context, member_t<Self, Context>>)
friend _operation<remove_cvref_t<Receiver>, Fn, Context>
constructible_from<Context, member_t<Self, Context>> AND
receiver_of<Receiver, ValueTypes...>)
friend _operation<remove_cvref_t<Receiver>, Fn, Context, ValueTypes...>
tag_invoke(tag_t<connect>, Self&& self, Receiver&& rec)
noexcept(std::is_nothrow_constructible_v<
_operation<Receiver, Fn, Context>,
_operation<Receiver, Fn, Context, ValueTypes...>,
Receiver,
member_t<Self, Fn>,
member_t<Self, Context>>) {
return _operation<remove_cvref_t<Receiver>, Fn, Context>{
return _operation<remove_cvref_t<Receiver>, Fn, Context, ValueTypes...>{
(Receiver&&) rec,
((Self&&) self).fn_,
((Self&&) self).ctx_};
Expand All @@ -121,7 +130,7 @@ struct _snd_base {

template <typename Fn, typename Context, typename... ValueTypes>
struct _snd {
struct type : _snd_base<Fn, Context>::type {
struct type : _snd_base<Fn, Context, ValueTypes...>::type {
template <template<typename...> class Variant, template <typename...> class Tuple>
using value_types = Variant<Tuple<ValueTypes...>>;
};
Expand Down
11 changes: 5 additions & 6 deletions include/unifex/finally.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace unifex
SourceSender,
CompletionSender,
Receiver,
std::decay_t<Values>...>::type;
Values...>::type;

template <
typename SourceSender,
Expand Down Expand Up @@ -384,7 +384,7 @@ namespace unifex
auto* const op = op_;

UNIFEX_TRY {
unifex::activate_union_member<std::tuple<std::decay_t<Values>...>>(
unifex::activate_union_member<std::tuple<Values...>>(
op->value_, static_cast<Values&&>(values)...);
} UNIFEX_CATCH (...) {
std::move(*this).set_error(std::current_exception());
Expand All @@ -411,8 +411,7 @@ namespace unifex
});
unifex::start(completionOp);
} UNIFEX_CATCH (...) {
using decayed_tuple_t = std::tuple<std::decay_t<Values>...>;
unifex::deactivate_union_member<decayed_tuple_t>(op->value_);
unifex::deactivate_union_member<std::tuple<Values...>>(op->value_);
unifex::set_error(
static_cast<Receiver&&>(op->receiver_), std::current_exception());
}
Expand Down Expand Up @@ -593,7 +592,7 @@ namespace unifex
sender_value_types_t<
remove_cvref_t<SourceSender>,
manual_lifetime_union,
decayed_tuple<std::tuple>::template apply>
std::tuple>
value_;
};

Expand Down Expand Up @@ -647,7 +646,7 @@ namespace unifex
template <typename...> class Variant,
template <typename...> class Tuple>
using value_types = typename sender_traits<SourceSender>::
template value_types<Variant, decayed_tuple<Tuple>::template apply>;
template value_types<Variant, Tuple>;

// This can produce any of the error_types of SourceSender, or of
// CompletionSender or an exception_ptr corresponding to an exception thrown
Expand Down
226 changes: 223 additions & 3 deletions test/create_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
*/

#include <unifex/create.hpp>
#include <unifex/single_thread_context.hpp>
#include <unifex/async_scope.hpp>
#include <unifex/finally.hpp>
#include <unifex/just.hpp>
#include <unifex/single_thread_context.hpp>
#include <unifex/sync_wait.hpp>

#include <optional>
Expand All @@ -30,6 +32,9 @@
using namespace unifex;

namespace {

int global;

struct CreateTest : testing::Test {
unifex::single_thread_context someThread;
unifex::async_scope someScope;
Expand All @@ -47,6 +52,14 @@ struct CreateTest : testing::Test {
});
}

void anIntRefAPI(void* context, void (*completed)(void* context, int& result)) {
// Execute some work asynchronously on some other thread. When its
// work is finished, pass the result to the callback.
someScope.detached_spawn_call_on(someThread.get_scheduler(), [=]() noexcept {
completed(context, global);
});
}

void aVoidAPI(void* context, void (*completed)(void* context)) {
// Execute some work asynchronously on some other thread. When its
// work is finished, pass the result to the callback.
Expand All @@ -58,18 +71,225 @@ struct CreateTest : testing::Test {
} // anonymous namespace

TEST_F(CreateTest, BasicTest) {
{
auto snd = [this](int a, int b) {
return create<int>([a, b, this](auto& rec) {
static_assert(receiver_of<decltype(rec), int>);
static_assert(!receiver_of<decltype(rec), int*>);
anIntAPI(a, b, &rec, [](void* context, int result) {
unifex::void_cast<decltype(rec)>(context).set_value(result);
});
});
}(1, 2);

std::optional<int> res = sync_wait(std::move(snd));
ASSERT_TRUE(res.has_value());
EXPECT_EQ(*res, 3);
}

{
auto snd = [this]() {
return create<int&>([this](auto& rec) {
static_assert(receiver_of<decltype(rec), int&>);
anIntRefAPI(&rec, [](void* context, int& result) {
unifex::void_cast<decltype(rec)>(context).set_value(result);
});
});
}();

std::optional<std::reference_wrapper<int>> res = sync_wait(std::move(snd));
ASSERT_TRUE(res.has_value());
EXPECT_EQ(&res->get(), &global);
}
}

TEST_F(CreateTest, FinallyCreate) {
auto snd = [this](int a, int b) {
return create<int>([a, b, this](auto& rec) {
static_assert(receiver_of<std::decay_t<decltype(rec)>, int>);
anIntAPI(a, b, &rec, [](void* context, int result) {
unifex::void_cast<decltype(rec)>(context).set_value(result);
});
});
}(1, 2) | finally(just());

std::optional<int> res = sync_wait(std::move(snd));
ASSERT_TRUE(res.has_value());
EXPECT_EQ(*res, 3);
}

TEST_F(CreateTest, DoubleCreateSetsIntValue) {
auto snd = [this](int a, int b) {
return create<int>([a, b, this](auto& rec) {
static_assert(receiver_of<decltype(rec), int>);
return create<double>([a, b, this](auto& rec) {
static_assert(receiver_of<std::decay_t<decltype(rec)>, int>);
anIntAPI(a, b, &rec, [](void* context, int result) {
unifex::void_cast<decltype(rec)>(context).set_value(result);
});
});
}(1, 2);

static_assert(std::is_same_v<decltype(sync_wait(std::move(snd))), std::optional<double>>);
std::optional<double> res = sync_wait(std::move(snd));
ASSERT_TRUE(res.has_value());
EXPECT_EQ(*res, 3);
}

struct TrackingObject {
static int moves;
static int copies;

explicit TrackingObject(int val) : val(val) {}
TrackingObject(const TrackingObject& other) : val(other.val) {
++copies;
}
TrackingObject(TrackingObject&& other) : val(other.val) {
++moves;
other.was_moved = true;
}
TrackingObject& operator=(const TrackingObject&) = delete;
TrackingObject& operator=(TrackingObject&&) = delete;

int val;
bool was_moved = false;
};
int TrackingObject::moves = 0;
int TrackingObject::copies = 0;

TEST_F(CreateTest, CreateObjectNotCopied) {
auto snd = [this](int a, int b) {
return create<TrackingObject>([a, b, this](auto& rec) {
static_assert(receiver_of<std::decay_t<decltype(rec)>, TrackingObject>);
anIntAPI(a, b, &rec, [](void* context, int result) {
unifex::void_cast<decltype(rec)>(context).set_value(TrackingObject{result});
});
});
}(1, 2);

TrackingObject::copies = 0;

std::optional<TrackingObject> res = sync_wait(std::move(snd));
ASSERT_TRUE(res.has_value());
EXPECT_EQ(res->val, 3);
EXPECT_EQ(TrackingObject::copies, 0);
}

TEST_F(CreateTest, CreateObjectCopied) {
auto snd = [this](int a, int b) {
return create<TrackingObject>([a, b, this](auto& rec) {
static_assert(receiver_of<std::decay_t<decltype(rec)>, TrackingObject>);
anIntAPI(a, b, &rec, [](void* context, int result) {
TrackingObject obj{result};
unifex::void_cast<decltype(rec)>(context).set_value(obj);
});
});
}(1, 2);

TrackingObject::copies = 0;

std::optional<TrackingObject> res = sync_wait(std::move(snd));
ASSERT_TRUE(res.has_value());
EXPECT_EQ(res->val, 3);
EXPECT_EQ(TrackingObject::copies, 1);
}

TEST_F(CreateTest, CreateObjectLeadsToNewObject) {
auto snd = [this](int a, int b) {
return create<TrackingObject>([a, b, this](auto& rec) {
static_assert(receiver_of<std::decay_t<decltype(rec)>, TrackingObject>);
anIntAPI(a, b, &rec, [](void* context, int result) {
unifex::void_cast<decltype(rec)>(context).set_value(TrackingObject{result});
});
});
}(1, 2) | then([](TrackingObject&& obj) {
return obj.val;
});

TrackingObject::copies = 0;
TrackingObject::moves = 0;

std::optional<int> res = sync_wait(std::move(snd));
ASSERT_TRUE(res.has_value());
EXPECT_EQ(*res, 3);
EXPECT_EQ(TrackingObject::copies, 0);
EXPECT_GE(TrackingObject::moves, 1);
}

TEST_F(CreateTest, CreateWithConditionalMove) {
TrackingObject obj{0};

struct Data {
void* context;
TrackingObject* obj;
};
Data data{nullptr, &obj};

auto snd = [this, &data](int a, int b) {
return create<TrackingObject&&>([a, b, &data, this](auto& rec) {
static_assert(receiver_of<std::decay_t<decltype(rec)>, TrackingObject&&>);
data.context = &rec;
anIntAPI(a, b, &data, [](void* context, int result) {
Data& data = unifex::void_cast<Data&>(context);
data.obj->val = result;
unifex::void_cast<decltype(rec)>(data.context).set_value(std::move(*data.obj));
});
});
}(1, 2) | then([](TrackingObject&& obj) {
return obj.val;
});

TrackingObject::copies = 0;
TrackingObject::moves = 0;

std::optional<int> res = sync_wait(std::move(snd));
ASSERT_TRUE(res.has_value());
EXPECT_EQ(*res, 3);
EXPECT_EQ(TrackingObject::copies, 0);
EXPECT_EQ(TrackingObject::moves, 0);
EXPECT_FALSE(obj.was_moved);
}

TEST_F(CreateTest, CreateWithConversions) {
struct A {
int val;
};
struct B {
B(A a) : val(a.val) {}
B(int val) : val(val) {}
operator A() const {
return A{val};
}
int val;
};

{
auto snd = [this](int a, int b) {
return create<A>([a, b, this](auto& rec) {
static_assert(receiver_of<std::decay_t<decltype(rec)>, A>);
anIntAPI(a, b, &rec, [](void* context, int result) {
unifex::void_cast<decltype(rec)>(context).set_value(B{result});
});
});
}(1, 2);

std::optional<A> res = sync_wait(std::move(snd));
ASSERT_TRUE(res.has_value());
EXPECT_EQ(res->val, 3);
}

{
auto snd = [this](int a, int b) {
return create<B>([a, b, this](auto& rec) {
static_assert(receiver_of<std::decay_t<decltype(rec)>, int>);
anIntAPI(a, b, &rec, [](void* context, int result) {
unifex::void_cast<decltype(rec)>(context).set_value(A{result});
});
});
}(1, 2);

std::optional<B> res = sync_wait(std::move(snd));
ASSERT_TRUE(res.has_value());
EXPECT_EQ(res->val, 3);
}
}

TEST_F(CreateTest, VoidWithContextTest) {
Expand Down
Loading

0 comments on commit 0df7f0b

Please sign in to comment.