diff --git a/include/unifex/create.hpp b/include/unifex/create.hpp index c9f888b8d..d8ba1f330 100644 --- a/include/unifex/create.hpp +++ b/include/unifex/create.hpp @@ -24,7 +24,7 @@ namespace unifex { namespace _create { -template +template struct _op { struct type { explicit type(Receiver rec, Fn fn, Context ctx) @@ -34,9 +34,17 @@ struct _op { : rec_((Receiver&&) rec), fn_((Fn&&) fn), ctx_((Context&&) ctx) {} template (typename... Ts) - (requires receiver_of) - void set_value(Ts&&... ts) noexcept(is_nothrow_receiver_of_v) { - unifex::set_value((Receiver&&) rec_, (Ts&&) ts...); + (requires (convertible_to && ...)) + void set_value(Ts&&... ts) noexcept { + UNIFEX_TRY { + // Satisfy the value completion contract by converting to the + // Sender's value_types. For example, if set_value is called with + // an lvalue reference but the create Sender sends non-reference + // values. + unifex::set_value(std::move(rec_), static_cast(static_cast(ts))...); + } UNIFEX_CATCH(...) { + unifex::set_error(std::move(rec_), std::current_exception()); + } } template (typename Error) @@ -79,10 +87,10 @@ struct _op { }; }; -template -using _operation = typename _op::type; +template +using _operation = typename _op::type; -template +template struct _snd_base { struct type { template class Variant> @@ -100,15 +108,16 @@ struct _snd_base { template (typename Self, typename Receiver) (requires derived_from, type> AND constructible_from> AND - constructible_from>) - friend _operation, Fn, Context> + constructible_from> AND + receiver_of) + friend _operation, Fn, Context, ValueTypes...> tag_invoke(tag_t, Self&& self, Receiver&& rec) noexcept(std::is_nothrow_constructible_v< - _operation, + _operation, Receiver, member_t, member_t>) { - return _operation, Fn, Context>{ + return _operation, Fn, Context, ValueTypes...>{ (Receiver&&) rec, ((Self&&) self).fn_, ((Self&&) self).ctx_}; @@ -121,7 +130,7 @@ struct _snd_base { template struct _snd { - struct type : _snd_base::type { + struct type : _snd_base::type { template class Variant, template class Tuple> using value_types = Variant>; }; diff --git a/include/unifex/finally.hpp b/include/unifex/finally.hpp index c3c012cf5..fee455a80 100644 --- a/include/unifex/finally.hpp +++ b/include/unifex/finally.hpp @@ -71,7 +71,7 @@ namespace unifex SourceSender, CompletionSender, Receiver, - std::decay_t...>::type; + Values...>::type; template < typename SourceSender, @@ -384,7 +384,7 @@ namespace unifex auto* const op = op_; UNIFEX_TRY { - unifex::activate_union_member...>>( + unifex::activate_union_member>( op->value_, static_cast(values)...); } UNIFEX_CATCH (...) { std::move(*this).set_error(std::current_exception()); @@ -411,8 +411,7 @@ namespace unifex }); unifex::start(completionOp); } UNIFEX_CATCH (...) { - using decayed_tuple_t = std::tuple...>; - unifex::deactivate_union_member(op->value_); + unifex::deactivate_union_member>(op->value_); unifex::set_error( static_cast(op->receiver_), std::current_exception()); } @@ -593,7 +592,7 @@ namespace unifex sender_value_types_t< remove_cvref_t, manual_lifetime_union, - decayed_tuple::template apply> + std::tuple> value_; }; @@ -647,7 +646,7 @@ namespace unifex template class Variant, template class Tuple> using value_types = typename sender_traits:: - template value_types::template apply>; + template value_types; // This can produce any of the error_types of SourceSender, or of // CompletionSender or an exception_ptr corresponding to an exception thrown diff --git a/test/create_test.cpp b/test/create_test.cpp index b1ae2d259..1f815f1ff 100644 --- a/test/create_test.cpp +++ b/test/create_test.cpp @@ -15,8 +15,10 @@ */ #include -#include #include +#include +#include +#include #include #include @@ -30,6 +32,9 @@ using namespace unifex; namespace { + +int global; + struct CreateTest : testing::Test { unifex::single_thread_context someThread; unifex::async_scope someScope; @@ -47,6 +52,14 @@ struct CreateTest : testing::Test { }); } + void anIntRefAPI(void* context, void (*completed)(void* context, int& result)) { + // Execute some work asynchronously on some other thread. When its + // work is finished, pass the result to the callback. + someScope.detached_spawn_call_on(someThread.get_scheduler(), [=]() noexcept { + completed(context, global); + }); + } + void aVoidAPI(void* context, void (*completed)(void* context)) { // Execute some work asynchronously on some other thread. When its // work is finished, pass the result to the callback. @@ -58,18 +71,225 @@ struct CreateTest : testing::Test { } // anonymous namespace TEST_F(CreateTest, BasicTest) { + { + auto snd = [this](int a, int b) { + return create([a, b, this](auto& rec) { + static_assert(receiver_of); + static_assert(!receiver_of); + anIntAPI(a, b, &rec, [](void* context, int result) { + unifex::void_cast(context).set_value(result); + }); + }); + }(1, 2); + + std::optional res = sync_wait(std::move(snd)); + ASSERT_TRUE(res.has_value()); + EXPECT_EQ(*res, 3); + } + + { + auto snd = [this]() { + return create([this](auto& rec) { + static_assert(receiver_of); + anIntRefAPI(&rec, [](void* context, int& result) { + unifex::void_cast(context).set_value(result); + }); + }); + }(); + + std::optional> res = sync_wait(std::move(snd)); + ASSERT_TRUE(res.has_value()); + EXPECT_EQ(&res->get(), &global); + } +} + +TEST_F(CreateTest, FinallyCreate) { + auto snd = [this](int a, int b) { + return create([a, b, this](auto& rec) { + static_assert(receiver_of, int>); + anIntAPI(a, b, &rec, [](void* context, int result) { + unifex::void_cast(context).set_value(result); + }); + }); + }(1, 2) | finally(just()); + + std::optional res = sync_wait(std::move(snd)); + ASSERT_TRUE(res.has_value()); + EXPECT_EQ(*res, 3); +} + +TEST_F(CreateTest, DoubleCreateSetsIntValue) { auto snd = [this](int a, int b) { - return create([a, b, this](auto& rec) { - static_assert(receiver_of); + return create([a, b, this](auto& rec) { + static_assert(receiver_of, int>); anIntAPI(a, b, &rec, [](void* context, int result) { unifex::void_cast(context).set_value(result); }); }); }(1, 2); + static_assert(std::is_same_v>); + std::optional res = sync_wait(std::move(snd)); + ASSERT_TRUE(res.has_value()); + EXPECT_EQ(*res, 3); +} + +struct TrackingObject { + static int moves; + static int copies; + + explicit TrackingObject(int val) : val(val) {} + TrackingObject(const TrackingObject& other) : val(other.val) { + ++copies; + } + TrackingObject(TrackingObject&& other) : val(other.val) { + ++moves; + other.was_moved = true; + } + TrackingObject& operator=(const TrackingObject&) = delete; + TrackingObject& operator=(TrackingObject&&) = delete; + + int val; + bool was_moved = false; +}; +int TrackingObject::moves = 0; +int TrackingObject::copies = 0; + +TEST_F(CreateTest, CreateObjectNotCopied) { + auto snd = [this](int a, int b) { + return create([a, b, this](auto& rec) { + static_assert(receiver_of, TrackingObject>); + anIntAPI(a, b, &rec, [](void* context, int result) { + unifex::void_cast(context).set_value(TrackingObject{result}); + }); + }); + }(1, 2); + + TrackingObject::copies = 0; + + std::optional res = sync_wait(std::move(snd)); + ASSERT_TRUE(res.has_value()); + EXPECT_EQ(res->val, 3); + EXPECT_EQ(TrackingObject::copies, 0); +} + +TEST_F(CreateTest, CreateObjectCopied) { + auto snd = [this](int a, int b) { + return create([a, b, this](auto& rec) { + static_assert(receiver_of, TrackingObject>); + anIntAPI(a, b, &rec, [](void* context, int result) { + TrackingObject obj{result}; + unifex::void_cast(context).set_value(obj); + }); + }); + }(1, 2); + + TrackingObject::copies = 0; + + std::optional res = sync_wait(std::move(snd)); + ASSERT_TRUE(res.has_value()); + EXPECT_EQ(res->val, 3); + EXPECT_EQ(TrackingObject::copies, 1); +} + +TEST_F(CreateTest, CreateObjectLeadsToNewObject) { + auto snd = [this](int a, int b) { + return create([a, b, this](auto& rec) { + static_assert(receiver_of, TrackingObject>); + anIntAPI(a, b, &rec, [](void* context, int result) { + unifex::void_cast(context).set_value(TrackingObject{result}); + }); + }); + }(1, 2) | then([](TrackingObject&& obj) { + return obj.val; + }); + + TrackingObject::copies = 0; + TrackingObject::moves = 0; + + std::optional res = sync_wait(std::move(snd)); + ASSERT_TRUE(res.has_value()); + EXPECT_EQ(*res, 3); + EXPECT_EQ(TrackingObject::copies, 0); + EXPECT_GE(TrackingObject::moves, 1); +} + +TEST_F(CreateTest, CreateWithConditionalMove) { + TrackingObject obj{0}; + + struct Data { + void* context; + TrackingObject* obj; + }; + Data data{nullptr, &obj}; + + auto snd = [this, &data](int a, int b) { + return create([a, b, &data, this](auto& rec) { + static_assert(receiver_of, TrackingObject&&>); + data.context = &rec; + anIntAPI(a, b, &data, [](void* context, int result) { + Data& data = unifex::void_cast(context); + data.obj->val = result; + unifex::void_cast(data.context).set_value(std::move(*data.obj)); + }); + }); + }(1, 2) | then([](TrackingObject&& obj) { + return obj.val; + }); + + TrackingObject::copies = 0; + TrackingObject::moves = 0; + std::optional res = sync_wait(std::move(snd)); ASSERT_TRUE(res.has_value()); EXPECT_EQ(*res, 3); + EXPECT_EQ(TrackingObject::copies, 0); + EXPECT_EQ(TrackingObject::moves, 0); + EXPECT_FALSE(obj.was_moved); +} + +TEST_F(CreateTest, CreateWithConversions) { + struct A { + int val; + }; + struct B { + B(A a) : val(a.val) {} + B(int val) : val(val) {} + operator A() const { + return A{val}; + } + int val; + }; + + { + auto snd = [this](int a, int b) { + return create([a, b, this](auto& rec) { + static_assert(receiver_of, A>); + anIntAPI(a, b, &rec, [](void* context, int result) { + unifex::void_cast(context).set_value(B{result}); + }); + }); + }(1, 2); + + std::optional res = sync_wait(std::move(snd)); + ASSERT_TRUE(res.has_value()); + EXPECT_EQ(res->val, 3); + } + + { + auto snd = [this](int a, int b) { + return create([a, b, this](auto& rec) { + static_assert(receiver_of, int>); + anIntAPI(a, b, &rec, [](void* context, int result) { + unifex::void_cast(context).set_value(A{result}); + }); + }); + }(1, 2); + + std::optional res = sync_wait(std::move(snd)); + ASSERT_TRUE(res.has_value()); + EXPECT_EQ(res->val, 3); + } } TEST_F(CreateTest, VoidWithContextTest) { diff --git a/test/finally_test.cpp b/test/finally_test.cpp index 338836b8d..2977caac1 100644 --- a/test/finally_test.cpp +++ b/test/finally_test.cpp @@ -14,19 +14,24 @@ * limitations under the License. */ #include + #include -#include -#include -#include -#include #include #include #include +#include #include #include +#include +#include +#include +#include +#include #include #include +#include +#include #include @@ -45,6 +50,103 @@ TEST(Finally, Value) { EXPECT_EQ(res->second, context.get_thread_id()); } +TEST(Finally, Ref) { + { + int a = 0; + + auto sndr = just_from([&a]() -> int& { return a; }) + | finally(just()); + using Sndr = decltype(sndr); + + static_assert(std::is_same_v< + sender_value_types_t, + std::variant> + >); + static_assert(std::is_same_v< + sender_error_types_t, + std::variant + >); + static_assert(!sender_traits::sends_done); + + auto res = std::move(sndr) | sync_wait(); + + ASSERT_FALSE(!res); + EXPECT_EQ(&res->get(), &a); + } + + { + int a = 0; + + auto res = just_from([&a]() -> const int& { return a; }) + | finally(just()) + | sync_wait(); + + ASSERT_FALSE(!res); + EXPECT_EQ(&res->get(), &a); + } + + { + int a = 0; + + auto res = just_from([&a]() -> int& { return a; }) + | finally(just()) + | then([](int& i) -> int& { return i; }) + | sync_wait(); + + ASSERT_FALSE(!res); + EXPECT_EQ(&res->get(), &a); + } +} + +struct sends_error_ref { + + template < + template class Variant, + template class Tuple> + using value_types = Variant>; + + template