Skip to content

Commit

Permalink
Optimize estimate clusters implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Nahuel Espinosa <[email protected]>
  • Loading branch information
nahueespinosa committed Jul 9, 2024
1 parent b3dd2c4 commit 054dd02
Showing 1 changed file with 41 additions and 29 deletions.
70 changes: 41 additions & 29 deletions beluga/include/beluga/algorithm/cluster_based_estimation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

// external
#include <range/v3/algorithm/max_element.hpp>
#include <range/v3/algorithm/sort.hpp>
#include <range/v3/numeric/accumulate.hpp>
#include <range/v3/range/conversion.hpp>
#include <range/v3/view/cache1.hpp>
Expand All @@ -39,6 +40,12 @@
#include <beluga/algorithm/spatial_hash.hpp>
#include <beluga/views.hpp>

#if RANGE_V3_MAJOR == 0 && RANGE_V3_MINOR < 12
#include <range/v3/view/group_by.hpp>
#else
#include <range/v3/view/chunk_by.hpp>
#endif

/**
* \file
* \brief Implementation of a cluster-based estimation algorithm.
Expand Down Expand Up @@ -343,43 +350,48 @@ template <class States, class Weights, class Clusters>
Cluster cluster;

/// Convenient factory method to pass to `zip_with`.
static constexpr auto construct(const State& s, Weight w, Cluster c) { return Particle{s, w, c}; }
static constexpr auto create(const State& s, Weight w, Cluster c) { return Particle{s, w, c}; }
};

struct Estimate {
Weight weight;
EstimateState mean;
EstimateCovariance covariance;

/// Constructor used with `emplace_back`.
constexpr Estimate(Weight w, EstimateState m, EstimateCovariance c)
: weight{w}, mean{std::move(m)}, covariance{std::move(c)} {}
};

// For each cluster found, estimate the mean and covariance of the states that belong to it.
const auto unique_clusters = clusters | ranges::to<std::unordered_set>;
auto estimates = std::vector<Estimate>{};
estimates.reserve(unique_clusters.size());

for (const auto cluster : unique_clusters) {
auto particles = ranges::views::zip_with(&Particle::construct, states, weights, clusters) | //
ranges::views::cache1 | //
ranges::views::filter([cluster](const auto& p) { return p.cluster == cluster; }) | //
ranges::to<std::vector>;

if (particles.size() < 2) {
// If there's only one sample in the cluster we can't estimate the covariance.
continue;
}

auto filtered_states = particles | ranges::views::transform(&Particle::state);
auto filtered_weights = particles | ranges::views::transform(&Particle::weight);
const auto [mean, covariance] = estimate(filtered_states, filtered_weights);
const auto total_weight = ranges::accumulate(filtered_weights, 0.0);
estimates.emplace_back(total_weight, std::move(mean), std::move(covariance));
}

return estimates;
auto particles = ranges::views::zip_with(&Particle::create, states, weights, clusters) | //
ranges::to<std::vector>;

ranges::sort(particles, std::less{}, &Particle::cluster);

// For each cluster, estimate the mean and covariance of the states that belong to it.
return particles |
#if RANGE_V3_MAJOR == 0 && RANGE_V3_MINOR < 12
// Compatibility support for old Range-v3 versions that don't have a `chunk_by` view.
// The difference with between the deprecated `group_by` and the standard `chunk_by` is:
// - group_by: The predicate is evaluated between the first element in the group and the current one.
// - chunk_by: The predicate is evaluated between adjacent elements.
//
// See also https://github.com/ericniebler/range-v3/issues/1637
//
// For this specific application, we can use them interchangeably.
ranges::views::group_by([](const auto& p1, const auto& p2) { return p1.cluster == p2.cluster; }) | //
#else
ranges::views::chunk_by([](const auto& p1, const auto& p2) { return p1.cluster == p2.cluster; }) | //
#endif
ranges::views::cache1 | //
ranges::views::filter([](auto subrange) {
// If there's only one sample in the cluster we can't estimate the covariance.
return subrange.size() > 1;
}) |
ranges::views::transform([](auto subrange) {
auto states = subrange | ranges::views::transform(&Particle::state);
auto weights = subrange | ranges::views::transform(&Particle::weight);
const auto [mean, covariance] = estimate(states, weights);
const auto total_weight = ranges::accumulate(weights, 0.0);
return Estimate{total_weight, std::move(mean), std::move(covariance)};
}) |
ranges::to<std::vector>;
}

/// Computes a cluster-based estimate from a particle set.
Expand Down

0 comments on commit 054dd02

Please sign in to comment.