From c988736b3986a8b55f5d05ec89d7de7c42738335 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Sun, 15 Sep 2024 15:37:10 +0200 Subject: [PATCH] Add a preprocessing step before the Taylor decomposition to transform powers with non-numerical exponent into combinations of exps and logs. This allows us to avoid having to implement the Taylor diff of pow() for non-numerical exponents. --- src/math/pow.cpp | 8 +++++ src/taylor_01.cpp | 78 +++++++++++++++++++++++++++++++++++++++++++++ test/taylor_pow.cpp | 68 ++++++++++++++++++--------------------- 3 files changed, 117 insertions(+), 37 deletions(-) diff --git a/src/math/pow.cpp b/src/math/pow.cpp index 384377c9c..c71e98ac2 100644 --- a/src/math/pow.cpp +++ b/src/math/pow.cpp @@ -564,6 +564,8 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp return ret; } +// LCOV_EXCL_START + // All the other cases. template , is_num_param>, int> = 0> llvm::Value *taylor_diff_pow_impl(llvm_state &, llvm::Type *, const pow_impl &, const U1 &, const U2 &, @@ -574,6 +576,8 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &, llvm::Type *, const pow_impl &, "An invalid argument type was encountered while trying to build the Taylor derivative of a pow()"); } +// LCOV_EXCL_STOP + llvm::Value *taylor_diff_pow(llvm_state &s, llvm::Type *fp_t, const pow_impl &f, const std::vector &deps, const std::vector &arr, llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size) @@ -971,6 +975,8 @@ llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &s, llvm::Type *fp_t, con return f; } +// LCOV_EXCL_START + // All the other cases. template , is_num_param>, int> = 0> llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &, llvm::Type *, const pow_impl &, const U1 &, const U2 &, @@ -980,6 +986,8 @@ llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &, llvm::Type *, const po "of a pow() in compact mode"); } +// LCOV_EXCL_STOP + llvm::Function *taylor_c_diff_func_pow(llvm_state &s, llvm::Type *fp_t, const pow_impl &fn, std::uint32_t n_uvars, std::uint32_t batch_size) { diff --git a/src/taylor_01.cpp b/src/taylor_01.cpp index 679a4a33c..5626b9205 100644 --- a/src/taylor_01.cpp +++ b/src/taylor_01.cpp @@ -8,12 +8,14 @@ #include #include +#include #include #include #include #include #include #include +#include #include #include #include @@ -46,6 +48,7 @@ #include #include +#include #include #include #include @@ -54,7 +57,11 @@ #include #include #include +#include #include +#include +#include +#include #include #include #include @@ -769,6 +776,74 @@ void taylor_decompose_replace_numbers(taylor_dc_t &dc, std::vector:: } } +// NOLINTNEXTLINE(misc-no-recursion) +expression pow_to_explog(funcptr_map &func_map, const expression &ex) +{ + return std::visit( + // NOLINTNEXTLINE(misc-no-recursion) + [&](const T &v) { + if constexpr (std::same_as) { + const auto *f_id = v.get_ptr(); + + // Check if we already performed the transformation on ex. + if (const auto it = func_map.find(f_id); it != func_map.end()) { + return it->second; + } + + // Perform the transformation on the function arguments. + std::vector new_args; + new_args.reserve(v.args().size()); + for (const auto &orig_arg : v.args()) { + new_args.push_back(pow_to_explog(func_map, orig_arg)); + } + + // Prepare the return value. + std::optional retval; + + if (v.template extract() != nullptr + && !std::holds_alternative(new_args[1].value())) { + // The function is a pow() and the exponent is not a number: transform x**y -> exp(y*log(x)). + // + // NOTE: do not call directly log(new_args[0]) in order to avoid constant folding when the base + // is a number. For instance, if we have pow(2_dbl, par[0]), then we would end up computing + // log(2) in double precision. This would result in an inaccurate result if the fp type + // or precision in use during integration is higher than double. + // NOTE: because the exponent is not a number, no other constant folding should take + // place here. + retval.emplace(exp(new_args[1] * expression{func{detail::log_impl(new_args[0])}})); + } else { + // Create a copy of v with the new arguments. + retval.emplace(v.copy(std::move(new_args))); + } + + // Put the return value into the cache. + [[maybe_unused]] const auto [_, flag] = func_map.emplace(f_id, *retval); + // NOTE: an expression cannot contain itself. + assert(flag); // LCOV_EXCL_LINE + + return std::move(*retval); + } else { + return ex; + } + }, + ex.value()); +} + +// Helper to transform x**y -> exp(y*log(x)), if y is not a number. +std::vector pow_to_explog(const std::vector &v_ex) +{ + funcptr_map func_map; + + std::vector retval; + retval.reserve(v_ex.size()); + + for (const auto &e : v_ex) { + retval.push_back(pow_to_explog(func_map, e)); + } + + return retval; +} + } // namespace } // namespace detail @@ -798,6 +873,9 @@ taylor_decompose_sys(const std::vector> &sys_, std::ranges::transform(sys_, std::back_inserter(all_ex), &std::pair::second); all_ex.insert(all_ex.end(), sv_funcs_.begin(), sv_funcs_.end()); + // Transform x**y -> exp(y*log(x)), if y is not a number. + all_ex = detail::pow_to_explog(all_ex); + // Transform sums into subs. all_ex = detail::sum_to_sub(all_ex); diff --git a/test/taylor_pow.cpp b/test/taylor_pow.cpp index be2e6f00b..17e1632ff 100644 --- a/test/taylor_pow.cpp +++ b/test/taylor_pow.cpp @@ -26,6 +26,7 @@ #endif #include +#include #include #include #include @@ -97,7 +98,9 @@ TEST_CASE("taylor pow approx") kw::opt_level = 0, kw::compact_mode = true}; - REQUIRE(ir_contains(ta, "taylor_c_diff.pow.")); + REQUIRE(!ir_contains(ta, "taylor_c_diff.pow.")); + REQUIRE(ir_contains(ta, "taylor_c_diff.exp.")); + REQUIRE(ir_contains(ta, "taylor_c_diff.log.")); } { @@ -167,9 +170,7 @@ TEST_CASE("taylor pow") kw::opt_level = opt_level, kw::pars = {fp_t{1} / fp_t{3}}}; - if (opt_level == 0u && compact_mode) { - REQUIRE(ir_contains(ta, "@heyoka.taylor_c_diff.pow.num_par")); - } + REQUIRE(!ir_contains(ta, "@heyoka.taylor_c_diff.pow.num_par")); ta.step(true); @@ -705,39 +706,6 @@ TEST_CASE("taylor pow") compare_batch_scalar({prime(x) = pow(y, expression{number{fp_t{3}}} / expression{number{fp_t{2}}}), prime(y) = pow(x, expression{number{fp_t{-1}}} / expression{number{fp_t{3}}})}, opt_level, high_accuracy, compact_mode, rng, .1f, 20.f); - - // Failure modes for non-implemented cases. - { - REQUIRE_THROWS_MATCHES((taylor_adaptive_batch{{prime(x) = pow(1_dbl, x)}, - {fp_t{2}, fp_t{2}, fp_t{3}}, - 3, - kw::tol = .1, - kw::high_accuracy = high_accuracy, - kw::compact_mode = compact_mode, - kw::opt_level = opt_level}), - std::invalid_argument, - Message(compact_mode - ? "An invalid argument type was encountered while trying to build the " - "Taylor derivative of a pow() in compact mode" - : "An invalid argument type was encountered while trying to build the " - "Taylor derivative of a pow()")); - } - - { - REQUIRE_THROWS_MATCHES((taylor_adaptive_batch{{prime(y) = pow(y, x), prime(x) = x + y}, - {fp_t{2}, fp_t{2}, fp_t{3}, fp_t{2}, fp_t{2}, fp_t{3}}, - 3, - kw::tol = .1, - kw::high_accuracy = high_accuracy, - kw::compact_mode = compact_mode, - kw::opt_level = opt_level}), - std::invalid_argument, - Message(compact_mode - ? "An invalid argument type was encountered while trying to build the " - "Taylor derivative of a pow() in compact mode" - : "An invalid argument type was encountered while trying to build the " - "Taylor derivative of a pow()")); - } }; for (auto cm : {false, true}) { @@ -749,3 +717,29 @@ TEST_CASE("taylor pow") } } } + +// Small test for the preprocessing that turns pow into exp+log. +TEST_CASE("pow_to_explog") +{ + auto [x, y] = make_vars("x", "y"); + + auto tmp1 = x + pow(y, par[0]); + auto tmp2 = pow(x, tmp1); + auto tmp3 = pow(tmp1, y); + + auto ta = taylor_adaptive{{prime(x) = (tmp1 * tmp2) / tmp3, prime(y) = tmp1}, {}, kw::tol = 1e-1}; + + REQUIRE(ta.get_decomposition().size() == 16u); + + // Count the number of exps and logs. + auto n_exp = 0, n_log = 0; + for (const auto &[ex, _] : ta.get_decomposition()) { + if (const auto *fptr = std::get_if(&ex.value())) { + n_exp += (fptr->extract() != nullptr); + n_log += (fptr->extract() != nullptr); + } + } + + REQUIRE(n_exp == 3); + REQUIRE(n_log == 3); +}