Skip to content

Commit

Permalink
Add alternative sigmoid approximation (#3)
Browse files Browse the repository at this point in the history
* Add sigmoid_exp approximation

* Undo comments

* Tweaking error bounds
  • Loading branch information
jatinchowdhury18 authored Jan 19, 2024
1 parent 07408e3 commit 0c68d4d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 11 deletions.
21 changes: 14 additions & 7 deletions include/math_approx/src/sigmoid_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,18 @@ T sigmoid (T x)
return (S) 0.5 * x_poly * rsqrt (x_poly * x_poly + (S) 1) + (S) 0.5;
}

// So far this has tested slower than the above approx (for equivalent error),
// but maybe it will be useful for someone!
// template <int order, typename T>
// T sigmoid_exp (T x)
// {
// return (T) 1 / ((T) 1 + math_approx::exp<order> (-x));
// }

/**
* Approximation of sigmoid(x) := 1 / (1 + e^-x),
* using math_approx::exp (x).
*
* So far this has tested slower than the above approximation
* for similar absolute error, but has better relative error
* characteristics.
*/
template <int order, bool C1_continuous = false, typename T>
T sigmoid_exp (T x)
{
return (T) 1 / ((T) 1 + math_approx::exp<order, C1_continuous> (-x));
}
} // namespace math_approx
68 changes: 67 additions & 1 deletion test/src/sigmoid_approx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TEST_CASE ("Sigmoid Approx Test")
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f);
#endif
const auto y_exact = test_helpers::compute_all<float> (all_floats, [] (auto x)
{ return 1.0f / (1.0f + std::exp (-x)); });
{ return 1.0f / (1.0f + std::exp (-x)); });

const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound)
{
Expand Down Expand Up @@ -50,3 +50,69 @@ TEST_CASE ("Sigmoid Approx Test")
2.0e-3f);
}
}

TEST_CASE ("Sigmoid (Exp) Approx Test")
{
#if ! defined(WIN32)
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-3f);
#else
const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f);
#endif
const auto y_exact = test_helpers::compute_all<float> (all_floats, [] (auto x)
{ return 1.0f / (1.0f + std::exp (-x)); });

const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound, float rel_err_bound, uint32_t ulp_bound)
{
const auto y_approx = test_helpers::compute_all<float> (all_floats, f_approx);

const auto error = test_helpers::compute_error<float> (y_exact, y_approx);
const auto rel_error = test_helpers::compute_rel_error<float> (y_exact, y_approx);
const auto ulp_error = test_helpers::compute_ulp_error (y_exact, y_approx);

const auto max_error = test_helpers::abs_max<float> (error);
const auto max_rel_error = test_helpers::abs_max<float> (rel_error);
const auto max_ulp_error = *std::max_element (ulp_error.begin(), ulp_error.end());

std::cout << max_error << ", " << max_rel_error << ", " << max_ulp_error << std::endl;
REQUIRE (std::abs (max_error) < err_bound);
REQUIRE (std::abs (max_rel_error) < rel_err_bound);
if (ulp_bound > 0)
REQUIRE (max_ulp_error < ulp_bound);
};

SECTION ("6th-Order (Exp)")
{
test_approx ([] (auto x)
{ return math_approx::sigmoid_exp<6> (x); },
1.5e-7f,
6.5e-7f,
12);
}

SECTION ("5th-Order (Exp)")
{
test_approx ([] (auto x)
{ return math_approx::sigmoid_exp<5> (x); },
1.5e-7f,
7.5e-7f,
12);
}

SECTION ("4th-Order (Exp)")
{
test_approx ([] (auto x)
{ return math_approx::sigmoid_exp<4> (x); },
9.5e-7f,
4.5e-6f,
65);
}

SECTION ("3rd-Order (Exp)")
{
test_approx ([] (auto x)
{ return math_approx::sigmoid_exp<3> (x); },
3.0e-4f,
1.5e-4f,
0);
}
}
6 changes: 6 additions & 0 deletions tools/bench/sigmoid_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ SIGMOID_BENCH (sigmoid_std, [] (auto x) { return 1.0f / (1.0f + std::exp (-x));
SIGMOID_BENCH (sigmoid_approx9, math_approx::sigmoid<9>)
SIGMOID_BENCH (sigmoid_approx7, math_approx::sigmoid<7>)
SIGMOID_BENCH (sigmoid_approx5, math_approx::sigmoid<5>)
SIGMOID_BENCH (sigmoid_exp_approx6, math_approx::sigmoid_exp<6>)
SIGMOID_BENCH (sigmoid_exp_approx5, math_approx::sigmoid_exp<5>)
SIGMOID_BENCH (sigmoid_exp_approx4, math_approx::sigmoid_exp<4>)

#define SIGMOID_SIMD_BENCH(name, func) \
void name (benchmark::State& state) \
Expand All @@ -47,5 +50,8 @@ SIGMOID_SIMD_BENCH (sigmoid_xsimd, [] (auto x) { return 1.0f / (1.0f + xsimd::ex
SIGMOID_SIMD_BENCH (sigmoid_simd_approx9, math_approx::tanh<9>)
SIGMOID_SIMD_BENCH (sigmoid_simd_approx7, math_approx::tanh<7>)
SIGMOID_SIMD_BENCH (sigmoid_simd_approx5, math_approx::tanh<5>)
SIGMOID_SIMD_BENCH (sigmoid_exp_simd_approx6, math_approx::sigmoid_exp<6>)
SIGMOID_SIMD_BENCH (sigmoid_exp_simd_approx5, math_approx::sigmoid_exp<5>)
SIGMOID_SIMD_BENCH (sigmoid_exp_simd_approx4, math_approx::sigmoid_exp<4>)

BENCHMARK_MAIN();
13 changes: 10 additions & 3 deletions tools/plotter/plotter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,24 @@ void plot_function (std::span<const float> all_floats,
plt::named_plot<float, float> (name, all_floats, y_approx);
}

template <typename T>
T sigmoid_ref (T x)
{
return (T) 1 / ((T) 1 + std::exp (-x));
}

#define FLOAT_FUNC(func) [] (float x) { return func (x); }

int main()
{
plt::figure();
const auto range = std::make_pair (1.0f, 10.0f);
const auto range = std::make_pair (-10.0f, 10.0f);
static constexpr auto tol = 1.0e-2f;

const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol);
const auto y_exact = test_helpers::compute_all<float> (all_floats, FLOAT_FUNC (std::acosh));
plot_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::acosh<5>) ), "acosh-5");
const auto y_exact = test_helpers::compute_all<float> (all_floats, FLOAT_FUNC (sigmoid_ref));
plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::sigmoid_exp<5, true>) ), "sigmoid_exp-5_c1");
plot_ulp_error (all_floats, y_exact, FLOAT_FUNC ((math_approx::sigmoid_exp<6, true>) ), "sigmoid_exp-6_c1");

plt::legend ({ { "loc", "upper right" } });
plt::xlim (range.first, range.second);
Expand Down

0 comments on commit 0c68d4d

Please sign in to comment.