Skip to content

Commit

Permalink
Add a preprocessing step before the Taylor decomposition to transform…
Browse files Browse the repository at this point in the history
… powers

with non-numerical exponent into combinations of exps and logs.

This allows us to avoid having to implement the Taylor diff of pow() for non-numerical
exponents.
  • Loading branch information
bluescarni committed Sep 15, 2024
1 parent 907ba11 commit c988736
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 37 deletions.
8 changes: 8 additions & 0 deletions src/math/pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,8 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
return ret;
}

// LCOV_EXCL_START

// All the other cases.
template <typename U1, typename U2, std::enable_if_t<!std::conjunction_v<is_num_param<U1>, is_num_param<U2>>, int> = 0>
llvm::Value *taylor_diff_pow_impl(llvm_state &, llvm::Type *, const pow_impl &, const U1 &, const U2 &,
Expand All @@ -574,6 +576,8 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &, llvm::Type *, const pow_impl &,
"An invalid argument type was encountered while trying to build the Taylor derivative of a pow()");
}

// LCOV_EXCL_STOP

llvm::Value *taylor_diff_pow(llvm_state &s, llvm::Type *fp_t, const pow_impl &f, const std::vector<std::uint32_t> &deps,
const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
Expand Down Expand Up @@ -971,6 +975,8 @@ llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &s, llvm::Type *fp_t, con
return f;
}

// LCOV_EXCL_START

// All the other cases.
template <typename U1, typename U2, std::enable_if_t<!std::conjunction_v<is_num_param<U1>, is_num_param<U2>>, int> = 0>
llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &, llvm::Type *, const pow_impl &, const U1 &, const U2 &,
Expand All @@ -980,6 +986,8 @@ llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &, llvm::Type *, const po
"of a pow() in compact mode");
}

// LCOV_EXCL_STOP

llvm::Function *taylor_c_diff_func_pow(llvm_state &s, llvm::Type *fp_t, const pow_impl &fn, std::uint32_t n_uvars,
std::uint32_t batch_size)
{
Expand Down
78 changes: 78 additions & 0 deletions src/taylor_01.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

#include <algorithm>
#include <cassert>
#include <concepts>
#include <cstdint>
#include <deque>
#include <exception>
#include <iterator>
#include <limits>
#include <numeric>
#include <optional>
#include <stdexcept>
#include <string>
#include <type_traits>
Expand Down Expand Up @@ -46,6 +48,7 @@

#include <heyoka/config.hpp>
#include <heyoka/detail/cm_utils.hpp>
#include <heyoka/detail/func_cache.hpp>
#include <heyoka/detail/llvm_func_create.hpp>
#include <heyoka/detail/llvm_helpers.hpp>
#include <heyoka/detail/logging_impl.hpp>
Expand All @@ -54,7 +57,11 @@
#include <heyoka/detail/type_traits.hpp>
#include <heyoka/detail/visibility.hpp>
#include <heyoka/expression.hpp>
#include <heyoka/func.hpp>
#include <heyoka/llvm_state.hpp>
#include <heyoka/math/exp.hpp>
#include <heyoka/math/log.hpp>
#include <heyoka/math/pow.hpp>
#include <heyoka/math/prod.hpp>
#include <heyoka/math/sum.hpp>
#include <heyoka/number.hpp>
Expand Down Expand Up @@ -769,6 +776,74 @@ void taylor_decompose_replace_numbers(taylor_dc_t &dc, std::vector<expression>::
}
}

// NOLINTNEXTLINE(misc-no-recursion)
expression pow_to_explog(funcptr_map<expression> &func_map, const expression &ex)
{
return std::visit(
// NOLINTNEXTLINE(misc-no-recursion)
[&]<typename T>(const T &v) {
if constexpr (std::same_as<T, func>) {
const auto *f_id = v.get_ptr();

// Check if we already performed the transformation on ex.
if (const auto it = func_map.find(f_id); it != func_map.end()) {
return it->second;
}

// Perform the transformation on the function arguments.
std::vector<expression> new_args;
new_args.reserve(v.args().size());
for (const auto &orig_arg : v.args()) {
new_args.push_back(pow_to_explog(func_map, orig_arg));
}

// Prepare the return value.
std::optional<expression> retval;

if (v.template extract<detail::pow_impl>() != nullptr
&& !std::holds_alternative<number>(new_args[1].value())) {
// The function is a pow() and the exponent is not a number: transform x**y -> exp(y*log(x)).
//
// NOTE: do not call directly log(new_args[0]) in order to avoid constant folding when the base
// is a number. For instance, if we have pow(2_dbl, par[0]), then we would end up computing
// log(2) in double precision. This would result in an inaccurate result if the fp type
// or precision in use during integration is higher than double.
// NOTE: because the exponent is not a number, no other constant folding should take
// place here.
retval.emplace(exp(new_args[1] * expression{func{detail::log_impl(new_args[0])}}));
} else {
// Create a copy of v with the new arguments.
retval.emplace(v.copy(std::move(new_args)));
}

// Put the return value into the cache.
[[maybe_unused]] const auto [_, flag] = func_map.emplace(f_id, *retval);
// NOTE: an expression cannot contain itself.
assert(flag); // LCOV_EXCL_LINE

return std::move(*retval);
} else {
return ex;
}
},
ex.value());
}

// Helper to transform x**y -> exp(y*log(x)), if y is not a number.
std::vector<expression> pow_to_explog(const std::vector<expression> &v_ex)
{
funcptr_map<expression> func_map;

std::vector<expression> retval;
retval.reserve(v_ex.size());

for (const auto &e : v_ex) {
retval.push_back(pow_to_explog(func_map, e));
}

return retval;
}

} // namespace

} // namespace detail
Expand Down Expand Up @@ -798,6 +873,9 @@ taylor_decompose_sys(const std::vector<std::pair<expression, expression>> &sys_,
std::ranges::transform(sys_, std::back_inserter(all_ex), &std::pair<expression, expression>::second);
all_ex.insert(all_ex.end(), sv_funcs_.begin(), sv_funcs_.end());

// Transform x**y -> exp(y*log(x)), if y is not a number.
all_ex = detail::pow_to_explog(all_ex);

// Transform sums into subs.
all_ex = detail::sum_to_sub(all_ex);

Expand Down
68 changes: 31 additions & 37 deletions test/taylor_pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#endif

#include <heyoka/expression.hpp>
#include <heyoka/func.hpp>
#include <heyoka/kw.hpp>
#include <heyoka/math.hpp>
#include <heyoka/number.hpp>
Expand Down Expand Up @@ -97,7 +98,9 @@ TEST_CASE("taylor pow approx")
kw::opt_level = 0,
kw::compact_mode = true};

REQUIRE(ir_contains(ta, "taylor_c_diff.pow."));
REQUIRE(!ir_contains(ta, "taylor_c_diff.pow."));
REQUIRE(ir_contains(ta, "taylor_c_diff.exp."));
REQUIRE(ir_contains(ta, "taylor_c_diff.log."));
}

{
Expand Down Expand Up @@ -167,9 +170,7 @@ TEST_CASE("taylor pow")
kw::opt_level = opt_level,
kw::pars = {fp_t{1} / fp_t{3}}};

if (opt_level == 0u && compact_mode) {
REQUIRE(ir_contains(ta, "@heyoka.taylor_c_diff.pow.num_par"));
}
REQUIRE(!ir_contains(ta, "@heyoka.taylor_c_diff.pow.num_par"));

ta.step(true);

Expand Down Expand Up @@ -705,39 +706,6 @@ TEST_CASE("taylor pow")
compare_batch_scalar<fp_t>({prime(x) = pow(y, expression{number{fp_t{3}}} / expression{number{fp_t{2}}}),
prime(y) = pow(x, expression{number{fp_t{-1}}} / expression{number{fp_t{3}}})},
opt_level, high_accuracy, compact_mode, rng, .1f, 20.f);

// Failure modes for non-implemented cases.
{
REQUIRE_THROWS_MATCHES((taylor_adaptive_batch<fp_t>{{prime(x) = pow(1_dbl, x)},
{fp_t{2}, fp_t{2}, fp_t{3}},
3,
kw::tol = .1,
kw::high_accuracy = high_accuracy,
kw::compact_mode = compact_mode,
kw::opt_level = opt_level}),
std::invalid_argument,
Message(compact_mode
? "An invalid argument type was encountered while trying to build the "
"Taylor derivative of a pow() in compact mode"
: "An invalid argument type was encountered while trying to build the "
"Taylor derivative of a pow()"));
}

{
REQUIRE_THROWS_MATCHES((taylor_adaptive_batch<fp_t>{{prime(y) = pow(y, x), prime(x) = x + y},
{fp_t{2}, fp_t{2}, fp_t{3}, fp_t{2}, fp_t{2}, fp_t{3}},
3,
kw::tol = .1,
kw::high_accuracy = high_accuracy,
kw::compact_mode = compact_mode,
kw::opt_level = opt_level}),
std::invalid_argument,
Message(compact_mode
? "An invalid argument type was encountered while trying to build the "
"Taylor derivative of a pow() in compact mode"
: "An invalid argument type was encountered while trying to build the "
"Taylor derivative of a pow()"));
}
};

for (auto cm : {false, true}) {
Expand All @@ -749,3 +717,29 @@ TEST_CASE("taylor pow")
}
}
}

// Small test for the preprocessing that turns pow into exp+log.
TEST_CASE("pow_to_explog")
{
auto [x, y] = make_vars("x", "y");

auto tmp1 = x + pow(y, par[0]);
auto tmp2 = pow(x, tmp1);
auto tmp3 = pow(tmp1, y);

auto ta = taylor_adaptive<double>{{prime(x) = (tmp1 * tmp2) / tmp3, prime(y) = tmp1}, {}, kw::tol = 1e-1};

REQUIRE(ta.get_decomposition().size() == 16u);

// Count the number of exps and logs.
auto n_exp = 0, n_log = 0;
for (const auto &[ex, _] : ta.get_decomposition()) {
if (const auto *fptr = std::get_if<func>(&ex.value())) {
n_exp += (fptr->extract<detail::exp_impl>() != nullptr);
n_log += (fptr->extract<detail::log_impl>() != nullptr);
}
}

REQUIRE(n_exp == 3);
REQUIRE(n_log == 3);
}

0 comments on commit c988736

Please sign in to comment.