Skip to content

Commit

Permalink
Adding sin and cosine approximations
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinchowdhury18 committed Nov 22, 2023
1 parent fa44c50 commit d06887a
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 15 deletions.
1 change: 1 addition & 0 deletions include/math_approx/math_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ namespace math_approx

#include "src/tanh_approx.hpp"
#include "src/sigmoid_approx.hpp"
#include "src/sin_approx.hpp"
12 changes: 12 additions & 0 deletions include/math_approx/src/basic_math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ T rsqrt (T x)
// return x * r;
}

template <typename T>
T select (bool q, T t, T f)
{
return q ? t : f;
}

#if defined(XSIMD_HPP)
template <typename T>
struct scalar_of<xsimd::batch<T>>
Expand All @@ -54,5 +60,11 @@ xsimd::batch<T> rsqrt (xsimd::batch<T> x)
r *= (S) -0.5;
return x * r;
}

template <typename T>
xsimd::batch<T> select (xsimd::batch_bool<T> q, xsimd::batch<T> t, xsimd::batch<T> f)
{
return xsimd::select (q, t, f);
}
#endif
} // namespace math_approx
2 changes: 1 addition & 1 deletion include/math_approx/src/tanh_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ namespace tanh_detail
template <int order, typename T>
T tanh (T x)
{
static_assert (order % 2 == 1 && order <= 11 && order >= 3, "Order must e an odd number within [3, 9]");
static_assert (order % 2 == 1 && order <= 11 && order >= 3, "Order must e an odd number within [3, 11]");

T x_poly {};
if constexpr (order == 11)
Expand Down
2 changes: 2 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ endfunction(setup_catch_test)

setup_catch_test(tanh_approx_test)
setup_catch_test(sigmoid_approx_test)
setup_catch_test(sin_approx_test)
setup_catch_test(cos_approx_test)
49 changes: 46 additions & 3 deletions test/src/cos_approx_test.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,46 @@
//
// Created by jatin on 11/22/23.
//
#include "test_helpers.hpp"
#include <catch2/catch_test_macros.hpp>
#include <iostream>

#include <math_approx/math_approx.hpp>

TEST_CASE ("Cosine Approx Test")
{
#if ! defined(WIN32)
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-3f);
#else
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f);
#endif
const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x)
{ return std::cos (x); });

const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound)
{
const auto y_approx = test_helpers::compute_all (all_floats, f_approx);

const auto error = test_helpers::compute_error (y_exact, y_approx);
const auto max_error = test_helpers::abs_max (error);

std::cout << max_error << std::endl;
REQUIRE (std::abs (max_error) < err_bound);
};

SECTION ("9th-Order")
{
test_approx ([] (auto x)
{ return math_approx::cos<9> (x); },
7.0e-7f);
}
SECTION ("7th-Order")
{
test_approx ([] (auto x)
{ return math_approx::cos<7> (x); },
1.8e-5f);
}
SECTION ("5th-Order")
{
test_approx ([] (auto x)
{ return math_approx::cos<5> (x); },
7.5e-4f);
}
}
49 changes: 46 additions & 3 deletions test/src/sin_approx_test.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,46 @@
//
// Created by jatin on 11/22/23.
//
#include "test_helpers.hpp"
#include <catch2/catch_test_macros.hpp>
#include <iostream>

#include <math_approx/math_approx.hpp>

TEST_CASE ("Sine Approx Test")
{
#if ! defined(WIN32)
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-3f);
#else
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f);
#endif
const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x)
{ return std::sin (x); });

const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound)
{
const auto y_approx = test_helpers::compute_all (all_floats, f_approx);

const auto error = test_helpers::compute_error (y_exact, y_approx);
const auto max_error = test_helpers::abs_max (error);

// std::cout << max_error << std::endl;
REQUIRE (std::abs (max_error) < err_bound);
};

SECTION ("9th-Order")
{
test_approx ([] (auto x)
{ return math_approx::sin<9> (x); },
8.5e-7f);
}
SECTION ("7th-Order")
{
test_approx ([] (auto x)
{ return math_approx::sin<7> (x); },
1.8e-5f);
}
SECTION ("5th-Order")
{
test_approx ([] (auto x)
{ return math_approx::sin<5> (x); },
7.5e-4f);
}
}
14 changes: 6 additions & 8 deletions tools/plotter/plotter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,17 @@ void plot_function (std::span<const float> all_floats,
int main()
{
plt::figure();
const auto range = std::make_pair (-10.0f, 10.0f);
const auto range = std::make_pair (-3.141f, 3.141f);
static constexpr auto tol = 1.0e-2f;

const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol);
const auto y_exact = test_helpers::compute_all (all_floats, [] (float x)
{ return 1.0f / (1.0f + std::exp (-x)); });
{ return std::cos (x); });

plot_error (
all_floats,
y_exact,
[] (float x)
{ return math_approx::sigmoid_exp<3> (x); },
"Sigmoid-Exp-5");
// // plot_error (all_floats, y_exact, [] (float x) { return math_approx::sin<5> (x); }, "Sin-5");
// // plot_error (all_floats, y_exact, [] (float x) { return math_approx::sin<7> (x); }, "Sin-7");
plot_ulp_error (all_floats, y_exact, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9");
// plot_function (all_floats, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9");

plt::legend ({ { "loc", "upper right" } });
plt::xlim (range.first, range.second);
Expand Down

0 comments on commit d06887a

Please sign in to comment.