diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 3492c9e828c3..b0a45c4e629d 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -854,6 +854,38 @@ Expr halide_exp(const Expr &x_full) { return result; } +Tuple halide_extended_exp(const Expr &x_full) { + Type type = x_full.type(); + internal_assert(type.element_of() == Float(32)); + + float ln2_part1 = 0.6931457519f; + float ln2_part2 = 1.4286067653e-6f; + float one_over_ln2 = 1.0f / logf(2.0f); + + Expr scaled = x_full * one_over_ln2; + Expr k_real = floor(scaled); + + Expr x = x_full - k_real * ln2_part1; + x = x - k_real * ln2_part2; + + float coeff[] = { + 0.00031965933071842413f, + 0.00119156835564003744f, + 0.00848988645943932717f, + 0.04160188091348320655f, + 0.16667983794100929562f, + 0.49999899033463041098f, + 1.0f, + 1.0f}; + Expr result = evaluate_polynomial(x, coeff, sizeof(coeff) / sizeof(coeff[0])); + + // Ensure that the mantissa part is not a NaN or itself an infinity. + result = strict_float(select(!is_finite(k_real), 1, result)); + result = common_subexpression_elimination(result); + + return {result, k_real}; +} + Expr halide_erf(const Expr &x_full) { user_assert(x_full.type() == Float(32)) << "halide_erf only works for Float(32)"; diff --git a/src/IROperator.h b/src/IROperator.h index a96ef6223c0d..1ea81f84937e 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -189,6 +189,34 @@ Expr halide_exp(const Expr &a); Expr halide_erf(const Expr &a); // @} +/** Extended exponential which produces two output values, + * each of the same precision as the input, as described in + * "The Two-Pass Softmax Algorithm" by Marat Dukhan and + * Artsiom Ablavatski [https://arxiv.org/abs/2001.04438]. + * + * The first element of the returned Tuple is a psuedo-mantissa while + * the second is an exponent which is an integer. The product of the + * pseudo-mantissa and 2 raised to the returned exponent is the + * desired result e^a. For arguments up to slightly greater than + * 11629079, the pseudo-mantissa is guaranteed to be within the + * interval (-e, e). For larger arguments, the exponent result of the + * tuple may not be able to represent the exact integer necessary to + * keep the pseudo-mantissa within bounds. Thus it can become + * progressively larger in magnitude as the argument increases. + * + * Ideally this routine will maintain a degree of accuracy through the + * entire range and be able to produce results out to the end of the + * numeric range. At present neither of these properties are true due to + * the following issues: + * - Range reduction may overflow when scaling the argument. + * - Range reduction is increasingly inaccurate in reducing the value + * due to the implementation. This results in overflow in the polynomial + * evaluation. + * - Even if the above to issues were resolved, the approximation polynomial + * would have to run on values outside its intended approximation range. + */ +Tuple halide_extended_exp(const Expr &a); + /** Raise an expression to an integer power by repeatedly multiplying * it by itself. */ Expr raise_to_integer_power(Expr a, int64_t b); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index ae4a6776ac72..da2aae1a03be 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -94,6 +94,7 @@ tests(GROUPS correctness erf.cpp exception.cpp explicit_inline_reductions.cpp + extended_exp.cpp extern_bounds_inference.cpp extern_consumer.cpp extern_error.cpp diff --git a/test/correctness/extended_exp.cpp b/test/correctness/extended_exp.cpp new file mode 100644 index 000000000000..22b316eebe7e --- /dev/null +++ b/test/correctness/extended_exp.cpp @@ -0,0 +1,140 @@ +#include "Halide.h" +#include +#include +#include +#include + +using namespace Halide; +using Halide::Internal::halide_exp; +using Halide::Internal::halide_extended_exp; + +// Compare naive two pass softmax, which will overflow easily, to two +// pass algorithm from "The Two-Pass Softmax Algorithm" by Marat +// Dukhan and Artsiom Ablavatski [https://arxiv.org/abs/2001.04438], +// which is implemented using halide_extended_exp. +void two_pass_softmax_test(float scale) { + Var x("x"); + RDom r(0, 1024); + + Func input("input"); + input(x) = 0.0f; + input(r) = random_float() * scale; + + // Naive two pass algorithm. Doesn't work for large values or large size inputs. + Func in_exp("in_exp"); + in_exp(x) = halide_exp(input(x)); + Func exp_sum("exp_sum"); + exp_sum() = sum(in_exp(r)); + + Func naive_softmax("naive_softmax"); + naive_softmax(x) = in_exp(x) / exp_sum(); + + // Three pass algorithm that works for all inputs. + Func max_input("max_input"); + max_input() = maximum(input(r)); + Func biased_in_exp("biased_in_exp"); + biased_in_exp(x) = halide_exp(input(x) - max_input()); + Func biased_exp_sum("biased_exp_sum"); + biased_exp_sum() = sum(biased_in_exp(r)); + + Func three_pass_softmax("three_pass_softmax"); + three_pass_softmax(x) = biased_in_exp(x) / biased_exp_sum(); + + // Two pass extended exp algorithm. + Func in_extended_exp("in_extended_exp"); + in_extended_exp(x) = halide_extended_exp(input(x)); + Expr mantissa = in_extended_exp(x)[0]; + Expr exponent = in_extended_exp(x)[1]; + + Func extended_exp_sum("extended_exp_sum"); + extended_exp_sum() = Tuple(0.0f, std::numeric_limits::lowest()); // mantissa, exponent + Expr max_exp = max(extended_exp_sum()[1], in_extended_exp(r)[1]); + Expr mantissa_sum = in_extended_exp(r)[0] * pow(2, in_extended_exp(r)[1] - max_exp) + + extended_exp_sum()[0] * pow(2, extended_exp_sum()[1] - max_exp); + extended_exp_sum() = Tuple(mantissa_sum, max_exp); + + Expr lambda = 1 / extended_exp_sum()[0]; + Func two_pass_softmax("two_pass_softmax"); + two_pass_softmax(x) = in_extended_exp(x)[0] * lambda * pow(2, in_extended_exp(x)[1] - extended_exp_sum()[1]); + + Func relative_error("relative_error"); + relative_error(x) = abs(three_pass_softmax(x) - two_pass_softmax(x)) / max(.000001f, three_pass_softmax(x)); + Func max_relative_error("max_relative_error"); + max_relative_error() = maximum(relative_error(r)); + Func max_prob("max_prob"); + max_prob() = maximum(two_pass_softmax(r)); + Func min_prob("min_prob"); + min_prob() = minimum(two_pass_softmax(r)); + Func sum_prob("sum_prob"); + sum_prob() = sum(two_pass_softmax(r)); + + Func result("result"); + result() = Tuple(max_relative_error(), max_prob(), min_prob(), sum_prob()); + exp_sum.compute_root(); + biased_exp_sum.compute_root(); + extended_exp_sum.compute_root(); + naive_softmax.compute_root(); + three_pass_softmax.compute_root(); + two_pass_softmax.compute_root(); + + auto output = result.realize(); + + float max_relative_error_result = ((Buffer &)output[0])(); + float max_probability = ((Buffer &)output[1])(); + float min_probability = ((Buffer &)output[2])(); + float sum_probability = ((Buffer &)output[3])(); + + if (max_relative_error_result > .0001f) { + std::cout << "Failed: Softmax results do not match.\n"; + exit(1); + } + + if (max_probability > 1.0f) { + std::cout << "Failed: Softmax probability is greater than 1.0f.\n"; + exit(1); + } + + if (min_probability < 0.0f) { + std::cout << "Failed: Softmax probability is negative.\n"; + exit(1); + } + + if (sum_probability > 1.0001f) { + std::cout << "Failed: Softmax probability sum is too large.\n"; + exit(1); + } +} + +void expect(float x, float mantissa, float exponent) { + float computed_mantissa; + float computed_exponent; + evaluate(halide_extended_exp(x), &computed_mantissa, &computed_exponent); + if (fabs(computed_mantissa) > exp(1.0f)) { + std::cout << "Mantissa large for x " << x << " mantissa " << computed_mantissa + << " exponent " << computed_exponent << "\n"; + } + if (fabs(mantissa - computed_mantissa) > .00001 || + fabs(exponent - computed_exponent) > .00001) { + std::cout << "Falied: halide_extended_exp(" << x << ") == {" + << computed_mantissa << ", " << computed_exponent + << "} expected {" + << mantissa << ", " << exponent << "}\n"; + exit(1); + } +} + +int main(int argc, char **argv) { + std::cout << std::hexfloat; + expect(0, 1, 0); + expect(1, exp(1.0f) / 2, 1); + expect(88, 1.94149, 126); + expect(0x1.62e43p+23f, 0x1.085012p+0, 0x1p+24); + expect(std::numeric_limits::lowest(), 1.0f, -std::numeric_limits::infinity()); + expect(std::numeric_limits::max(), 1.0f, std::numeric_limits::infinity()); + two_pass_softmax_test(1.0f); + two_pass_softmax_test(10000.0f); + two_pass_softmax_test(-10000.0f); + two_pass_softmax_test(std::numeric_limits::max()); + two_pass_softmax_test(std::numeric_limits::lowest()); + std::cout << "Success!\n"; +}