Skip to content

Commit

Permalink
Add maybe_wrap utility
Browse files Browse the repository at this point in the history
Signed-off-by: Nahuel Espinosa <[email protected]>
  • Loading branch information
nahueespinosa committed Jan 15, 2024
1 parent 089de63 commit ec3c4db
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 20 deletions.
22 changes: 12 additions & 10 deletions beluga/include/beluga/actions/propagate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <execution>

#include <beluga/type_traits/particle_traits.hpp>
#include <beluga/utility/maybe_wrap.hpp>
#include <beluga/views/particles.hpp>

#include <range/v3/action/action.hpp>
Expand All @@ -45,15 +46,15 @@ struct propagate_base_fn {
class Model,
std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0,
std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(ExecutionPolicy&& policy, Range& range, Model model) const -> Range& {
constexpr auto operator()(ExecutionPolicy&& policy, Range& range, Model&& model) const -> Range& {
static_assert(beluga::is_particle_range_v<Range>);
auto states = range | beluga::views::states | ranges::views::common;
std::transform(
policy, // rvalue policies are not supported in some STL implementations
std::begin(states), //
std::end(states), //
std::begin(states), //
std::move(model));
std::forward<Model>(model));
return range;
}

Expand All @@ -64,17 +65,18 @@ struct propagate_base_fn {
class ExecutionPolicy,
std::enable_if_t<ranges::range<Range>, int> = 0,
std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(Range&& range, Model model, ExecutionPolicy policy) const -> Range& {
return (*this)(std::move(policy), std::forward<Range>(range), std::move(model));
constexpr auto operator()(Range&& range, Model&& model, ExecutionPolicy policy) const -> Range& {
return (*this)(std::move(policy), std::forward<Range>(range), std::forward<Model>(model));
}

/// Overload that returns a view closure to compose with other views.
template <
class ExecutionPolicy, //
class Model, //
std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(ExecutionPolicy policy, Model model) const {
return ranges::make_action_closure(ranges::bind_back(propagate_base_fn{}, std::move(model), std::move(policy)));
constexpr auto operator()(ExecutionPolicy policy, Model&& model) const {
return ranges::make_action_closure(
ranges::bind_back(propagate_base_fn{}, maybe_wrap<Model>(model), std::move(policy)));
}
};

Expand All @@ -87,14 +89,14 @@ struct propagate_fn : public propagate_base_fn {
class Range, //
class Model, //
std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(Range&& range, Model model) const -> Range& {
return (*this)(std::execution::seq, std::forward<Range>(range), std::move(model));
constexpr auto operator()(Range&& range, Model&& model) const -> Range& {
return (*this)(std::execution::seq, std::forward<Range>(range), std::forward<Model>(model));
}

/// Overload that returns a view closure to compose with other views.
template <class Model>
constexpr auto operator()(Model model) const {
return ranges::make_action_closure(ranges::bind_back(propagate_fn{}, std::move(model)));
constexpr auto operator()(Model&& model) const {
return ranges::make_action_closure(ranges::bind_back(propagate_fn{}, maybe_wrap<Model>(model)));
}
};

Expand Down
22 changes: 12 additions & 10 deletions beluga/include/beluga/actions/reweight.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <execution>

#include <beluga/type_traits/particle_traits.hpp>
#include <beluga/utility/maybe_wrap.hpp>
#include <beluga/views/particles.hpp>

#include <range/v3/action/action.hpp>
Expand Down Expand Up @@ -49,7 +50,7 @@ struct reweight_base_fn {
class Model,
std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0,
std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(ExecutionPolicy&& policy, Range& range, Model model) const -> Range& {
constexpr auto operator()(ExecutionPolicy&& policy, Range& range, Model&& model) const -> Range& {
static_assert(beluga::is_particle_range_v<Range>);
auto states = range | beluga::views::states | ranges::views::common;
auto weights = range | beluga::views::weights | ranges::views::common;
Expand All @@ -59,7 +60,7 @@ struct reweight_base_fn {
std::end(states), //
std::begin(weights), //
std::begin(weights), //
[model = std::move(model)](const auto& s, auto w) { return w * model(s); });
[model = std::forward<Model>(model)](const auto& s, auto w) { return w * model(s); });
return range;
}

Expand All @@ -70,14 +71,15 @@ struct reweight_base_fn {
class ExecutionPolicy,
std::enable_if_t<ranges::range<Range>, int> = 0,
std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(Range&& range, Model model, ExecutionPolicy policy) const -> Range& {
return (*this)(std::move(policy), std::forward<Range>(range), std::move(model));
constexpr auto operator()(Range&& range, Model&& model, ExecutionPolicy policy) const -> Range& {
return (*this)(std::move(policy), std::forward<Range>(range), std::forward<Model>(model));
}

/// Overload that returns a view closure to compose with other views.
template <class ExecutionPolicy, class Model, std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(ExecutionPolicy policy, Model model) const {
return ranges::make_action_closure(ranges::bind_back(reweight_base_fn{}, std::move(model), std::move(policy)));
constexpr auto operator()(ExecutionPolicy policy, Model&& model) const {
return ranges::make_action_closure(
ranges::bind_back(reweight_base_fn{}, maybe_wrap<Model>(model), std::move(policy)));
}
};

Expand All @@ -87,14 +89,14 @@ struct reweight_fn : public reweight_base_fn {

/// Overload that defines a default execution policy.
template <class Range, class Model, std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(Range&& range, Model model) const -> Range& {
return (*this)(std::execution::seq, std::forward<Range>(range), std::move(model));
constexpr auto operator()(Range&& range, Model&& model) const -> Range& {
return (*this)(std::execution::seq, std::forward<Range>(range), std::forward<Model>(model));
}

/// Overload that returns a view closure to compose with other views.
template <class Model>
constexpr auto operator()(Model model) const {
return ranges::make_action_closure(ranges::bind_back(reweight_fn{}, std::move(model)));
constexpr auto operator()(Model&& model) const {
return ranges::make_action_closure(ranges::bind_back(reweight_fn{}, maybe_wrap<Model>(model)));
}
};

Expand Down
44 changes: 44 additions & 0 deletions beluga/include/beluga/utility/maybe_wrap.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BELUGA_UTILITY_MAYBE_WRAP_HPP
#define BELUGA_UTILITY_MAYBE_WRAP_HPP

namespace beluga {

/// Conditionally wraps an lvalue-reference into a std::reference_wrapper.
/**
* This is useful to pass arguments using std::bind, std::bind_back, etc.
* Note that function-like objects don't need to be unwrapped to be called,
* as std::reference_wrapper implements the call operator.
*/
template <class T>
constexpr auto maybe_wrap(std::remove_reference_t<T>& t) noexcept {
if constexpr (std::is_const_v<std::remove_reference_t<T>>) {
return std::cref(t);
} else {
return std::ref(t);
}
}

/// Overload for rvalues. Simply forward.
template <class T>
constexpr auto&& maybe_wrap(std::remove_reference_t<T>&& t) noexcept {
static_assert(!std::is_lvalue_reference_v<T>);
return static_cast<T&&>(t);
}

} // namespace beluga

#endif
11 changes: 11 additions & 0 deletions beluga/test/beluga/actions/test_propagate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,15 @@ TEST(PropagateAction, Composition) {
ASSERT_TRUE(ranges::equal(beluga::views::states(input), std::vector{4, 4, 4, 4, 4}));
}

TEST(PropagateAction, StatefulModel) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(1.0))};
auto model = [value = 0](int) mutable { return value++; };
input |= beluga::views::sample | //
ranges::views::take_exactly(5) | //
beluga::actions::assign | //
beluga::actions::propagate(model);
ASSERT_TRUE(ranges::equal(beluga::views::states(input), std::vector{0, 1, 2, 3, 4}));
ASSERT_EQ(model(0), 5); // the model was passed by reference
}

} // namespace
11 changes: 11 additions & 0 deletions beluga/test/beluga/actions/test_reweight.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,15 @@ TEST(ReweightAction, Composition) {
ASSERT_TRUE(ranges::equal(beluga::views::weights(input), std::vector<beluga::Weight>{2, 2, 2, 2, 2}));
}

TEST(ReweightAction, StatefulModel) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(1.0))};
auto model = [value = 0](int) mutable { return value++; };
input |= beluga::views::sample | //
ranges::views::take_exactly(5) | //
beluga::actions::assign | //
beluga::actions::reweight(model);
ASSERT_TRUE(ranges::equal(beluga::views::weights(input), std::vector<beluga::Weight>{0, 1, 2, 3, 4}));
ASSERT_EQ(model(0), 5); // the model was passed by reference
}

} // namespace

0 comments on commit ec3c4db

Please sign in to comment.