From a1590f8ecf7390affd1d25787b767287f5eda752 Mon Sep 17 00:00:00 2001 From: jatin Date: Wed, 22 Nov 2023 11:20:07 -0800 Subject: [PATCH] Trig approx improvements --- .github/workflows/run_tests.yml | 2 +- CMakeLists.txt | 3 + include/math_approx/math_approx.hpp | 2 +- .../src/{sin_approx.hpp => trig_approx.hpp} | 15 +++-- test/CMakeLists.txt | 3 +- test/src/cos_approx_test.cpp | 46 -------------- ...n_approx_test.cpp => trig_approx_test.cpp} | 41 ++++++++++++ tools/bench/CMakeLists.txt | 3 + tools/bench/trig_bench.cpp | 63 +++++++++++++++++++ tools/plotter/plotter.cpp | 4 +- 10 files changed, 125 insertions(+), 57 deletions(-) rename include/math_approx/src/{sin_approx.hpp => trig_approx.hpp} (91%) delete mode 100644 test/src/cos_approx_test.cpp rename test/src/{sin_approx_test.cpp => trig_approx_test.cpp} (52%) create mode 100644 tools/bench/trig_bench.cpp diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 5f39b26..6ea5321 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -42,4 +42,4 @@ jobs: - name: CMake Test run: | ctest --test-dir build -C RelWithDebInfo --show-only - ctest --test-dir build -C RelWithDebInfo -j4 --output-on-failure + ctest --test-dir build -C RelWithDebInfo -j2 --output-on-failure diff --git a/CMakeLists.txt b/CMakeLists.txt index 864a1d5..ab2ce04 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,9 @@ endif() add_library(math_approx INTERFACE) target_include_directories(math_approx INTERFACE include) +if(MSVC) + target_compile_definitions(math_approx INTERFACE _USE_MATH_DEFINES=1) +endif() if (TARGET xsimd) message(STATUS "math_approx -- Linking with XSIMD...") target_link_libraries(math_approx INTERFACE xsimd) diff --git a/include/math_approx/math_approx.hpp b/include/math_approx/math_approx.hpp index 6f5a076..406aceb 100644 --- a/include/math_approx/math_approx.hpp +++ b/include/math_approx/math_approx.hpp @@ -8,4 +8,4 @@ namespace math_approx #include "src/tanh_approx.hpp" #include "src/sigmoid_approx.hpp" -#include "src/sin_approx.hpp" +#include "src/trig_approx.hpp" diff --git a/include/math_approx/src/sin_approx.hpp b/include/math_approx/src/trig_approx.hpp similarity index 91% rename from include/math_approx/src/sin_approx.hpp rename to include/math_approx/src/trig_approx.hpp index cd85a37..48809aa 100644 --- a/include/math_approx/src/sin_approx.hpp +++ b/include/math_approx/src/trig_approx.hpp @@ -99,11 +99,16 @@ T cos_mpi_pi (T x) using S = scalar_of_t; static constexpr auto pi = static_cast (M_PI); - static constexpr auto pi_o_2 = pi * (S) 0.5;; + static constexpr auto pi_sq = pi * pi; + static constexpr auto pi_o_2 = pi * (S) 0.5; + + using std::abs; +#if defined(XSIMD_HPP) + using xsimd::abs; +#endif + x = abs (x); - const auto hpmx = (x > (S) 0 ? (S) 1 : (S) -1) * pi_o_2 - x; - const auto thpmx = (x > (S) 0 ? (S) 3 : (S) -3) * pi_o_2 - x; - const auto nhpmx = (x > (S) 0 ? (S) -1 : (S) 1) * pi_o_2 - x; + const auto hpmx = pi_o_2 - x; const auto hpmx_sq = hpmx * hpmx; T x_poly {}; @@ -114,7 +119,7 @@ T cos_mpi_pi (T x) else if constexpr (order == 5) x_poly = sin_detail::sin_poly_5 (hpmx, hpmx_sq); - return thpmx * nhpmx * (x > (S) 0 ? (S) -1 : (S) 1) * x_poly; + return (pi_sq - hpmx_sq) * x_poly; } template diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d588dfa..8ec6f36 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -30,5 +30,4 @@ endfunction(setup_catch_test) setup_catch_test(tanh_approx_test) setup_catch_test(sigmoid_approx_test) -setup_catch_test(sin_approx_test) -setup_catch_test(cos_approx_test) +setup_catch_test(trig_approx_test) diff --git a/test/src/cos_approx_test.cpp b/test/src/cos_approx_test.cpp deleted file mode 100644 index c821b13..0000000 --- a/test/src/cos_approx_test.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include "test_helpers.hpp" -#include -#include - -#include - -TEST_CASE ("Cosine Approx Test") -{ -#if ! defined(WIN32) - const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-3f); -#else - const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f); -#endif - const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) - { return std::cos (x); }); - - const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound) - { - const auto y_approx = test_helpers::compute_all (all_floats, f_approx); - - const auto error = test_helpers::compute_error (y_exact, y_approx); - const auto max_error = test_helpers::abs_max (error); - - std::cout << max_error << std::endl; - REQUIRE (std::abs (max_error) < err_bound); - }; - - SECTION ("9th-Order") - { - test_approx ([] (auto x) - { return math_approx::cos<9> (x); }, - 7.0e-7f); - } - SECTION ("7th-Order") - { - test_approx ([] (auto x) - { return math_approx::cos<7> (x); }, - 1.8e-5f); - } - SECTION ("5th-Order") - { - test_approx ([] (auto x) - { return math_approx::cos<5> (x); }, - 7.5e-4f); - } -} diff --git a/test/src/sin_approx_test.cpp b/test/src/trig_approx_test.cpp similarity index 52% rename from test/src/sin_approx_test.cpp rename to test/src/trig_approx_test.cpp index 11cdaca..f4a4799 100644 --- a/test/src/sin_approx_test.cpp +++ b/test/src/trig_approx_test.cpp @@ -44,3 +44,44 @@ TEST_CASE ("Sine Approx Test") 7.5e-4f); } } + +TEST_CASE ("Cosine Approx Test") +{ +#if ! defined(WIN32) + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-3f); +#else + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f); +#endif + const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) + { return std::cos (x); }); + + const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound) + { + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + + const auto error = test_helpers::compute_error (y_exact, y_approx); + const auto max_error = test_helpers::abs_max (error); + + // std::cout << max_error << std::endl; + REQUIRE (std::abs (max_error) < err_bound); + }; + + SECTION ("9th-Order") + { + test_approx ([] (auto x) + { return math_approx::cos<9> (x); }, + 7.5e-7f); + } + SECTION ("7th-Order") + { + test_approx ([] (auto x) + { return math_approx::cos<7> (x); }, + 1.8e-5f); + } + SECTION ("5th-Order") + { + test_approx ([] (auto x) + { return math_approx::cos<5> (x); }, + 7.5e-4f); + } +} diff --git a/tools/bench/CMakeLists.txt b/tools/bench/CMakeLists.txt index d8b4916..42c8769 100644 --- a/tools/bench/CMakeLists.txt +++ b/tools/bench/CMakeLists.txt @@ -10,3 +10,6 @@ target_link_libraries(tanh_approx_bench PRIVATE benchmark::benchmark math_approx add_executable(sigmoid_approx_bench sigmoid_bench.cpp) target_link_libraries(sigmoid_approx_bench PRIVATE benchmark::benchmark math_approx) + +add_executable(trig_approx_bench trig_bench.cpp) +target_link_libraries(trig_approx_bench PRIVATE benchmark::benchmark math_approx) diff --git a/tools/bench/trig_bench.cpp b/tools/bench/trig_bench.cpp new file mode 100644 index 0000000..796bd27 --- /dev/null +++ b/tools/bench/trig_bench.cpp @@ -0,0 +1,63 @@ +#include +#include + +static constexpr size_t N = 2000; +const auto data = [] +{ + std::vector x; + x.resize (N, 0.0f); + for (size_t i = 0; i < N; ++i) + x[i] = -10.0f + 20.0f * (float) i / (float) N; + return x; +}(); + +#define TRIG_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ +for (auto _ : state) \ +{ \ +for (auto& x : data) \ +{ \ +auto y = func (x); \ +benchmark::DoNotOptimize (y); \ +} \ +} \ +} \ +BENCHMARK (name); + +TRIG_BENCH (cos_std, std::cos) +TRIG_BENCH (cos_approx9, math_approx::cos<9>) +TRIG_BENCH (cos_approx7, math_approx::cos<7>) +TRIG_BENCH (cos_approx5, math_approx::cos<5>) + +TRIG_BENCH (sin_std, std::sin) +TRIG_BENCH (sin_approx9, math_approx::sin<9>) +TRIG_BENCH (sin_approx7, math_approx::sin<7>) +TRIG_BENCH (sin_approx5, math_approx::sin<5>) + +#define TRIG_SIMD_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ +for (auto _ : state) \ +{ \ +for (auto& x : data) \ +{ \ +auto y = func (xsimd::broadcast (x)); \ +static_assert (std::is_same_v, decltype(y)>); \ +benchmark::DoNotOptimize (y); \ +} \ +} \ +} \ +BENCHMARK (name); + +TRIG_SIMD_BENCH (sin_xsimd, xsimd::sin) +TRIG_SIMD_BENCH (sin_simd_approx9, math_approx::sin<9>) +TRIG_SIMD_BENCH (sin_simd_approx7, math_approx::sin<7>) +TRIG_SIMD_BENCH (sin_simd_approx5, math_approx::sin<5>) + +TRIG_SIMD_BENCH (cos_xsimd, xsimd::cos) +TRIG_SIMD_BENCH (cos_simd_approx9, math_approx::cos<9>) +TRIG_SIMD_BENCH (cos_simd_approx7, math_approx::cos<7>) +TRIG_SIMD_BENCH (cos_simd_approx5, math_approx::cos<5>) + +BENCHMARK_MAIN(); diff --git a/tools/plotter/plotter.cpp b/tools/plotter/plotter.cpp index 8ec6b9d..c41ca97 100644 --- a/tools/plotter/plotter.cpp +++ b/tools/plotter/plotter.cpp @@ -66,8 +66,8 @@ int main() // // plot_error (all_floats, y_exact, [] (float x) { return math_approx::sin<5> (x); }, "Sin-5"); // // plot_error (all_floats, y_exact, [] (float x) { return math_approx::sin<7> (x); }, "Sin-7"); - plot_ulp_error (all_floats, y_exact, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9"); - // plot_function (all_floats, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9"); + // plot_ulp_error (all_floats, y_exact, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9"); + plot_function (all_floats, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9"); plt::legend ({ { "loc", "upper right" } }); plt::xlim (range.first, range.second);