From efd75828af018f490131a7a7b8fd62c3997473a6 Mon Sep 17 00:00:00 2001 From: Nahuel Espinosa Date: Sat, 12 Nov 2022 12:49:03 -0300 Subject: [PATCH] Fix KLD condition --- beluga/core/benchmark/sampling_benchmark.cpp | 16 ++++----- .../core/include/beluga/algorithm/sampling.h | 22 ++++++++---- beluga/core/test/sampling_test.cpp | 34 ++++++++++++------- 3 files changed, 45 insertions(+), 27 deletions(-) diff --git a/beluga/core/benchmark/sampling_benchmark.cpp b/beluga/core/benchmark/sampling_benchmark.cpp index 65f0f3fba..1862ad0fd 100644 --- a/beluga/core/benchmark/sampling_benchmark.cpp +++ b/beluga/core/benchmark/sampling_benchmark.cpp @@ -78,16 +78,16 @@ void BM_AdaptiveResample(benchmark::State& state) { std::size_t max_samples = particle_count; double resolution = 1.; double kld_epsilon = 0.05; - double kld_upper_quantile = 0.95; + double kld_z = -1.28155156327703; // P = 0.9 for (auto _ : state) { - auto&& samples = - ranges::views::generate( - beluga::random_sample(beluga::views::all(container), beluga::views::weights(container), generator)) | - ranges::views::transform(beluga::set_cluster(resolution)) | - ranges::views::take_while( - beluga::kld_condition(min_samples, kld_epsilon, kld_upper_quantile), beluga::cluster) | - ranges::views::take(max_samples) | ranges::views::common; + auto&& samples = ranges::views::generate(beluga::random_sample( + beluga::views::all(container), beluga::views::weights(container), generator)) | + ranges::views::transform(beluga::set_cluster(resolution)) | + ranges::views::take_while( + beluga::kld_condition(min_samples, kld_epsilon, kld_z), + beluga::cluster) | + ranges::views::take(max_samples) | ranges::views::common; auto first = std::begin(beluga::views::all(new_container)); auto last = std::copy(std::begin(samples), std::end(samples), first); state.counters["SampleSize"] = std::distance(first, last); diff --git a/beluga/core/include/beluga/algorithm/sampling.h b/beluga/core/include/beluga/algorithm/sampling.h index 43caa1b8d..46a78b108 100644 --- a/beluga/core/include/beluga/algorithm/sampling.h +++ b/beluga/core/include/beluga/algorithm/sampling.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -34,22 +35,29 @@ inline auto set_cluster(double resolution) { }; } -inline auto kld_condition(std::size_t min, double epsilon = 0.05, double upper_quantile = 0.95) { +inline auto kld_condition(std::size_t min, double epsilon, double z = -1.28155156327703) { // Compute minimum number of samples based on a Kullback-Leibler distance epsilon // between the maximum likelihood estimate and the true distribution. - auto target_size = [upper_quantile, epsilon](std::size_t k) { + // Z is the upper_standard normal quantile for (1 - P), where P is the probability + // that the error in the estimated distribution will be less than epsilon. + // Here are some examples: + // P = 0.90 -> Z = -1.28155156327703 + // P = 0.95 -> Z = -1.64485362793663 + // P = 0.99 -> Z = -2.32634787735669 + auto target_size = [two_epsilon = 2 * epsilon, z](std::size_t k) { + if (k <= 2U) { + return std::numeric_limits::max(); + } double common = 2. / (9 * (k - 1)); - double base = 1. - common - std::sqrt(common) * upper_quantile; - double result = ((k - 1) / epsilon) * base * base * base; + double base = 1. - common - std::sqrt(common) * z; + double result = ((k - 1) / two_epsilon) * base * base * base; return static_cast(std::ceil(result)); }; return [=, count = 0ULL, buckets = std::unordered_set{}](std::size_t hash) mutable { count++; buckets.insert(hash); - auto cluster_count = buckets.size(); - bool target_not_reached = cluster_count <= 2U ? true : count < target_size(cluster_count); - return count < min || target_not_reached; + return count < min || count < target_size(buckets.size()); }; } diff --git a/beluga/core/test/sampling_test.cpp b/beluga/core/test/sampling_test.cpp index 1e6ac5e09..d729bed96 100644 --- a/beluga/core/test/sampling_test.cpp +++ b/beluga/core/test/sampling_test.cpp @@ -51,12 +51,12 @@ TEST(RandomSample, Functional) { } } -class KLDConditionWithParam : public ::testing::TestWithParam> {}; +class KLDConditionWithParam : public ::testing::TestWithParam> {}; TEST_P(KLDConditionWithParam, Minimum) { - const std::size_t cluster_count = GetParam().first; + const std::size_t cluster_count = std::get<1>(GetParam()); const std::size_t fixed_min_samples = 1'000; - auto predicate = beluga::kld_condition(fixed_min_samples); + auto predicate = beluga::kld_condition(fixed_min_samples, 0.01, 0.95); std::size_t cluster = 0; for (std::size_t i = 0; i < fixed_min_samples - 1; ++i) { ASSERT_TRUE(predicate(cluster)) << "Stopped at " << i + 1 << " samples (Expected: " << fixed_min_samples << ")."; @@ -67,9 +67,10 @@ TEST_P(KLDConditionWithParam, Minimum) { } TEST_P(KLDConditionWithParam, Limit) { - const std::size_t cluster_count = GetParam().first; - const std::size_t min_samples = GetParam().second; - auto predicate = beluga::kld_condition(0, 0.05, 0.95); + const double kld_k = std::get<0>(GetParam()); + const std::size_t cluster_count = std::get<1>(GetParam()); + const std::size_t min_samples = std::get<2>(GetParam()); + auto predicate = beluga::kld_condition(0, 0.01, kld_k); std::size_t cluster = 0; for (std::size_t i = 0; i < min_samples - 1; ++i) { ASSERT_TRUE(predicate(cluster)) << "Stopped at " << i + 1 << " samples (Expected: " << min_samples << ")."; @@ -80,15 +81,24 @@ TEST_P(KLDConditionWithParam, Limit) { ASSERT_FALSE(predicate(cluster)) << "Didn't stop at " << min_samples << " samples."; } +constexpr double kPercentile90th = -1.28155156327703; +constexpr double kPercentile99th = -2.32634787735669; + INSTANTIATE_TEST_SUITE_P( KLDPairs, KLDConditionWithParam, testing::Values( - std::make_pair(3, 8), - std::make_pair(4, 18), - std::make_pair(5, 30), - std::make_pair(6, 44), - std::make_pair(7, 57), - std::make_pair(100, 1'713))); + std::make_tuple(kPercentile90th, 3, 228), + std::make_tuple(kPercentile90th, 4, 311), + std::make_tuple(kPercentile90th, 5, 388), + std::make_tuple(kPercentile90th, 6, 461), + std::make_tuple(kPercentile90th, 7, 531), + std::make_tuple(kPercentile90th, 100, 5871), + std::make_tuple(kPercentile99th, 3, 462), + std::make_tuple(kPercentile99th, 4, 569), + std::make_tuple(kPercentile99th, 5, 666), + std::make_tuple(kPercentile99th, 6, 756), + std::make_tuple(kPercentile99th, 7, 843), + std::make_tuple(kPercentile99th, 100, 6733))); } // namespace