From 23262957087c1130f48b4e2804f9f5f0c6c64830 Mon Sep 17 00:00:00 2001 From: Zalman Stern Date: Tue, 23 Apr 2024 23:54:15 -0700 Subject: [PATCH 1/7] Initial checking of halide_extended_exp support. --- src/IROperator.cpp | 30 +++++++ src/IROperator.h | 28 +++++++ test/correctness/extended_exp.cpp | 127 ++++++++++++++++++++++++++++++ 3 files changed, 185 insertions(+) create mode 100644 test/correctness/extended_exp.cpp diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 0f318f777561..eaeda11ed6a8 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -852,6 +852,36 @@ Expr halide_exp(const Expr &x_full) { return result; } +Tuple halide_extended_exp(const Expr &x_full) { + Type type = x_full.type(); + internal_assert(type.element_of() == Float(32)); + + float ln2_part1 = 0.6931457519f; + float ln2_part2 = 1.4286067653e-6f; + float one_over_ln2 = 1.0f / logf(2.0f); + + Expr scaled = x_full * one_over_ln2; + Expr k_real = floor(scaled); + + Expr x = x_full - k_real * ln2_part1; + x -= k_real * ln2_part2; + + float coeff[] = { + 0.00031965933071842413f, + 0.00119156835564003744f, + 0.00848988645943932717f, + 0.04160188091348320655f, + 0.16667983794100929562f, + 0.49999899033463041098f, + 1.0f, + 1.0f}; + Expr result = evaluate_polynomial(x, coeff, sizeof(coeff) / sizeof(coeff[0])); + + result = common_subexpression_elimination(result); + + return {result, k_real}; +} + Expr halide_erf(const Expr &x_full) { user_assert(x_full.type() == Float(32)) << "halide_erf only works for Float(32)"; diff --git a/src/IROperator.h b/src/IROperator.h index a96ef6223c0d..1ea81f84937e 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -189,6 +189,34 @@ Expr halide_exp(const Expr &a); Expr halide_erf(const Expr &a); // @} +/** Extended exponential which produces two output values, + * each of the same precision as the input, as described in + * "The Two-Pass Softmax Algorithm" by Marat Dukhan and + * Artsiom Ablavatski [https://arxiv.org/abs/2001.04438]. + * + * The first element of the returned Tuple is a psuedo-mantissa while + * the second is an exponent which is an integer. The product of the + * pseudo-mantissa and 2 raised to the returned exponent is the + * desired result e^a. For arguments up to slightly greater than + * 11629079, the pseudo-mantissa is guaranteed to be within the + * interval (-e, e). For larger arguments, the exponent result of the + * tuple may not be able to represent the exact integer necessary to + * keep the pseudo-mantissa within bounds. Thus it can become + * progressively larger in magnitude as the argument increases. + * + * Ideally this routine will maintain a degree of accuracy through the + * entire range and be able to produce results out to the end of the + * numeric range. At present neither of these properties are true due to + * the following issues: + * - Range reduction may overflow when scaling the argument. + * - Range reduction is increasingly inaccurate in reducing the value + * due to the implementation. This results in overflow in the polynomial + * evaluation. + * - Even if the above to issues were resolved, the approximation polynomial + * would have to run on values outside its intended approximation range. + */ +Tuple halide_extended_exp(const Expr &a); + /** Raise an expression to an integer power by repeatedly multiplying * it by itself. */ Expr raise_to_integer_power(Expr a, int64_t b); diff --git a/test/correctness/extended_exp.cpp b/test/correctness/extended_exp.cpp new file mode 100644 index 000000000000..18ec5f773494 --- /dev/null +++ b/test/correctness/extended_exp.cpp @@ -0,0 +1,127 @@ +#include "Halide.h" +#include +#include +#include +#include + +using namespace Halide; +using Halide::Internal::halide_exp; +using Halide::Internal::halide_extended_exp; + +// Compare naive two pass softmax, which will overflow easily, to two +// pass algorithm from "The Two-Pass Softmax Algorithm" by Marat +// Dukhan and Artsiom Ablavatski [https://arxiv.org/abs/2001.04438], +// which is implemented using halide_extended_exp. +void two_pass_softmax_test(float scale) { + Var x("x"); + RDom r(0, 1024); + + Func input("input"); + input(x) = 0.0f; + input(r) = random_float(); + + Func in_exp("in_exp"); + in_exp(x) = halide_exp(input(x)); + Func exp_sum("exp_sum"); + exp_sum() = sum(in_exp(r)); + + Func naive_softmax("naive_softmax"); + naive_softmax(x) = in_exp(x) / exp_sum(); + + Func in_extended_exp("in_extended_exp"); + in_extended_exp(x) = halide_extended_exp(input(x)); + Expr mantissa = in_extended_exp(x)[0]; + Expr exponent = in_extended_exp(x)[1]; + + Func extended_exp_sum("extended_exp_sum"); + extended_exp_sum() = Tuple(0.0f, std::numeric_limits::lowest()); // mantissa, exponent + Expr max_exp = max(extended_exp_sum()[1], in_extended_exp(r)[1]); + Expr mantissa_sum = in_extended_exp(r)[0] * pow(2, in_extended_exp(r)[1] - max_exp) + + extended_exp_sum()[0] * pow(2, extended_exp_sum()[1] - max_exp); + extended_exp_sum() = Tuple(mantissa_sum, max_exp); + + Expr lambda = 1 / extended_exp_sum()[0]; + Func two_pass_softmax("two_pass_softmax"); + two_pass_softmax(x) = in_extended_exp(x)[0] * lambda * pow(2, in_extended_exp(x)[1] - extended_exp_sum()[1]); + + Func relative_error("relative_error"); + relative_error(x) = abs(naive_softmax(x) - two_pass_softmax(x)) / naive_softmax(x); + Func max_relative_error("max_relative_error"); + max_relative_error() = maximum(relative_error(r)); +#if 1 + Func max_prob("max_prob"); + max_prob() = maximum(two_pass_softmax(r)); + Func min_prob("min_prob"); + min_prob() = minimum(two_pass_softmax(r)); + Func sum_prob("sum_prob"); + sum_prob() = sum(two_pass_softmax(r)); +#else + Func max_prob("max_prob"); + max_prob() = maximum(naive_softmax(r)); + Func min_prob("min_prob"); + min_prob() = minimum(naive_softmax(r)); + Func sum_prob("sum_prob"); + sum_prob() = sum(naive_softmax(r)); +#endif + + Func result("result"); + result() = Tuple(max_relative_error(), max_prob(), min_prob(), sum_prob()); + exp_sum.compute_root(); + extended_exp_sum.compute_root(); + naive_softmax.compute_root(); + two_pass_softmax.compute_root(); + + auto output = result.realize(); + + float max_relative_error_result = ((Buffer &)output[0])(); + float max_probability = ((Buffer &)output[1])(); + float min_probability = ((Buffer &)output[2])(); + float sum_probability = ((Buffer &)output[3])(); + + std::cout << "Two pass softmax with scale " << scale + << "\nMax relative error: " << max_relative_error_result + << "\nmax probability: " << max_probability + << "\nmin probability: " << min_probability + << "\nsum of probabilities: " << sum_probability << "\n"; + + if (max_relative_error_result > .0001f) { + std::cout << "Failed: Softmax results do not match.\n"; + exit(1); + } +} + +void expect(float x, float mantissa, float exponent) { + float computed_mantissa; + float computed_exponent; + evaluate(halide_extended_exp(x), &computed_mantissa, &computed_exponent); + if (fabs(computed_mantissa) > exp(1.0f)) { + std::cout << "Mantissa large for x " << x << " mantissa " << computed_mantissa << " exponent " << computed_exponent << "\n"; + } + if (fabs(mantissa - computed_mantissa) > .00001 || + fabs(exponent - computed_exponent) > .00001) { + std::cout << "Falied: halide_extended_exp(" << x << ") == {" + << computed_mantissa << ", " << computed_exponent + << "} expected {" + << mantissa << ", " << exponent << "}\n"; + exit(1); + } +} + +int main(int argc, char **argv) { + std::cout << std::hexfloat; + expect(0, 1, 0); + expect(1, exp(1.0f) / 2, 1); + expect(88, 1.94149, 126); + expect(0x1.62e43p+23f, 0x1.085012p+0, 0x1p+24); + // Implementation does not support these yet. +#if 0 + expect(std::numeric_limits::lowest(), 0, 0); + expect(std::numeric_limits::max(), 0, 0); +#endif + two_pass_softmax_test(1.0f); + two_pass_softmax_test(10000.0f); + two_pass_softmax_test(-10000.0f); + two_pass_softmax_test(std::numeric_limits::max()); + two_pass_softmax_test(std::numeric_limits::lowest()); + std::cout << "Success\n"; +} From de4a6fa178f1ad5b6dfda182fb959875497fe6af Mon Sep 17 00:00:00 2001 From: Zalman Stern Date: Wed, 24 Apr 2024 10:25:32 -0700 Subject: [PATCH 2/7] Fix formatting. --- test/correctness/extended_exp.cpp | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/test/correctness/extended_exp.cpp b/test/correctness/extended_exp.cpp index 18ec5f773494..1ccd8dd5e444 100644 --- a/test/correctness/extended_exp.cpp +++ b/test/correctness/extended_exp.cpp @@ -79,10 +79,10 @@ void two_pass_softmax_test(float scale) { float sum_probability = ((Buffer &)output[3])(); std::cout << "Two pass softmax with scale " << scale - << "\nMax relative error: " << max_relative_error_result - << "\nmax probability: " << max_probability - << "\nmin probability: " << min_probability - << "\nsum of probabilities: " << sum_probability << "\n"; + << "\nMax relative error: " << max_relative_error_result + << "\nmax probability: " << max_probability + << "\nmin probability: " << min_probability + << "\nsum of probabilities: " << sum_probability << "\n"; if (max_relative_error_result > .0001f) { std::cout << "Failed: Softmax results do not match.\n"; @@ -95,15 +95,16 @@ void expect(float x, float mantissa, float exponent) { float computed_exponent; evaluate(halide_extended_exp(x), &computed_mantissa, &computed_exponent); if (fabs(computed_mantissa) > exp(1.0f)) { - std::cout << "Mantissa large for x " << x << " mantissa " << computed_mantissa << " exponent " << computed_exponent << "\n"; + std::cout << "Mantissa large for x " << x << " mantissa " << computed_mantissa + << " exponent " << computed_exponent << "\n"; } if (fabs(mantissa - computed_mantissa) > .00001 || - fabs(exponent - computed_exponent) > .00001) { + fabs(exponent - computed_exponent) > .00001) { std::cout << "Falied: halide_extended_exp(" << x << ") == {" - << computed_mantissa << ", " << computed_exponent - << "} expected {" - << mantissa << ", " << exponent << "}\n"; - exit(1); + << computed_mantissa << ", " << computed_exponent + << "} expected {" + << mantissa << ", " << exponent << "}\n"; + exit(1); } } From bcc724027b051afd94c20e63882af37011d7560b Mon Sep 17 00:00:00 2001 From: Zalman Stern Date: Wed, 24 Apr 2024 10:33:16 -0700 Subject: [PATCH 3/7] Fix formatting. --- test/correctness/extended_exp.cpp | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/test/correctness/extended_exp.cpp b/test/correctness/extended_exp.cpp index 1ccd8dd5e444..1f55a3ff5c39 100644 --- a/test/correctness/extended_exp.cpp +++ b/test/correctness/extended_exp.cpp @@ -15,7 +15,7 @@ using Halide::Internal::halide_extended_exp; void two_pass_softmax_test(float scale) { Var x("x"); RDom r(0, 1024); - + Func input("input"); input(x) = 0.0f; input(r) = random_float(); @@ -43,26 +43,17 @@ void two_pass_softmax_test(float scale) { Expr lambda = 1 / extended_exp_sum()[0]; Func two_pass_softmax("two_pass_softmax"); two_pass_softmax(x) = in_extended_exp(x)[0] * lambda * pow(2, in_extended_exp(x)[1] - extended_exp_sum()[1]); - + Func relative_error("relative_error"); relative_error(x) = abs(naive_softmax(x) - two_pass_softmax(x)) / naive_softmax(x); Func max_relative_error("max_relative_error"); max_relative_error() = maximum(relative_error(r)); -#if 1 Func max_prob("max_prob"); max_prob() = maximum(two_pass_softmax(r)); Func min_prob("min_prob"); min_prob() = minimum(two_pass_softmax(r)); Func sum_prob("sum_prob"); sum_prob() = sum(two_pass_softmax(r)); -#else - Func max_prob("max_prob"); - max_prob() = maximum(naive_softmax(r)); - Func min_prob("min_prob"); - min_prob() = minimum(naive_softmax(r)); - Func sum_prob("sum_prob"); - sum_prob() = sum(naive_softmax(r)); -#endif Func result("result"); result() = Tuple(max_relative_error(), max_prob(), min_prob(), sum_prob()); @@ -83,7 +74,7 @@ void two_pass_softmax_test(float scale) { << "\nmax probability: " << max_probability << "\nmin probability: " << min_probability << "\nsum of probabilities: " << sum_probability << "\n"; - + if (max_relative_error_result > .0001f) { std::cout << "Failed: Softmax results do not match.\n"; exit(1); @@ -96,7 +87,7 @@ void expect(float x, float mantissa, float exponent) { evaluate(halide_extended_exp(x), &computed_mantissa, &computed_exponent); if (fabs(computed_mantissa) > exp(1.0f)) { std::cout << "Mantissa large for x " << x << " mantissa " << computed_mantissa - << " exponent " << computed_exponent << "\n"; + << " exponent " << computed_exponent << "\n"; } if (fabs(mantissa - computed_mantissa) > .00001 || fabs(exponent - computed_exponent) > .00001) { From 23f6db1b2dc5bdc06e0e4cb7ab5e24284dea5167 Mon Sep 17 00:00:00 2001 From: Zalman Stern Date: Wed, 24 Apr 2024 11:39:33 -0700 Subject: [PATCH 4/7] Add extended_exp test to CMakeLists.txt. --- test/correctness/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 604ceda468f5..63c0ed6d041d 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -93,6 +93,7 @@ tests(GROUPS correctness erf.cpp exception.cpp explicit_inline_reductions.cpp + extended_exp.cpp extern_bounds_inference.cpp extern_consumer.cpp extern_error.cpp From 39e35f7c97a9bedc6c63da0a5008898bdb981a50 Mon Sep 17 00:00:00 2001 From: Zalman Stern Date: Wed, 24 Apr 2024 12:35:46 -0700 Subject: [PATCH 5/7] Appease stupid string match success check. --- src/IROperator.cpp | 4 ++-- test/correctness/extended_exp.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/IROperator.cpp b/src/IROperator.cpp index eaeda11ed6a8..8a5a1cc49375 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -863,8 +863,8 @@ Tuple halide_extended_exp(const Expr &x_full) { Expr scaled = x_full * one_over_ln2; Expr k_real = floor(scaled); - Expr x = x_full - k_real * ln2_part1; - x -= k_real * ln2_part2; + Expr x = strict_float(x_full - k_real * ln2_part1); + x = strict_float(x - k_real * ln2_part2); float coeff[] = { 0.00031965933071842413f, diff --git a/test/correctness/extended_exp.cpp b/test/correctness/extended_exp.cpp index 1f55a3ff5c39..705199de4e94 100644 --- a/test/correctness/extended_exp.cpp +++ b/test/correctness/extended_exp.cpp @@ -115,5 +115,5 @@ int main(int argc, char **argv) { two_pass_softmax_test(-10000.0f); two_pass_softmax_test(std::numeric_limits::max()); two_pass_softmax_test(std::numeric_limits::lowest()); - std::cout << "Success\n"; + std::cout << "Success!\n"; } From 422009a96f697dbd57356e75ab49be74e0948dc6 Mon Sep 17 00:00:00 2001 From: Zalman Stern Date: Wed, 24 Apr 2024 12:46:50 -0700 Subject: [PATCH 6/7] Remove strict_float experiment. --- src/IROperator.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 8a5a1cc49375..8f58208fa17e 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -863,8 +863,8 @@ Tuple halide_extended_exp(const Expr &x_full) { Expr scaled = x_full * one_over_ln2; Expr k_real = floor(scaled); - Expr x = strict_float(x_full - k_real * ln2_part1); - x = strict_float(x - k_real * ln2_part2); + Expr x = x_full - k_real * ln2_part1; + x = x - k_real * ln2_part2; float coeff[] = { 0.00031965933071842413f, From 3aab14e23767f3b2c7b6a2c0e64c339a5f50b325 Mon Sep 17 00:00:00 2001 From: Zalman Stern Date: Tue, 30 Apr 2024 23:28:17 -0700 Subject: [PATCH 7/7] Improve numerics slightly by returning positive and negative inifinity values for exponent part of extended exp. Improve test by comparing to three pass algorithm. --- src/IROperator.cpp | 2 ++ test/correctness/extended_exp.cpp | 47 ++++++++++++++++++++++--------- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/src/IROperator.cpp b/src/IROperator.cpp index c73dc448e488..b0a45c4e629d 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -879,6 +879,8 @@ Tuple halide_extended_exp(const Expr &x_full) { 1.0f}; Expr result = evaluate_polynomial(x, coeff, sizeof(coeff) / sizeof(coeff[0])); + // Ensure that the mantissa part is not a NaN or itself an infinity. + result = strict_float(select(!is_finite(k_real), 1, result)); result = common_subexpression_elimination(result); return {result, k_real}; diff --git a/test/correctness/extended_exp.cpp b/test/correctness/extended_exp.cpp index 705199de4e94..22b316eebe7e 100644 --- a/test/correctness/extended_exp.cpp +++ b/test/correctness/extended_exp.cpp @@ -18,8 +18,9 @@ void two_pass_softmax_test(float scale) { Func input("input"); input(x) = 0.0f; - input(r) = random_float(); + input(r) = random_float() * scale; + // Naive two pass algorithm. Doesn't work for large values or large size inputs. Func in_exp("in_exp"); in_exp(x) = halide_exp(input(x)); Func exp_sum("exp_sum"); @@ -28,6 +29,18 @@ void two_pass_softmax_test(float scale) { Func naive_softmax("naive_softmax"); naive_softmax(x) = in_exp(x) / exp_sum(); + // Three pass algorithm that works for all inputs. + Func max_input("max_input"); + max_input() = maximum(input(r)); + Func biased_in_exp("biased_in_exp"); + biased_in_exp(x) = halide_exp(input(x) - max_input()); + Func biased_exp_sum("biased_exp_sum"); + biased_exp_sum() = sum(biased_in_exp(r)); + + Func three_pass_softmax("three_pass_softmax"); + three_pass_softmax(x) = biased_in_exp(x) / biased_exp_sum(); + + // Two pass extended exp algorithm. Func in_extended_exp("in_extended_exp"); in_extended_exp(x) = halide_extended_exp(input(x)); Expr mantissa = in_extended_exp(x)[0]; @@ -45,7 +58,7 @@ void two_pass_softmax_test(float scale) { two_pass_softmax(x) = in_extended_exp(x)[0] * lambda * pow(2, in_extended_exp(x)[1] - extended_exp_sum()[1]); Func relative_error("relative_error"); - relative_error(x) = abs(naive_softmax(x) - two_pass_softmax(x)) / naive_softmax(x); + relative_error(x) = abs(three_pass_softmax(x) - two_pass_softmax(x)) / max(.000001f, three_pass_softmax(x)); Func max_relative_error("max_relative_error"); max_relative_error() = maximum(relative_error(r)); Func max_prob("max_prob"); @@ -58,8 +71,10 @@ void two_pass_softmax_test(float scale) { Func result("result"); result() = Tuple(max_relative_error(), max_prob(), min_prob(), sum_prob()); exp_sum.compute_root(); + biased_exp_sum.compute_root(); extended_exp_sum.compute_root(); naive_softmax.compute_root(); + three_pass_softmax.compute_root(); two_pass_softmax.compute_root(); auto output = result.realize(); @@ -69,16 +84,25 @@ void two_pass_softmax_test(float scale) { float min_probability = ((Buffer &)output[2])(); float sum_probability = ((Buffer &)output[3])(); - std::cout << "Two pass softmax with scale " << scale - << "\nMax relative error: " << max_relative_error_result - << "\nmax probability: " << max_probability - << "\nmin probability: " << min_probability - << "\nsum of probabilities: " << sum_probability << "\n"; - if (max_relative_error_result > .0001f) { std::cout << "Failed: Softmax results do not match.\n"; exit(1); } + + if (max_probability > 1.0f) { + std::cout << "Failed: Softmax probability is greater than 1.0f.\n"; + exit(1); + } + + if (min_probability < 0.0f) { + std::cout << "Failed: Softmax probability is negative.\n"; + exit(1); + } + + if (sum_probability > 1.0001f) { + std::cout << "Failed: Softmax probability sum is too large.\n"; + exit(1); + } } void expect(float x, float mantissa, float exponent) { @@ -105,11 +129,8 @@ int main(int argc, char **argv) { expect(1, exp(1.0f) / 2, 1); expect(88, 1.94149, 126); expect(0x1.62e43p+23f, 0x1.085012p+0, 0x1p+24); - // Implementation does not support these yet. -#if 0 - expect(std::numeric_limits::lowest(), 0, 0); - expect(std::numeric_limits::max(), 0, 0); -#endif + expect(std::numeric_limits::lowest(), 1.0f, -std::numeric_limits::infinity()); + expect(std::numeric_limits::max(), 1.0f, std::numeric_limits::infinity()); two_pass_softmax_test(1.0f); two_pass_softmax_test(10000.0f); two_pass_softmax_test(-10000.0f);