Skip to content

Commit

Permalink
Fix KLD condition
Browse files Browse the repository at this point in the history
  • Loading branch information
nahueespinosa committed Nov 12, 2022
1 parent a38afc6 commit efd7582
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 27 deletions.
16 changes: 8 additions & 8 deletions beluga/core/benchmark/sampling_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,16 @@ void BM_AdaptiveResample(benchmark::State& state) {
std::size_t max_samples = particle_count;
double resolution = 1.;
double kld_epsilon = 0.05;
double kld_upper_quantile = 0.95;
double kld_z = -1.28155156327703; // P = 0.9

for (auto _ : state) {
auto&& samples =
ranges::views::generate(
beluga::random_sample(beluga::views::all(container), beluga::views::weights(container), generator)) |
ranges::views::transform(beluga::set_cluster(resolution)) |
ranges::views::take_while(
beluga::kld_condition(min_samples, kld_epsilon, kld_upper_quantile), beluga::cluster<const Particle&>) |
ranges::views::take(max_samples) | ranges::views::common;
auto&& samples = ranges::views::generate(beluga::random_sample(
beluga::views::all(container), beluga::views::weights(container), generator)) |
ranges::views::transform(beluga::set_cluster(resolution)) |
ranges::views::take_while(
beluga::kld_condition(min_samples, kld_epsilon, kld_z),
beluga::cluster<const Particle&>) |
ranges::views::take(max_samples) | ranges::views::common;
auto first = std::begin(beluga::views::all(new_container));
auto last = std::copy(std::begin(samples), std::end(samples), first);
state.counters["SampleSize"] = std::distance(first, last);
Expand Down
22 changes: 15 additions & 7 deletions beluga/core/include/beluga/algorithm/sampling.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <functional>
#include <limits>
#include <random>
#include <unordered_set>

Expand Down Expand Up @@ -34,22 +35,29 @@ inline auto set_cluster(double resolution) {
};
}

inline auto kld_condition(std::size_t min, double epsilon = 0.05, double upper_quantile = 0.95) {
inline auto kld_condition(std::size_t min, double epsilon, double z = -1.28155156327703) {
// Compute minimum number of samples based on a Kullback-Leibler distance epsilon
// between the maximum likelihood estimate and the true distribution.
auto target_size = [upper_quantile, epsilon](std::size_t k) {
// Z is the upper_standard normal quantile for (1 - P), where P is the probability
// that the error in the estimated distribution will be less than epsilon.
// Here are some examples:
// P = 0.90 -> Z = -1.28155156327703
// P = 0.95 -> Z = -1.64485362793663
// P = 0.99 -> Z = -2.32634787735669
auto target_size = [two_epsilon = 2 * epsilon, z](std::size_t k) {
if (k <= 2U) {
return std::numeric_limits<std::size_t>::max();
}
double common = 2. / (9 * (k - 1));
double base = 1. - common - std::sqrt(common) * upper_quantile;
double result = ((k - 1) / epsilon) * base * base * base;
double base = 1. - common - std::sqrt(common) * z;
double result = ((k - 1) / two_epsilon) * base * base * base;
return static_cast<std::size_t>(std::ceil(result));
};

return [=, count = 0ULL, buckets = std::unordered_set<std::size_t>{}](std::size_t hash) mutable {
count++;
buckets.insert(hash);
auto cluster_count = buckets.size();
bool target_not_reached = cluster_count <= 2U ? true : count < target_size(cluster_count);
return count < min || target_not_reached;
return count < min || count < target_size(buckets.size());
};
}

Expand Down
34 changes: 22 additions & 12 deletions beluga/core/test/sampling_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ TEST(RandomSample, Functional) {
}
}

class KLDConditionWithParam : public ::testing::TestWithParam<std::pair<std::size_t, std::size_t>> {};
class KLDConditionWithParam : public ::testing::TestWithParam<std::tuple<double, std::size_t, std::size_t>> {};

TEST_P(KLDConditionWithParam, Minimum) {
const std::size_t cluster_count = GetParam().first;
const std::size_t cluster_count = std::get<1>(GetParam());
const std::size_t fixed_min_samples = 1'000;
auto predicate = beluga::kld_condition(fixed_min_samples);
auto predicate = beluga::kld_condition(fixed_min_samples, 0.01, 0.95);
std::size_t cluster = 0;
for (std::size_t i = 0; i < fixed_min_samples - 1; ++i) {
ASSERT_TRUE(predicate(cluster)) << "Stopped at " << i + 1 << " samples (Expected: " << fixed_min_samples << ").";
Expand All @@ -67,9 +67,10 @@ TEST_P(KLDConditionWithParam, Minimum) {
}

TEST_P(KLDConditionWithParam, Limit) {
const std::size_t cluster_count = GetParam().first;
const std::size_t min_samples = GetParam().second;
auto predicate = beluga::kld_condition(0, 0.05, 0.95);
const double kld_k = std::get<0>(GetParam());
const std::size_t cluster_count = std::get<1>(GetParam());
const std::size_t min_samples = std::get<2>(GetParam());
auto predicate = beluga::kld_condition(0, 0.01, kld_k);
std::size_t cluster = 0;
for (std::size_t i = 0; i < min_samples - 1; ++i) {
ASSERT_TRUE(predicate(cluster)) << "Stopped at " << i + 1 << " samples (Expected: " << min_samples << ").";
Expand All @@ -80,15 +81,24 @@ TEST_P(KLDConditionWithParam, Limit) {
ASSERT_FALSE(predicate(cluster)) << "Didn't stop at " << min_samples << " samples.";
}

constexpr double kPercentile90th = -1.28155156327703;
constexpr double kPercentile99th = -2.32634787735669;

INSTANTIATE_TEST_SUITE_P(
KLDPairs,
KLDConditionWithParam,
testing::Values(
std::make_pair(3, 8),
std::make_pair(4, 18),
std::make_pair(5, 30),
std::make_pair(6, 44),
std::make_pair(7, 57),
std::make_pair(100, 1'713)));
std::make_tuple(kPercentile90th, 3, 228),
std::make_tuple(kPercentile90th, 4, 311),
std::make_tuple(kPercentile90th, 5, 388),
std::make_tuple(kPercentile90th, 6, 461),
std::make_tuple(kPercentile90th, 7, 531),
std::make_tuple(kPercentile90th, 100, 5871),
std::make_tuple(kPercentile99th, 3, 462),
std::make_tuple(kPercentile99th, 4, 569),
std::make_tuple(kPercentile99th, 5, 666),
std::make_tuple(kPercentile99th, 6, 756),
std::make_tuple(kPercentile99th, 7, 843),
std::make_tuple(kPercentile99th, 100, 6733)));

} // namespace

0 comments on commit efd7582

Please sign in to comment.