Skip to content

Commit

Permalink
Add random intersperse view
Browse files Browse the repository at this point in the history
This patch adds a new random intersperse view to help with adaptive
sampling and replace the `make_random_selector` implementation. This is
a range adaptor object that will insert values from a generator function
between contiguous elements from the source range based on a given
probability.

Signed-off-by: Nahuel Espinosa <[email protected]>
  • Loading branch information
nahueespinosa committed Jan 1, 2024
1 parent 4f080fe commit b9867fc
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 0 deletions.
185 changes: 185 additions & 0 deletions beluga/include/beluga/views/random_intersperse.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Copyright 2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BELUGA_VIEWS_RANDOM_INTERSPERSE_HPP
#define BELUGA_VIEWS_RANDOM_INTERSPERSE_HPP

#include <functional>
#include <optional>
#include <random>

#include <range/v3/functional/bind_back.hpp>
#include <range/v3/utility/random.hpp>
#include <range/v3/view/adaptor.hpp>

/**
* \file
* \brief Implementation of a random_intersperse range adaptor object.
*/

namespace beluga::views {

namespace detail {

/// Implementation of the random_intersperse view as a view adaptor.
/**
* \tparam Range A [forward range](https://en.cppreference.com/w/cpp/ranges/forward_range).
* \tparam Fn A callable type which takes no arguments and returns values to be inserted.
* \tparam URNG A random number generator that satisfies the
* [UniformRandomBitGenerator](https://en.cppreference.com/w/cpp/named_req/UniformRandomBitGenerator)
* requirements.
*/
template <class Range, class Fn, class URNG = typename ranges::detail::default_random_engine>
struct random_intersperse_view
: public ranges::view_adaptor<
random_intersperse_view<Range, Fn, URNG>,
Range,
// The cardinality value can be infinite, unknown, finite or a specific number if known.
// If the adapted range cardinality is finite (or a specific number) then we know the
// resulting view is finite. Else, we propagate the infinite or unknown specification.
// Care must be taken when the adapted range is finite but the intersperse probability
// is 1.0, leading to an infinite range in practice.
ranges::is_finite<Range>::value ? ranges::finite : ranges::range_cardinality<Range>::value> {
public:
/// Default constructor.
random_intersperse_view() = default;

/// Construct the view from a range to be adapted.
/**
* \param range The range to be adapted.
* \param fn The generator function that returns values to be inserted.
* \param probability The probability of inserting a value on each iteration.
* \param engine The random number generator object.
*/
constexpr random_intersperse_view(
Range range,
Fn fn,
double probability,
URNG& engine = ranges::detail::get_random_engine())
: random_intersperse_view::view_adaptor{std::move(range)},
fn_{std::move(fn)},
distribution_{probability},
engine_{std::addressof(engine)} {}

private:
// `ranges::range_access` needs access to the adaptor members.
friend ranges::range_access;

using fn_return_type = decltype(std::declval<Fn>()());
static_assert(std::is_convertible_v<fn_return_type, ranges::range_value_t<Range>>);

/// Adaptor subclass that implements the random_intersperse logic.
struct adaptor : public ranges::adaptor_base {
public:
/// Default constructor.
adaptor() = default;

/// Construct an iterator adaptor from the parent view.
constexpr explicit adaptor(random_intersperse_view* view) noexcept : view_(view) {}

/// Return the inserted value or dereference the current iterator.
[[nodiscard]] constexpr auto read(ranges::iterator_t<Range> it) const { return fn_return_.value_or(*it); }

/// Generate a new value to be inserted or increment the input iterator.
constexpr void next(ranges::iterator_t<Range>& it) {
fn_return_.reset();
if (view_->should_intersperse()) {
fn_return_ = view_->fn_();
} else {
++it;
}
}

void prev(ranges::iterator_t<Range>& it) = delete;
void advance() = delete;
void distance_to() = delete;

private:
random_intersperse_view* view_;
std::optional<fn_return_type> fn_return_;
};

/// Return the adaptor for the begin iterator.
[[nodiscard]] constexpr auto begin_adaptor() { return adaptor{this}; }

/// Return whether we should intersperse a value or increment the input iterator.
[[nodiscard]] constexpr bool should_intersperse() { return distribution_(*engine_); }

ranges::semiregular_box_t<Fn> fn_;
std::bernoulli_distribution distribution_;
URNG* engine_;
};

/// Implementation detail for a random_intersperse range adaptor object.
struct random_intersperse_fn {
/// Default insertion probability on each iteration.
static constexpr double kDefaultProbability = 0.5;

/// Overload that implements the andom_intersperse algorithm.
/**
* \tparam Range A [forward range](https://en.cppreference.com/w/cpp/ranges/forward_range).
* \tparam Fn A callable type which takes no arguments and returns values to be inserted.
* \tparam URNG A random number generator that satisfies the
* [UniformRandomBitGenerator](https://en.cppreference.com/w/cpp/named_req/UniformRandomBitGenerator)
* requirements.
* \param range The range to be adapted.
* \param fn Fn instance used to insert values between source elements.
* \param probability The probability of inserting a value on each iteration.
* \param engine The random number generator object.
*/
template <class Range, class Fn, class URNG = typename ranges::detail::default_random_engine>
constexpr auto operator()(
Range&& range,
Fn fn,
double probability = kDefaultProbability,
URNG& engine = ranges::detail::get_random_engine()) const {
return random_intersperse_view{ranges::views::all(std::forward<Range>(range)), std::move(fn), probability, engine};
}

/// Overload that unwraps the engine reference from a view closure.
template <class Range, class Fn, class URNG>
constexpr auto operator()(Range&& range, Fn fn, double probability, std::reference_wrapper<URNG> engine) const {
return (*this)(ranges::views::all(std::forward<Range>(range)), std::move(fn), probability, engine.get());
}

/// Overload that returns a view closure to compose with other views.
/**
* \tparam Fn A callable type which takes no arguments and returns values to be inserted.
* \tparam URNG A random number generator that satisfies the
* [UniformRandomBitGenerator](https://en.cppreference.com/w/cpp/named_req/UniformRandomBitGenerator)
* requirements.
* \param fn Fn instance used to insert values between source elements.
* \param probability The probability of inserting a value on each iteration.
* \param engine The random number generator object.
*/
template <class Fn, class URNG = typename ranges::detail::default_random_engine>
constexpr auto operator()(
Fn fn,
double probability = kDefaultProbability,
URNG& engine = ranges::detail::get_random_engine()) const {
return ranges::make_view_closure(
ranges::bind_back(random_intersperse_fn{}, std::move(fn), probability, std::ref(engine)));
}
};

} // namespace detail

/// [Range adaptor object](https://en.cppreference.com/w/cpp/named_req/RangeAdaptorObject) that
/// will insert values from a generator function between contiguous elements from the source based
/// on a given probability.
inline constexpr detail::random_intersperse_fn random_intersperse;

} // namespace beluga::views

#endif
1 change: 1 addition & 0 deletions beluga/test/beluga/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ add_executable(
test_storage.cpp
test_tuple_vector.cpp
type_traits/test_strongly_typed_numeric.cpp
views/test_random_intersperse.cpp
views/test_take_evenly.cpp)

target_link_libraries(
Expand Down
142 changes: 142 additions & 0 deletions beluga/test/beluga/views/test_random_intersperse.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright 2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gmock/gmock.h>

#include "beluga/views/random_intersperse.hpp"

#include <range/v3/algorithm/count.hpp>
#include <range/v3/range/conversion.hpp>
#include <range/v3/view/generate.hpp>
#include <range/v3/view/iota.hpp>
#include <range/v3/view/take.hpp>

namespace {

TEST(RandomIntersperseView, ConceptChecksFromContiguousRange) {
auto input = std::array{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
auto output = beluga::views::random_intersperse(input, []() { return 0; });

static_assert(ranges::common_range<decltype(input)>);
static_assert(!ranges::common_range<decltype(output)>);

static_assert(!ranges::viewable_range<decltype(input)>);
static_assert(ranges::viewable_range<decltype(output)>);

static_assert(ranges::forward_range<decltype(input)>);
static_assert(ranges::forward_range<decltype(output)>);

static_assert(ranges::sized_range<decltype(input)>);
static_assert(!ranges::sized_range<decltype(output)>);

static_assert(ranges::bidirectional_range<decltype(input)>);
static_assert(!ranges::bidirectional_range<decltype(output)>);

static_assert(ranges::random_access_range<decltype(input)>);
static_assert(!ranges::random_access_range<decltype(output)>);

static_assert(ranges::contiguous_range<decltype(input)>);
static_assert(!ranges::contiguous_range<decltype(output)>);
}

TEST(RandomIntersperseView, ConceptChecksFromInfiniteRange) {
auto input = ranges::views::generate([]() { return 1; });
auto output = beluga::views::random_intersperse(input, []() { return 0; });

static_assert(!ranges::common_range<decltype(input)>);
static_assert(!ranges::common_range<decltype(output)>);

static_assert(ranges::viewable_range<decltype(input)>);
static_assert(ranges::viewable_range<decltype(output)>);

static_assert(!ranges::forward_range<decltype(input)>);
static_assert(!ranges::forward_range<decltype(output)>);

static_assert(!ranges::sized_range<decltype(input)>);
static_assert(!ranges::sized_range<decltype(output)>);

static_assert(!ranges::bidirectional_range<decltype(input)>);
static_assert(!ranges::bidirectional_range<decltype(output)>);

static_assert(!ranges::random_access_range<decltype(input)>);
static_assert(!ranges::random_access_range<decltype(output)>);

static_assert(!ranges::contiguous_range<decltype(input)>);
static_assert(!ranges::contiguous_range<decltype(output)>);
}

TEST(RandomIntersperseView, GuaranteedIntersperseFirstElement) {
auto input = std::array{10, 20, 30};
auto output = input | beluga::views::random_intersperse([i = 0]() mutable { return i++; }, 1.0);
auto it = ranges::begin(output);
ASSERT_EQ(*it, 10); // The first element is always from the input range
}

TEST(RandomIntersperseView, GuaranteedIntersperseDoubleDereference) {
auto input = std::array{10, 20, 30};
auto output = input | beluga::views::random_intersperse([i = 0]() mutable { return i++; }, 1.0);
auto it = ranges::begin(output);
++it;
ASSERT_EQ(*it, 0);
ASSERT_EQ(*it, 0);
++it;
ASSERT_EQ(*it, 1);
}

TEST(RandomIntersperseView, GuaranteedIntersperseTakeFive) {
auto input = std::array{10, 20, 30};
auto output = input | //
beluga::views::random_intersperse([]() { return 4; }, 1.0) | //
ranges::views::take(5) | //
ranges::to<std::vector>;
ASSERT_EQ(ranges::size(output), 5);
ASSERT_THAT(output, testing::ElementsAre(10, 4, 4, 4, 4));
}

TEST(RandomIntersperseView, ZeroProbabilityIntersperseTakeFive) {
auto input = std::array{10, 20, 30};
auto output = input | //
beluga::views::random_intersperse([]() { return 4; }, 0.0) | //
ranges::views::take(5) | //
ranges::to<std::vector>;
ASSERT_EQ(ranges::size(output), 3);
ASSERT_THAT(output, testing::ElementsAre(10, 20, 30));
}

TEST(RandomIntersperseView, ZeroProbabilityMultipass) {
auto input = std::array{10, 20, 30};
auto output = input | //
beluga::views::random_intersperse([]() { return 4; }, 0.0) | //
ranges::views::take(3);
ASSERT_THAT(output | ranges::to<std::vector>, testing::ElementsAre(10, 20, 30));
ASSERT_THAT(output | ranges::to<std::vector>, testing::ElementsAre(10, 20, 30));
}

class RandomIntersperseViewWithParam : public ::testing::TestWithParam<double> {};

TEST_P(RandomIntersperseViewWithParam, TestPercentage) {
const double expected_p = GetParam();
const int size = 10'000;
auto output = ranges::views::iota(1, size + 1) | beluga::views::random_intersperse([]() { return 0; }, expected_p);
const double count = static_cast<double>(ranges::count(output, 0));
const double actual_p = count / (size + count);
ASSERT_NEAR(expected_p, actual_p, 0.01);
}

INSTANTIATE_TEST_SUITE_P(
RandomIntersperseViewParams,
RandomIntersperseViewWithParam,
testing::Values(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9));

} // namespace

0 comments on commit b9867fc

Please sign in to comment.