Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
glpuga committed Jan 2, 2024
1 parent 0370372 commit d2e4b84
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 69 deletions.
131 changes: 65 additions & 66 deletions beluga/include/beluga/estimation/cluster_based_estimator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

// external
#include <range/v3/action/sort.hpp>
#include <range/v3/algorithm/max_element.hpp>
#include <range/v3/numeric/accumulate.hpp>
#include <range/v3/range/conversion.hpp>
#include <range/v3/view/cache1.hpp>
#include <range/v3/view/filter.hpp>
#include <range/v3/view/zip.hpp>
#include <sophus/se2.hpp>
Expand All @@ -49,38 +51,38 @@ namespace beluga {

namespace cse_detail {

/// @brief A struct that holds the data of a single cell in the grid.
/// \brief A struct that holds the data of a single cell in the grid.
struct GridCellData {
double weight{0.0}; ///< average weight of the cell
std::size_t num_particles{0}; ///< number of particles in the cell
Sophus::SE2d representative_pose_in_world; ///< state of a particle that is within the cell
std::optional<std::size_t> cluster_id; ///< cluster id of the cell
};

/// @brief A map that holds the sparse data about the particles grouped in cells. Used by the clusterization algorithm.
/// \brief A map that holds the sparse data about the particles grouped in cells. Used by the clusterization algorithm.
using GridCellDataMap2D = std::unordered_map<std::size_t, GridCellData>;

/// @brief Function that creates an vector containing the hashes of each of the states in the input range.
/// @tparam Range Type of the states range.
/// @tparam Hasher Hash function type to convert states into hashes.
/// @param states The range of states.
/// @param spatial_hash_function_ The hash object instance.
/// @return A vector containing the hashes of each of the states in the input range.
/// \brief Function that creates an vector containing the hashes of each of the states in the input range.
/// \tparam Range Type of the states range.
/// \tparam Hasher Hash function type to convert states into hashes.
/// \param states The range of states.
/// \param spatial_hash_function_ The hash object instance.
/// \return A vector containing the hashes of each of the states in the input range.
template <class Range, class Hasher>
[[nodiscard]] auto precalculate_particle_hashes(Range&& states, const Hasher& spatial_hash_function_) {
const auto state_to_range = [&](const auto& state) { return spatial_hash_function_(state); };
return states | ranges::views::transform(state_to_range) | ranges::to<std::vector<std::size_t>>();
const auto state_to_hash = [&](const auto& state) { return spatial_hash_function_(state); };
return states | ranges::views::transform(state_to_hash) | ranges::to<std::vector<std::size_t>>();
}

/// @brief Populate the grid cell data map with the data from the particles and their weights.
/// @tparam GridCellDataMapType Type of the grid cell data map.
/// @tparam Range Type of the states range.
/// @tparam Weights Type of the weights range.
/// @tparam Hashes Type of the hashes range.
/// @param states Range containing the states of the particles.
/// @param weights Range containing the weights of the particles.
/// @param hashes Range containing the hashes of the particles.
/// @return New instance of the grid cell data map populated with the information from the states.
/// \brief Populate the grid cell data map with the data from the particles and their weights.
/// \tparam GridCellDataMapType Type of the grid cell data map.
/// \tparam Range Type of the states range.
/// \tparam Weights Type of the weights range.
/// \tparam Hashes Type of the hashes range.
/// \param states Range containing the states of the particles.
/// \param weights Range containing the weights of the particles.
/// \param hashes Range containing the hashes of the particles.
/// \return New instance of the grid cell data map populated with the information from the states.
template <class GridCellDataMapType, class Range, class Weights, class Hashes>
[[nodiscard]] auto populate_grid_cell_data_from_particles(Range&& states, Weights&& weights, const Hashes& hashes) {
GridCellDataMapType grid_cell_data;
Expand Down Expand Up @@ -108,11 +110,11 @@ template <class GridCellDataMapType, class Range, class Weights, class Hashes>
return grid_cell_data;
}

/// @brief Calculate the weight threshold that corresponds to a given percentile of the weights.
/// @tparam GridCellDataMapType Type of the grid cell data map.
/// @param grid_cell_data The grid cell data map.
/// @param threshold The percentile of the weights to calculate the threshold for (range: 0.0 to 1.0)
/// @return Threshold value that corresponds to the given percentile of the weights.
/// \brief Calculate the weight threshold that corresponds to a given percentile of the weights.
/// \tparam GridCellDataMapType Type of the grid cell data map.
/// \param grid_cell_data The grid cell data map.
/// \param threshold The percentile of the weights to calculate the threshold for (range: 0.0 to 1.0)
/// \return Threshold value that corresponds to the given percentile of the weights.
template <class GridCellDataMapType>
[[nodiscard]] auto calculate_percentile_weight_threshold(GridCellDataMapType&& grid_cell_data, double threshold) {
const auto extract_weight_f = [](const auto& grid_cell) { return grid_cell.second.weight; };
Expand All @@ -121,10 +123,10 @@ template <class GridCellDataMapType>
return weights[static_cast<std::size_t>(static_cast<double>(weights.size()) * threshold)];
}

/// @brief Cap the weight of each cell in the grid cell data map to a given value.
/// @tparam GridCellDataMapType Type of the grid cell data map.
/// @param grid_cell_data The grid cell data map.
/// @param weight_cap The maximum weight value to be assigned to each cell.
/// \brief Cap the weight of each cell in the grid cell data map to a given value.
/// \tparam GridCellDataMapType Type of the grid cell data map.
/// \param grid_cell_data The grid cell data map.
/// \param weight_cap The maximum weight value to be assigned to each cell.
template <class GridCellDataMapType>
void cap_grid_cell_data_weights(GridCellDataMapType&& grid_cell_data, double weight_cap) {
for (auto& [hash, entry] : grid_cell_data) {
Expand All @@ -133,10 +135,10 @@ void cap_grid_cell_data_weights(GridCellDataMapType&& grid_cell_data, double wei
}
}

/// @brief Creates the priority queue used by the clustering information from the grid cell data map.
/// @tparam GridCellDataMapType Type of the grid cell data map.
/// @param grid_cell_data The grid cell data map.
/// @return A priority queue containing the information from the grid cell data map.
/// \brief Creates the priority queue used by the clustering information from the grid cell data map.
/// \tparam GridCellDataMapType Type of the grid cell data map.
/// \param grid_cell_data The grid cell data map.
/// \return A priority queue containing the information from the grid cell data map.
template <class GridCellDataMapType>
[[nodiscard]] auto populate_priority_queue(GridCellDataMapType&& grid_cell_data) {
struct PriorityQueueItem {
Expand All @@ -161,14 +163,14 @@ template <class GridCellDataMapType>
PriorityQueueItemCompare{}, std::move(queue_container));
}

/// @brief Function that runs the clustering algorithm and assigns a cluster id to each cell in the grid cell data map.
/// @tparam GridCellDataMapType Type of the grid cell data map.
/// @tparam Hasher Type of the hash function used to convert states into hashes.
/// @tparam Neighbors Type of the range containing the neighbors of a cell.
/// @param grid_cell_data The grid cell data map.
/// @param spatial_hash_function_ The hash object instance.
/// @param neighbors Range containing the neighbors of a cell.
/// @param weight_cap The maximum weight value to be assigned to each cell.
/// \brief Function that runs the clustering algorithm and assigns a cluster id to each cell in the grid cell data map.
/// \tparam GridCellDataMapType Type of the grid cell data map.
/// \tparam Hasher Type of the hash function used to convert states into hashes.
/// \tparam Neighbors Type of the range containing the neighbors of a cell.
/// \param grid_cell_data The grid cell data map.
/// \param spatial_hash_function_ The hash object instance.
/// \param neighbors Range containing the neighbors of a cell.
/// \param weight_cap The maximum weight value to be assigned to each cell.
template <class GridCellDataMapType, class Hasher, class Neighbors>
void map_cells_to_clusters(
GridCellDataMapType&& grid_cell_data,
Expand Down Expand Up @@ -213,6 +215,7 @@ void map_cells_to_clusters(
auto valid_neighbor_hashes_view = //
neighbors | //
ranges::views::transform(get_neighbor_hash) | //
ranges::views::cache1 | //
ranges::views::filter(filter_invalid_neighbors); //

for (const auto& neighbor_hash : valid_neighbor_hashes_view) {
Expand All @@ -226,19 +229,17 @@ void map_cells_to_clusters(
}
}

/// @brief For each cluster, estimate the mean and covariance of the states that belong to it.
/// @tparam StateType Type used for the pose estimation.
/// @tparam CovarianceType Type used for the covariance estimation.
/// @tparam GridCellDataMapType Type of the grid cell data map.
/// @tparam Range Range type of the states.
/// @tparam Weights Range type of the weights.
/// @tparam Hashes Range type of the hashes.
/// @param grid_cell_data Grid cell data map that has already been processed by the clustering algorithm.
/// @param states Range containing the states of the particles.
/// @param weights Range containing the weights of the particles.
/// @param hashes Range containing the hashes of the particles.
/// @return A vector of tuples, containing the weight, mean and covariance of each cluster, in no particular order.
template <class StateType, class CovarianceType, class GridCellDataMapType, class Range, class Weights, class Hashes>
/// \brief For each cluster, estimate the mean and covariance of the states that belong to it.
/// \tparam GridCellDataMapType Type of the grid cell data map.
/// \tparam Range Range type of the states.
/// \tparam Weights Range type of the weights.
/// \tparam Hashes Range type of the hashes.
/// \param grid_cell_data Grid cell data map that has already been processed by the clustering algorithm.
/// \param states Range containing the states of the particles.
/// \param weights Range containing the weights of the particles.
/// \param hashes Range containing the hashes of the particles.
/// \return A vector of tuples, containing the weight, mean and covariance of each cluster, in no particular order.
template <class GridCellDataMapType, class Range, class Weights, class Hashes>
[[nodiscard]] auto
estimate_clusters(GridCellDataMapType&& grid_cell_data, Range&& states, Weights&& weights, Hashes&& hashes) {
struct ClusterInfo {
Expand All @@ -250,16 +251,16 @@ estimate_clusters(GridCellDataMapType&& grid_cell_data, Range&& states, Weights&
// find out the weight of each of the clusters, along with how many clusters there are
for (const auto& [weight, hash] : ranges::views::zip(weights, hashes)) {
const auto& grid_cell = grid_cell_data[hash];
if (!grid_cell.cluster_id.has_value()) {
continue;
}
auto [it, inserted] = cluster_info.try_emplace(*grid_cell.cluster_id, ClusterInfo{});
it->second.total_weight += weight;
++it->second.particle_count;
}

// for each cluster found, estimate the mean and covariance of the states that belong to it
using StateType = std::decay_t<decltype(std::get<0>(beluga::estimate(states, weights)))>;
using CovarianceType = std::decay_t<decltype(std::get<1>(beluga::estimate(states, weights)))>;
std::vector<std::tuple<double, StateType, CovarianceType>> cluster_estimates;

const auto mask_filter = [](const auto& sample) { return std::get<1>(sample); };

for (auto& [cluster_id, info] : cluster_info) {
Expand Down Expand Up @@ -293,7 +294,7 @@ estimate_clusters(GridCellDataMapType&& grid_cell_data, Range&& states, Weights&

/// Primary template for a cluster-based estimation algorithm.
/**
* Particles are groups into clusters around local maximums and the mean and covariance of the cluster with the highest
* Particles are groups into clusters around local maxima and the mean and covariance of the cluster with the highest
* weight is returned.
*
* This class implements the EstimationInterface interface
Expand Down Expand Up @@ -321,14 +322,14 @@ class ClusterBasedStateEstimator : public Mixin {
static constexpr double kSpatialHashResolution = 0.20; ///< clustering algorithm spatial resolution
static constexpr double kAngularHashResolution = 0.524; ///< clustering algorithm angular resolution

/// @brief spatial hash function used to group particles in cells
/// \brief spatial hash function used to group particles in cells
const beluga::spatial_hash<Sophus::SE2d> spatial_hash_function_{
kSpatialHashResolution, // x
kSpatialHashResolution, // y
kAngularHashResolution // theta
};

/// @brief Adjacent grid cells used by the clustering algorithm.
/// \brief Adjacent grid cells used by the clustering algorithm.
const std::vector<Sophus::SE2d> adjacent_grid_cells_ = {
Sophus::SE2d{Sophus::SO2d{0.0}, Sophus::Vector2d{+kSpatialHashResolution, 0.0}},
Sophus::SE2d{Sophus::SO2d{0.0}, Sophus::Vector2d{-kSpatialHashResolution, 0.0}},
Expand All @@ -348,19 +349,17 @@ std::pair<StateType, CovarianceType> ClusterBasedStateEstimator<Mixin, StateType
cse_detail::cap_grid_cell_data_weights(grid_cell_data, weight_cap);
cse_detail::map_cells_to_clusters(grid_cell_data, spatial_hash_function_, adjacent_grid_cells_, weight_cap);

auto per_cluster_estimates = cse_detail::estimate_clusters<StateType, CovarianceType>(
grid_cell_data, this->self().states(), this->self().weights(), hashes);
auto per_cluster_estimates =
cse_detail::estimate_clusters(grid_cell_data, this->self().states(), this->self().weights(), hashes);

if (per_cluster_estimates.empty()) {
// hmmm... maybe the particles are too fragmented? calculate the overall mean and covariance
return beluga::estimate(this->self().states(), this->self().weights());
}

// order by decreasing weight and return the first one
std::sort(per_cluster_estimates.begin(), per_cluster_estimates.end(), [](const auto& a, const auto& b) {
return std::get<0>(a) > std::get<0>(b);
});
return {std::get<1>(per_cluster_estimates[0]), std::get<2>(per_cluster_estimates[0])};
const auto [_, mean, covariance] =
*ranges::max_element(per_cluster_estimates, std::less{}, [](const auto& t) { return std::get<0>(t); });
return {mean, covariance};
}

/// An alias template for the simple state estimator in 2D.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct ClusterBasedEstimationDetailTesting : public testing::Test {
const auto xwidth = xmax - xmin;
const auto ywidth = ymax - ymin;

// simulate particles in a grid with (phases x phases) clusters with
// simulate particles in a grid with 4 (2x2) clusters with
// different peak heights. The highest on is the one located on the
// higher-right.

Expand Down Expand Up @@ -246,7 +246,7 @@ TEST_F(ClusterBasedEstimationDetailTesting, MapGridCellsToClustersStep) {
return spatial_hash_function(state);
};

const auto hash_to_id = [&](const auto& hash) { return map[hash].cluster_id ? map[hash].cluster_id.value() : 9999; };
const auto hash_to_id = [&](const auto& hash) { return map[hash].cluster_id.value(); };

auto quadrant_1_unique_ids = quadrant_1_view | //
ranges::views::transform(coord_to_hash) | //
Expand Down Expand Up @@ -307,7 +307,7 @@ TEST_F(ClusterBasedEstimationDetailTesting, ClusterStatEstimationStep) {
cap_grid_cell_data_weights(grid_cell_data, weight_cap);
map_cells_to_clusters(grid_cell_data, spatial_hash_function, adjacent_grid_cells, weight_cap);

auto per_cluster_estimates = estimate_clusters<SE2d, Matrix3d>(grid_cell_data, states, weights, hashes);
auto per_cluster_estimates = estimate_clusters(grid_cell_data, states, weights, hashes);

// check that the number of clusters is correct
ASSERT_EQ(per_cluster_estimates.size(), 4);
Expand Down

0 comments on commit d2e4b84

Please sign in to comment.