Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add propagate and reweight actions #289

Merged
merged 6 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions beluga/include/beluga/actions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#define BELUGA_ACTIONS_HPP

#include <beluga/actions/assign.hpp>
#include <beluga/actions/propagate.hpp>
#include <beluga/actions/reweight.hpp>

/**
* \file
Expand Down
113 changes: 113 additions & 0 deletions beluga/include/beluga/actions/propagate.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BELUGA_ACTIONS_PROPAGATE_HPP
#define BELUGA_ACTIONS_PROPAGATE_HPP

#include <algorithm>
#include <execution>

#include <beluga/type_traits/particle_traits.hpp>
#include <beluga/views/particles.hpp>

#include <range/v3/action/action.hpp>
#include <range/v3/view/common.hpp>

namespace beluga::actions {

namespace detail {

/// Implementation detail for a propagate range adaptor object.
struct propagate_base_fn {
/// Overload that implements the propagate algorithm.
/**
* \tparam ExecutionPolicy An [execution policy](https://en.cppreference.com/w/cpp/algorithm/execution_policy_tag_t).
* \tparam Range An [input range](https://en.cppreference.com/w/cpp/ranges/input_range) of particles.
* \tparam Model A callable that can compute the new state from the previous state.
* \param policy The execution policy to use.
* \param range An existing range of particles to apply this action to.
* \param model A callable instance to compute the states from the previous states.
*/
template <
class ExecutionPolicy,
class Range,
class Model,
std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0,
std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(ExecutionPolicy&& policy, Range& range, Model model) const -> Range& {
static_assert(beluga::is_particle_range_v<Range>);
auto states = range | beluga::views::states | ranges::views::common;
std::transform(
policy, // rvalue policies are not supported in some STL implementations
std::begin(states), //
std::end(states), //
std::begin(states), //
std::move(model));
return range;
}

/// Overload that re-orders arguments from a view closure.
template <
class Range,
class Model,
class ExecutionPolicy,
std::enable_if_t<ranges::range<Range>, int> = 0,
std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(Range&& range, Model model, ExecutionPolicy policy) const -> Range& {
return (*this)(std::move(policy), std::forward<Range>(range), std::move(model));
}

/// Overload that returns a view closure to compose with other views.
template <
class ExecutionPolicy, //
class Model, //
std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(ExecutionPolicy policy, Model model) const {
return ranges::make_action_closure(ranges::bind_back(propagate_base_fn{}, std::move(model), std::move(policy)));
}
};

/// Implementation detail for a propagate range adaptor object with a default execution policy.
struct propagate_fn : public propagate_base_fn {
using propagate_base_fn::operator();

/// Overload that defines a default execution policy.
template <
class Range, //
class Model, //
std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(Range&& range, Model model) const -> Range& {
return (*this)(std::execution::seq, std::forward<Range>(range), std::move(model));
}

/// Overload that returns a view closure to compose with other views.
template <class Model>
constexpr auto operator()(Model model) const {
return ranges::make_action_closure(ranges::bind_back(propagate_fn{}, std::move(model)));
}
};

} // namespace detail

/// [Range adaptor object](https://en.cppreference.com/w/cpp/named_req/RangeAdaptorObject) that
/// can update the state in a particle range using a motion model.
/**
* This action updates particle states based on their current value and a state-transition function (or motion model).
* Every other particle attribute (such as importance sampling weights) is left unchanged.
*/
inline constexpr detail::propagate_fn propagate;

} // namespace beluga::actions

#endif
114 changes: 114 additions & 0 deletions beluga/include/beluga/actions/reweight.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BELUGA_ACTIONS_REWEIGHT_HPP
#define BELUGA_ACTIONS_REWEIGHT_HPP

#include <algorithm>
#include <execution>

#include <beluga/type_traits/particle_traits.hpp>
#include <beluga/views/particles.hpp>

#include <range/v3/action/action.hpp>
#include <range/v3/algorithm/max_element.hpp>
#include <range/v3/view/common.hpp>

namespace beluga::actions {

namespace detail {

/// Implementation detail for a reweight range adaptor object.
struct reweight_base_fn {
/// Overload that implements the reweight algorithm.
/**
* \tparam ExecutionPolicy An [execution policy](https://en.cppreference.com/w/cpp/algorithm/execution_policy_tag_t).
* \tparam Range An [input range](https://en.cppreference.com/w/cpp/ranges/input_range) of particles.
* \tparam Model A callable that can compute the importance weight given a particle state.
* \param policy The execution policy to use.
* \param range An existing range of particles to apply this action to.
* \param model A callable instance to compute the weights given the particle states.
*
* For each particle, we multiply the current weight by the new importance weight to accumulate information from
* sensor updates.
*/
template <
class ExecutionPolicy,
class Range,
class Model,
std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0,
std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(ExecutionPolicy&& policy, Range& range, Model model) const -> Range& {
static_assert(beluga::is_particle_range_v<Range>);
auto states = range | beluga::views::states | ranges::views::common;
auto weights = range | beluga::views::weights | ranges::views::common;
std::transform(
policy, //
std::begin(states), //
std::end(states), //
std::begin(weights), //
std::begin(weights), //
[model = std::move(model)](const auto& s, auto w) { return w * model(s); });
return range;
}

/// Overload that re-orders arguments from a view closure.
template <
class Range,
class Model,
class ExecutionPolicy,
std::enable_if_t<ranges::range<Range>, int> = 0,
std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(Range&& range, Model model, ExecutionPolicy policy) const -> Range& {
return (*this)(std::move(policy), std::forward<Range>(range), std::move(model));
}

/// Overload that returns a view closure to compose with other views.
template <class ExecutionPolicy, class Model, std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(ExecutionPolicy policy, Model model) const {
return ranges::make_action_closure(ranges::bind_back(reweight_base_fn{}, std::move(model), std::move(policy)));
}
};

/// Implementation detail for a reweight range adaptor object with a default execution policy.
struct reweight_fn : public reweight_base_fn {
using reweight_base_fn::operator();

/// Overload that defines a default execution policy.
template <class Range, class Model, std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(Range&& range, Model model) const -> Range& {
return (*this)(std::execution::seq, std::forward<Range>(range), std::move(model));
}

/// Overload that returns a view closure to compose with other views.
template <class Model>
constexpr auto operator()(Model model) const {
return ranges::make_action_closure(ranges::bind_back(reweight_fn{}, std::move(model)));
}
};

} // namespace detail

/// [Range adaptor object](https://en.cppreference.com/w/cpp/named_req/RangeAdaptorObject) that
/// can update the weights in a particle range using a sensor model.
/**
* This action updates particle weights by importance weight multiplication.
* These importance weights are computed by a given measurement likelihood
* function (or sensor model) for current particle states.
*/
inline constexpr detail::reweight_fn reweight;

} // namespace beluga::actions

#endif
2 changes: 2 additions & 0 deletions beluga/test/beluga/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
add_executable(
test_beluga
actions/test_assign.cpp
actions/test_propagate.cpp
actions/test_reweight.cpp
algorithm/raycasting/test_bresenham.cpp
algorithm/test_distance_map.cpp
algorithm/test_estimation.cpp
Expand Down
66 changes: 66 additions & 0 deletions beluga/test/beluga/actions/test_propagate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gmock/gmock.h>

#include <beluga/actions/assign.hpp>
#include <beluga/actions/propagate.hpp>
#include <beluga/views/sample.hpp>

#include <range/v3/algorithm/equal.hpp>
#include <range/v3/view/take_exactly.hpp>

namespace {

TEST(PropagateAction, DefaultExecutionPolicy) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(1.0))};
input |= beluga::actions::propagate([](int value) { return ++value; });
ASSERT_EQ(input.front(), std::make_tuple(6, 1.0));
}

TEST(PropagateAction, SequencedExecutionPolicy) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(1.0))};
input |= beluga::actions::propagate(std::execution::seq, [](int value) { return ++value; });
ASSERT_EQ(input.front(), std::make_tuple(6, 1.0));
}

TEST(PropagateAction, ParallelExecutionPolicy) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(1.0))};
input |= beluga::actions::propagate(std::execution::par, [](int value) { return ++value; });
ASSERT_EQ(input.front(), std::make_tuple(6, 1.0));
}

TEST(PropagateAction, Composition) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(1.0))};
input |= beluga::actions::propagate([](int value) { return --value; }) | //
beluga::views::sample | //
ranges::views::take_exactly(5) | //
beluga::actions::assign;
auto states = input | beluga::views::states | ranges::to<std::vector>;
ASSERT_THAT(states, testing::ElementsAre(4, 4, 4, 4, 4));
}

TEST(PropagateAction, StatefulModel) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(1.0))};
auto model = [value = 0](int) mutable { return value++; };
input |= beluga::views::sample | //
ranges::views::take_exactly(5) | //
beluga::actions::assign | //
beluga::actions::propagate(std::ref(model));
auto states = input | beluga::views::states | ranges::to<std::vector>;
ASSERT_THAT(states, testing::ElementsAre(0, 1, 2, 3, 4));
ASSERT_EQ(model(0), 5); // the model was passed by reference
}

} // namespace
63 changes: 63 additions & 0 deletions beluga/test/beluga/actions/test_reweight.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gmock/gmock.h>

#include <beluga/actions/assign.hpp>
#include <beluga/actions/reweight.hpp>

#include <range/v3/algorithm/equal.hpp>
#include <range/v3/view/intersperse.hpp>

namespace {

TEST(ReweightAction, DefaultExecutionPolicy) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(2.0))};
input |= beluga::actions::reweight([](int value) { return value; });
ASSERT_EQ(input.front(), std::make_tuple(5, 10.0));
}

TEST(ReweightAction, SequencedExecutionPolicy) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(2.0))};
input |= beluga::actions::reweight(std::execution::seq, [](int value) { return value; });
ASSERT_EQ(input.front(), std::make_tuple(5, 10.0));
}

TEST(ReweightAction, ParallelExecutionPolicy) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(2.0))};
input |= beluga::actions::reweight(std::execution::par, [](int value) { return value; });
ASSERT_EQ(input.front(), std::make_tuple(5, 10.0));
}

TEST(ReweightAction, Composition) {
auto input = std::vector{std::make_tuple(4, beluga::Weight(0.5)), std::make_tuple(4, beluga::Weight(1.0))};
input |= beluga::actions::reweight([](int value) { return value; }) | //
ranges::views::intersperse(std::make_tuple(5, beluga::Weight(1.0))) | //
beluga::actions::assign;
auto weights = input | beluga::views::weights | ranges::to<std::vector>;
ASSERT_THAT(weights, testing::ElementsAre(2, 1, 4));
}

TEST(ReweightAction, StatefulModel) {
auto input = std::vector{std::make_tuple(4, beluga::Weight(0.5)), std::make_tuple(4, beluga::Weight(1.0))};
auto model = [value = 0](int) mutable { return value++; };
input |= ranges::views::intersperse(std::make_tuple(5, beluga::Weight(1.0))) | //
beluga::actions::assign | //
beluga::actions::reweight(std::ref(model));
auto weights = input | beluga::views::weights | ranges::to<std::vector>;
ASSERT_THAT(weights, testing::ElementsAre(0, 1, 2));
ASSERT_EQ(model(0), 3); // the model was passed by reference
}

} // namespace
Loading