-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for extended exp operation as halide_extended_exp. #8206
Draft
zvookin
wants to merge
8
commits into
main
Choose a base branch
from
halide_extended_exp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
2326295
Initial checking of halide_extended_exp support.
de4a6fa
Fix formatting.
bcc7240
Fix formatting.
23f6db1
Add extended_exp test to CMakeLists.txt.
39e35f7
Appease stupid string match success check.
422009a
Remove strict_float experiment.
6e0673a
Merge branch 'main' into halide_extended_exp
3aab14e
Improve numerics slightly by returning positive and negative inifinity
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -189,6 +189,34 @@ Expr halide_exp(const Expr &a); | |
Expr halide_erf(const Expr &a); | ||
// @} | ||
|
||
/** Extended exponential which produces two output values, | ||
* each of the same precision as the input, as described in | ||
* "The Two-Pass Softmax Algorithm" by Marat Dukhan and | ||
* Artsiom Ablavatski [https://arxiv.org/abs/2001.04438]. | ||
* | ||
* The first element of the returned Tuple is a psuedo-mantissa while | ||
* the second is an exponent which is an integer. The product of the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So is the returned Tuple a pair of (float32, int32), or is it (float32, float32) where the second is always an integral value? |
||
* pseudo-mantissa and 2 raised to the returned exponent is the | ||
* desired result e^a. For arguments up to slightly greater than | ||
* 11629079, the pseudo-mantissa is guaranteed to be within the | ||
* interval (-e, e). For larger arguments, the exponent result of the | ||
* tuple may not be able to represent the exact integer necessary to | ||
* keep the pseudo-mantissa within bounds. Thus it can become | ||
* progressively larger in magnitude as the argument increases. | ||
* | ||
* Ideally this routine will maintain a degree of accuracy through the | ||
* entire range and be able to produce results out to the end of the | ||
* numeric range. At present neither of these properties are true due to | ||
* the following issues: | ||
* - Range reduction may overflow when scaling the argument. | ||
* - Range reduction is increasingly inaccurate in reducing the value | ||
* due to the implementation. This results in overflow in the polynomial | ||
* evaluation. | ||
* - Even if the above to issues were resolved, the approximation polynomial | ||
* would have to run on values outside its intended approximation range. | ||
*/ | ||
Tuple halide_extended_exp(const Expr &a); | ||
|
||
/** Raise an expression to an integer power by repeatedly multiplying | ||
* it by itself. */ | ||
Expr raise_to_integer_power(Expr a, int64_t b); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
#include "Halide.h" | ||
#include <cmath> | ||
#include <iomanip> | ||
#include <iostream> | ||
#include <limits> | ||
|
||
using namespace Halide; | ||
using Halide::Internal::halide_exp; | ||
using Halide::Internal::halide_extended_exp; | ||
|
||
// Compare naive two pass softmax, which will overflow easily, to two | ||
// pass algorithm from "The Two-Pass Softmax Algorithm" by Marat | ||
// Dukhan and Artsiom Ablavatski [https://arxiv.org/abs/2001.04438], | ||
// which is implemented using halide_extended_exp. | ||
void two_pass_softmax_test(float scale) { | ||
Var x("x"); | ||
RDom r(0, 1024); | ||
|
||
Func input("input"); | ||
input(x) = 0.0f; | ||
input(r) = random_float() * scale; | ||
|
||
// Naive two pass algorithm. Doesn't work for large values or large size inputs. | ||
Func in_exp("in_exp"); | ||
in_exp(x) = halide_exp(input(x)); | ||
Func exp_sum("exp_sum"); | ||
exp_sum() = sum(in_exp(r)); | ||
|
||
Func naive_softmax("naive_softmax"); | ||
naive_softmax(x) = in_exp(x) / exp_sum(); | ||
|
||
// Three pass algorithm that works for all inputs. | ||
Func max_input("max_input"); | ||
max_input() = maximum(input(r)); | ||
Func biased_in_exp("biased_in_exp"); | ||
biased_in_exp(x) = halide_exp(input(x) - max_input()); | ||
Func biased_exp_sum("biased_exp_sum"); | ||
biased_exp_sum() = sum(biased_in_exp(r)); | ||
|
||
Func three_pass_softmax("three_pass_softmax"); | ||
three_pass_softmax(x) = biased_in_exp(x) / biased_exp_sum(); | ||
|
||
// Two pass extended exp algorithm. | ||
Func in_extended_exp("in_extended_exp"); | ||
in_extended_exp(x) = halide_extended_exp(input(x)); | ||
Expr mantissa = in_extended_exp(x)[0]; | ||
Expr exponent = in_extended_exp(x)[1]; | ||
|
||
Func extended_exp_sum("extended_exp_sum"); | ||
extended_exp_sum() = Tuple(0.0f, std::numeric_limits<float>::lowest()); // mantissa, exponent | ||
Expr max_exp = max(extended_exp_sum()[1], in_extended_exp(r)[1]); | ||
Expr mantissa_sum = in_extended_exp(r)[0] * pow(2, in_extended_exp(r)[1] - max_exp) + | ||
extended_exp_sum()[0] * pow(2, extended_exp_sum()[1] - max_exp); | ||
extended_exp_sum() = Tuple(mantissa_sum, max_exp); | ||
|
||
Expr lambda = 1 / extended_exp_sum()[0]; | ||
Func two_pass_softmax("two_pass_softmax"); | ||
two_pass_softmax(x) = in_extended_exp(x)[0] * lambda * pow(2, in_extended_exp(x)[1] - extended_exp_sum()[1]); | ||
|
||
Func relative_error("relative_error"); | ||
relative_error(x) = abs(three_pass_softmax(x) - two_pass_softmax(x)) / max(.000001f, three_pass_softmax(x)); | ||
Func max_relative_error("max_relative_error"); | ||
max_relative_error() = maximum(relative_error(r)); | ||
Func max_prob("max_prob"); | ||
max_prob() = maximum(two_pass_softmax(r)); | ||
Func min_prob("min_prob"); | ||
min_prob() = minimum(two_pass_softmax(r)); | ||
Func sum_prob("sum_prob"); | ||
sum_prob() = sum(two_pass_softmax(r)); | ||
|
||
Func result("result"); | ||
result() = Tuple(max_relative_error(), max_prob(), min_prob(), sum_prob()); | ||
exp_sum.compute_root(); | ||
biased_exp_sum.compute_root(); | ||
extended_exp_sum.compute_root(); | ||
naive_softmax.compute_root(); | ||
three_pass_softmax.compute_root(); | ||
two_pass_softmax.compute_root(); | ||
|
||
auto output = result.realize(); | ||
|
||
float max_relative_error_result = ((Buffer<float> &)output[0])(); | ||
float max_probability = ((Buffer<float> &)output[1])(); | ||
float min_probability = ((Buffer<float> &)output[2])(); | ||
float sum_probability = ((Buffer<float> &)output[3])(); | ||
|
||
if (max_relative_error_result > .0001f) { | ||
std::cout << "Failed: Softmax results do not match.\n"; | ||
exit(1); | ||
} | ||
|
||
if (max_probability > 1.0f) { | ||
std::cout << "Failed: Softmax probability is greater than 1.0f.\n"; | ||
exit(1); | ||
} | ||
|
||
if (min_probability < 0.0f) { | ||
std::cout << "Failed: Softmax probability is negative.\n"; | ||
exit(1); | ||
} | ||
|
||
if (sum_probability > 1.0001f) { | ||
std::cout << "Failed: Softmax probability sum is too large.\n"; | ||
exit(1); | ||
} | ||
} | ||
|
||
void expect(float x, float mantissa, float exponent) { | ||
float computed_mantissa; | ||
float computed_exponent; | ||
evaluate(halide_extended_exp(x), &computed_mantissa, &computed_exponent); | ||
if (fabs(computed_mantissa) > exp(1.0f)) { | ||
std::cout << "Mantissa large for x " << x << " mantissa " << computed_mantissa | ||
<< " exponent " << computed_exponent << "\n"; | ||
} | ||
if (fabs(mantissa - computed_mantissa) > .00001 || | ||
fabs(exponent - computed_exponent) > .00001) { | ||
std::cout << "Falied: halide_extended_exp(" << x << ") == {" | ||
<< computed_mantissa << ", " << computed_exponent | ||
<< "} expected {" | ||
<< mantissa << ", " << exponent << "}\n"; | ||
exit(1); | ||
} | ||
} | ||
|
||
int main(int argc, char **argv) { | ||
std::cout << std::hexfloat; | ||
expect(0, 1, 0); | ||
expect(1, exp(1.0f) / 2, 1); | ||
expect(88, 1.94149, 126); | ||
expect(0x1.62e43p+23f, 0x1.085012p+0, 0x1p+24); | ||
expect(std::numeric_limits<float>::lowest(), 1.0f, -std::numeric_limits<float>::infinity()); | ||
expect(std::numeric_limits<float>::max(), 1.0f, std::numeric_limits<float>::infinity()); | ||
two_pass_softmax_test(1.0f); | ||
two_pass_softmax_test(10000.0f); | ||
two_pass_softmax_test(-10000.0f); | ||
two_pass_softmax_test(std::numeric_limits<float>::max()); | ||
two_pass_softmax_test(std::numeric_limits<float>::lowest()); | ||
std::cout << "Success!\n"; | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pseudo