-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This patch adds a new random intersperse view to help with adaptive sampling and replace the `make_random_selector` implementation. This is a range adaptor object that will insert values from a generator function between contiguous elements from the source range based on a given probability. Signed-off-by: Nahuel Espinosa <[email protected]>
- Loading branch information
1 parent
4f080fe
commit b9867fc
Showing
3 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
// Copyright 2024 Ekumen, Inc. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef BELUGA_VIEWS_RANDOM_INTERSPERSE_HPP | ||
#define BELUGA_VIEWS_RANDOM_INTERSPERSE_HPP | ||
|
||
#include <functional> | ||
#include <optional> | ||
#include <random> | ||
|
||
#include <range/v3/functional/bind_back.hpp> | ||
#include <range/v3/utility/random.hpp> | ||
#include <range/v3/view/adaptor.hpp> | ||
|
||
/** | ||
* \file | ||
* \brief Implementation of a random_intersperse range adaptor object. | ||
*/ | ||
|
||
namespace beluga::views { | ||
|
||
namespace detail { | ||
|
||
/// Implementation of the random_intersperse view as a view adaptor. | ||
/** | ||
* \tparam Range A [forward range](https://en.cppreference.com/w/cpp/ranges/forward_range). | ||
* \tparam Fn A callable type which takes no arguments and returns values to be inserted. | ||
* \tparam URNG A random number generator that satisfies the | ||
* [UniformRandomBitGenerator](https://en.cppreference.com/w/cpp/named_req/UniformRandomBitGenerator) | ||
* requirements. | ||
*/ | ||
template <class Range, class Fn, class URNG = typename ranges::detail::default_random_engine> | ||
struct random_intersperse_view | ||
: public ranges::view_adaptor< | ||
random_intersperse_view<Range, Fn, URNG>, | ||
Range, | ||
// The cardinality value can be infinite, unknown, finite or a specific number if known. | ||
// If the adapted range cardinality is finite (or a specific number) then we know the | ||
// resulting view is finite. Else, we propagate the infinite or unknown specification. | ||
// Care must be taken when the adapted range is finite but the intersperse probability | ||
// is 1.0, leading to an infinite range in practice. | ||
ranges::is_finite<Range>::value ? ranges::finite : ranges::range_cardinality<Range>::value> { | ||
public: | ||
/// Default constructor. | ||
random_intersperse_view() = default; | ||
|
||
/// Construct the view from a range to be adapted. | ||
/** | ||
* \param range The range to be adapted. | ||
* \param fn The generator function that returns values to be inserted. | ||
* \param probability The probability of inserting a value on each iteration. | ||
* \param engine The random number generator object. | ||
*/ | ||
constexpr random_intersperse_view( | ||
Range range, | ||
Fn fn, | ||
double probability, | ||
URNG& engine = ranges::detail::get_random_engine()) | ||
: random_intersperse_view::view_adaptor{std::move(range)}, | ||
fn_{std::move(fn)}, | ||
distribution_{probability}, | ||
engine_{std::addressof(engine)} {} | ||
|
||
private: | ||
// `ranges::range_access` needs access to the adaptor members. | ||
friend ranges::range_access; | ||
|
||
using fn_return_type = decltype(std::declval<Fn>()()); | ||
static_assert(std::is_convertible_v<fn_return_type, ranges::range_value_t<Range>>); | ||
|
||
/// Adaptor subclass that implements the random_intersperse logic. | ||
struct adaptor : public ranges::adaptor_base { | ||
public: | ||
/// Default constructor. | ||
adaptor() = default; | ||
|
||
/// Construct an iterator adaptor from the parent view. | ||
constexpr explicit adaptor(random_intersperse_view* view) noexcept : view_(view) {} | ||
|
||
/// Return the inserted value or dereference the current iterator. | ||
[[nodiscard]] constexpr auto read(ranges::iterator_t<Range> it) const { return fn_return_.value_or(*it); } | ||
|
||
/// Generate a new value to be inserted or increment the input iterator. | ||
constexpr void next(ranges::iterator_t<Range>& it) { | ||
fn_return_.reset(); | ||
if (view_->should_intersperse()) { | ||
fn_return_ = view_->fn_(); | ||
} else { | ||
++it; | ||
} | ||
} | ||
|
||
void prev(ranges::iterator_t<Range>& it) = delete; | ||
void advance() = delete; | ||
void distance_to() = delete; | ||
|
||
private: | ||
random_intersperse_view* view_; | ||
std::optional<fn_return_type> fn_return_; | ||
}; | ||
|
||
/// Return the adaptor for the begin iterator. | ||
[[nodiscard]] constexpr auto begin_adaptor() { return adaptor{this}; } | ||
|
||
/// Return whether we should intersperse a value or increment the input iterator. | ||
[[nodiscard]] constexpr bool should_intersperse() { return distribution_(*engine_); } | ||
|
||
ranges::semiregular_box_t<Fn> fn_; | ||
std::bernoulli_distribution distribution_; | ||
URNG* engine_; | ||
}; | ||
|
||
/// Implementation detail for a random_intersperse range adaptor object. | ||
struct random_intersperse_fn { | ||
/// Default insertion probability on each iteration. | ||
static constexpr double kDefaultProbability = 0.5; | ||
|
||
/// Overload that implements the andom_intersperse algorithm. | ||
/** | ||
* \tparam Range A [forward range](https://en.cppreference.com/w/cpp/ranges/forward_range). | ||
* \tparam Fn A callable type which takes no arguments and returns values to be inserted. | ||
* \tparam URNG A random number generator that satisfies the | ||
* [UniformRandomBitGenerator](https://en.cppreference.com/w/cpp/named_req/UniformRandomBitGenerator) | ||
* requirements. | ||
* \param range The range to be adapted. | ||
* \param fn Fn instance used to insert values between source elements. | ||
* \param probability The probability of inserting a value on each iteration. | ||
* \param engine The random number generator object. | ||
*/ | ||
template <class Range, class Fn, class URNG = typename ranges::detail::default_random_engine> | ||
constexpr auto operator()( | ||
Range&& range, | ||
Fn fn, | ||
double probability = kDefaultProbability, | ||
URNG& engine = ranges::detail::get_random_engine()) const { | ||
return random_intersperse_view{ranges::views::all(std::forward<Range>(range)), std::move(fn), probability, engine}; | ||
} | ||
|
||
/// Overload that unwraps the engine reference from a view closure. | ||
template <class Range, class Fn, class URNG> | ||
constexpr auto operator()(Range&& range, Fn fn, double probability, std::reference_wrapper<URNG> engine) const { | ||
return (*this)(ranges::views::all(std::forward<Range>(range)), std::move(fn), probability, engine.get()); | ||
} | ||
|
||
/// Overload that returns a view closure to compose with other views. | ||
/** | ||
* \tparam Fn A callable type which takes no arguments and returns values to be inserted. | ||
* \tparam URNG A random number generator that satisfies the | ||
* [UniformRandomBitGenerator](https://en.cppreference.com/w/cpp/named_req/UniformRandomBitGenerator) | ||
* requirements. | ||
* \param fn Fn instance used to insert values between source elements. | ||
* \param probability The probability of inserting a value on each iteration. | ||
* \param engine The random number generator object. | ||
*/ | ||
template <class Fn, class URNG = typename ranges::detail::default_random_engine> | ||
constexpr auto operator()( | ||
Fn fn, | ||
double probability = kDefaultProbability, | ||
URNG& engine = ranges::detail::get_random_engine()) const { | ||
return ranges::make_view_closure( | ||
ranges::bind_back(random_intersperse_fn{}, std::move(fn), probability, std::ref(engine))); | ||
} | ||
}; | ||
|
||
} // namespace detail | ||
|
||
/// [Range adaptor object](https://en.cppreference.com/w/cpp/named_req/RangeAdaptorObject) that | ||
/// will insert values from a generator function between contiguous elements from the source based | ||
/// on a given probability. | ||
inline constexpr detail::random_intersperse_fn random_intersperse; | ||
|
||
} // namespace beluga::views | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
// Copyright 2024 Ekumen, Inc. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <gmock/gmock.h> | ||
|
||
#include "beluga/views/random_intersperse.hpp" | ||
|
||
#include <range/v3/algorithm/count.hpp> | ||
#include <range/v3/range/conversion.hpp> | ||
#include <range/v3/view/generate.hpp> | ||
#include <range/v3/view/iota.hpp> | ||
#include <range/v3/view/take.hpp> | ||
|
||
namespace { | ||
|
||
TEST(RandomIntersperseView, ConceptChecksFromContiguousRange) { | ||
auto input = std::array{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; | ||
auto output = beluga::views::random_intersperse(input, []() { return 0; }); | ||
|
||
static_assert(ranges::common_range<decltype(input)>); | ||
static_assert(!ranges::common_range<decltype(output)>); | ||
|
||
static_assert(!ranges::viewable_range<decltype(input)>); | ||
static_assert(ranges::viewable_range<decltype(output)>); | ||
|
||
static_assert(ranges::forward_range<decltype(input)>); | ||
static_assert(ranges::forward_range<decltype(output)>); | ||
|
||
static_assert(ranges::sized_range<decltype(input)>); | ||
static_assert(!ranges::sized_range<decltype(output)>); | ||
|
||
static_assert(ranges::bidirectional_range<decltype(input)>); | ||
static_assert(!ranges::bidirectional_range<decltype(output)>); | ||
|
||
static_assert(ranges::random_access_range<decltype(input)>); | ||
static_assert(!ranges::random_access_range<decltype(output)>); | ||
|
||
static_assert(ranges::contiguous_range<decltype(input)>); | ||
static_assert(!ranges::contiguous_range<decltype(output)>); | ||
} | ||
|
||
TEST(RandomIntersperseView, ConceptChecksFromInfiniteRange) { | ||
auto input = ranges::views::generate([]() { return 1; }); | ||
auto output = beluga::views::random_intersperse(input, []() { return 0; }); | ||
|
||
static_assert(!ranges::common_range<decltype(input)>); | ||
static_assert(!ranges::common_range<decltype(output)>); | ||
|
||
static_assert(ranges::viewable_range<decltype(input)>); | ||
static_assert(ranges::viewable_range<decltype(output)>); | ||
|
||
static_assert(!ranges::forward_range<decltype(input)>); | ||
static_assert(!ranges::forward_range<decltype(output)>); | ||
|
||
static_assert(!ranges::sized_range<decltype(input)>); | ||
static_assert(!ranges::sized_range<decltype(output)>); | ||
|
||
static_assert(!ranges::bidirectional_range<decltype(input)>); | ||
static_assert(!ranges::bidirectional_range<decltype(output)>); | ||
|
||
static_assert(!ranges::random_access_range<decltype(input)>); | ||
static_assert(!ranges::random_access_range<decltype(output)>); | ||
|
||
static_assert(!ranges::contiguous_range<decltype(input)>); | ||
static_assert(!ranges::contiguous_range<decltype(output)>); | ||
} | ||
|
||
TEST(RandomIntersperseView, GuaranteedIntersperseFirstElement) { | ||
auto input = std::array{10, 20, 30}; | ||
auto output = input | beluga::views::random_intersperse([i = 0]() mutable { return i++; }, 1.0); | ||
auto it = ranges::begin(output); | ||
ASSERT_EQ(*it, 10); // The first element is always from the input range | ||
} | ||
|
||
TEST(RandomIntersperseView, GuaranteedIntersperseDoubleDereference) { | ||
auto input = std::array{10, 20, 30}; | ||
auto output = input | beluga::views::random_intersperse([i = 0]() mutable { return i++; }, 1.0); | ||
auto it = ranges::begin(output); | ||
++it; | ||
ASSERT_EQ(*it, 0); | ||
ASSERT_EQ(*it, 0); | ||
++it; | ||
ASSERT_EQ(*it, 1); | ||
} | ||
|
||
TEST(RandomIntersperseView, GuaranteedIntersperseTakeFive) { | ||
auto input = std::array{10, 20, 30}; | ||
auto output = input | // | ||
beluga::views::random_intersperse([]() { return 4; }, 1.0) | // | ||
ranges::views::take(5) | // | ||
ranges::to<std::vector>; | ||
ASSERT_EQ(ranges::size(output), 5); | ||
ASSERT_THAT(output, testing::ElementsAre(10, 4, 4, 4, 4)); | ||
} | ||
|
||
TEST(RandomIntersperseView, ZeroProbabilityIntersperseTakeFive) { | ||
auto input = std::array{10, 20, 30}; | ||
auto output = input | // | ||
beluga::views::random_intersperse([]() { return 4; }, 0.0) | // | ||
ranges::views::take(5) | // | ||
ranges::to<std::vector>; | ||
ASSERT_EQ(ranges::size(output), 3); | ||
ASSERT_THAT(output, testing::ElementsAre(10, 20, 30)); | ||
} | ||
|
||
TEST(RandomIntersperseView, ZeroProbabilityMultipass) { | ||
auto input = std::array{10, 20, 30}; | ||
auto output = input | // | ||
beluga::views::random_intersperse([]() { return 4; }, 0.0) | // | ||
ranges::views::take(3); | ||
ASSERT_THAT(output | ranges::to<std::vector>, testing::ElementsAre(10, 20, 30)); | ||
ASSERT_THAT(output | ranges::to<std::vector>, testing::ElementsAre(10, 20, 30)); | ||
} | ||
|
||
class RandomIntersperseViewWithParam : public ::testing::TestWithParam<double> {}; | ||
|
||
TEST_P(RandomIntersperseViewWithParam, TestPercentage) { | ||
const double expected_p = GetParam(); | ||
const int size = 10'000; | ||
auto output = ranges::views::iota(1, size + 1) | beluga::views::random_intersperse([]() { return 0; }, expected_p); | ||
const double count = static_cast<double>(ranges::count(output, 0)); | ||
const double actual_p = count / (size + count); | ||
ASSERT_NEAR(expected_p, actual_p, 0.01); | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P( | ||
RandomIntersperseViewParams, | ||
RandomIntersperseViewWithParam, | ||
testing::Values(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)); | ||
|
||
} // namespace |