diff --git a/pkg/solidity-utils/contracts/math/LogExpMath.sol b/pkg/solidity-utils/contracts/math/LogExpMath.sol index f4848b243a..1ecf94fa34 100644 --- a/pkg/solidity-utils/contracts/math/LogExpMath.sol +++ b/pkg/solidity-utils/contracts/math/LogExpMath.sol @@ -35,10 +35,8 @@ library LogExpMath { // All arguments and return values are 18 decimal fixed point numbers. int256 constant ONE_18 = 1e18; - // Internally, intermediate values are computed with higher precision as 20 decimal fixed point numbers, and in the - // case of ln36, 36 decimals. + // Internally, intermediate values are computed with higher precision as 20 decimal fixed point numbers. int256 constant ONE_20 = 1e20; - int256 constant ONE_36 = 1e36; // The domain of natural exponentiation is bound by the word size and number of decimals used. // @@ -50,41 +48,8 @@ library LogExpMath { int256 constant MAX_NATURAL_EXPONENT = 130e18; int256 constant MIN_NATURAL_EXPONENT = -41e18; - // Bounds for ln_36's argument. Both ln(0.9) and ln(1.1) can be represented with 36 decimal places in a fixed point - // 256 bit integer. - int256 constant LN_36_LOWER_BOUND = ONE_18 - 1e17; - int256 constant LN_36_UPPER_BOUND = ONE_18 + 1e17; - uint256 constant MILD_EXPONENT_BOUND = 2**254 / uint256(ONE_20); - // 18 decimal constants - int256 constant x0 = 128000000000000000000; // 2ˆ7 - int256 constant a0 = 38877084059945950922200000000000000000000000000000000000; // eˆ(x0) (no decimals) - int256 constant x1 = 64000000000000000000; // 2ˆ6 - int256 constant a1 = 6235149080811616882910000000; // eˆ(x1) (no decimals) - - // 20 decimal constants - int256 constant x2 = 3200000000000000000000; // 2ˆ5 - int256 constant a2 = 7896296018268069516100000000000000; // eˆ(x2) - int256 constant x3 = 1600000000000000000000; // 2ˆ4 - int256 constant a3 = 888611052050787263676000000; // eˆ(x3) - int256 constant x4 = 800000000000000000000; // 2ˆ3 - int256 constant a4 = 298095798704172827474000; // eˆ(x4) - int256 constant x5 = 400000000000000000000; // 2ˆ2 - int256 constant a5 = 5459815003314423907810; // eˆ(x5) - int256 constant x6 = 200000000000000000000; // 2ˆ1 - int256 constant a6 = 738905609893065022723; // eˆ(x6) - int256 constant x7 = 100000000000000000000; // 2ˆ0 - int256 constant a7 = 271828182845904523536; // eˆ(x7) - int256 constant x8 = 50000000000000000000; // 2ˆ-1 - int256 constant a8 = 164872127070012814685; // eˆ(x8) - int256 constant x9 = 25000000000000000000; // 2ˆ-2 - int256 constant a9 = 128402541668774148407; // eˆ(x9) - int256 constant x10 = 12500000000000000000; // 2ˆ-3 - int256 constant a10 = 113314845306682631683; // eˆ(x10) - int256 constant x11 = 6250000000000000000; // 2ˆ-4 - int256 constant a11 = 106449445891785942956; // eˆ(x11) - /** * @dev Exponentiation (x^y) with unsigned 18 decimal fixed point base and exponent. * @@ -115,19 +80,7 @@ library LogExpMath { _require(y < MILD_EXPONENT_BOUND, Errors.Y_OUT_OF_BOUNDS); int256 y_int256 = int256(y); - int256 logx_times_y; - if (LN_36_LOWER_BOUND < x_int256 && x_int256 < LN_36_UPPER_BOUND) { - int256 ln_36_x = _ln_36(x_int256); - - // ln_36_x has 36 decimal places, so multiplying by y_int256 isn't as straightforward, since we can't just - // bring y_int256 to 36 decimal places, as it might overflow. Instead, we perform two 18 decimal - // multiplications and add the results: one with the first 18 decimals of ln_36_x, and one with the - // (downscaled) last 18 decimals. - logx_times_y = ((ln_36_x / ONE_18) * y_int256 + ((ln_36_x % ONE_18) * y_int256) / ONE_18); - } else { - logx_times_y = _ln(x_int256) * y_int256; - } - logx_times_y /= ONE_18; + int256 logx_times_y = (ln(x_int256) * y_int256) / ONE_18; // Finally, we compute exp(y * ln(x)) to arrive at x^y _require( @@ -143,372 +96,129 @@ library LogExpMath { * * Reverts if `x` is smaller than MIN_NATURAL_EXPONENT, or larger than `MAX_NATURAL_EXPONENT`. */ - function exp(int256 x) internal pure returns (int256) { + function exp(int256 x) internal pure returns (int256 r) { _require(x >= MIN_NATURAL_EXPONENT && x <= MAX_NATURAL_EXPONENT, Errors.INVALID_EXPONENT); - if (x < 0) { - // We only handle positive exponents: e^(-x) is computed as 1 / e^x. We can safely make x positive since it - // fits in the signed 256 bit range (as it is larger than MIN_NATURAL_EXPONENT). - // Fixed point division requires multiplying by ONE_18. - return ((ONE_18 * ONE_18) / exp(-x)); - } - - // First, we use the fact that e^(x+y) = e^x * e^y to decompose x into a sum of powers of two, which we call x_n, - // where x_n == 2^(7 - n), and e^x_n = a_n has been precomputed. We choose the first x_n, x0, to equal 2^7 - // because all larger powers are larger than MAX_NATURAL_EXPONENT, and therefore not present in the - // decomposition. - // At the end of this process we will have the product of all e^x_n = a_n that apply, and the remainder of this - // decomposition, which will be lower than the smallest x_n. - // exp(x) = k_0 * a_0 * k_1 * a_1 * ... + k_n * a_n * exp(remainder), where each k_n equals either 0 or 1. - // We mutate x by subtracting x_n, making it the remainder of the decomposition. - - // The first two a_n (e^(2^7) and e^(2^6)) are too large if stored as 18 decimal numbers, and could cause - // intermediate overflows. Instead we store them as plain integers, with 0 decimals. - // Additionally, x0 + x1 is larger than MAX_NATURAL_EXPONENT, which means they will not both be present in the - // decomposition. - - // For each x_n, we test if that term is present in the decomposition (if x is larger than it), and if so deduct - // it and compute the accumulated product. - - int256 firstAN; - if (x >= x0) { - x -= x0; - firstAN = a0; - } else if (x >= x1) { - x -= x1; - firstAN = a1; - } else { - firstAN = 1; // One with no decimal places - } - - // We now transform x into a 20 decimal fixed point number, to have enhanced precision when computing the - // smaller terms. - x *= 100; - - // `product` is the accumulated product of all a_n (except a0 and a1), which starts at 20 decimal fixed point - // one. Recall that fixed point multiplication requires dividing by ONE_20. - int256 product = ONE_20; - - if (x >= x2) { - x -= x2; - product = (product * a2) / ONE_20; - } - if (x >= x3) { - x -= x3; - product = (product * a3) / ONE_20; - } - if (x >= x4) { - x -= x4; - product = (product * a4) / ONE_20; - } - if (x >= x5) { - x -= x5; - product = (product * a5) / ONE_20; - } - if (x >= x6) { - x -= x6; - product = (product * a6) / ONE_20; - } - if (x >= x7) { - x -= x7; - product = (product * a7) / ONE_20; - } - if (x >= x8) { - x -= x8; - product = (product * a8) / ONE_20; - } - if (x >= x9) { - x -= x9; - product = (product * a9) / ONE_20; - } - - // x10 and x11 are unnecessary here since we have high enough precision already. - - // Now we need to compute e^x, where x is small (in particular, it is smaller than x9). We use the Taylor series - // expansion for e^x: 1 + x + (x^2 / 2!) + (x^3 / 3!) + ... + (x^n / n!). - - int256 seriesSum = ONE_20; // The initial one in the sum, with 20 decimal places. - int256 term; // Each term in the sum, where the nth term is (x^n / n!). - - // The first term is simply x. - term = x; - seriesSum += term; - - // Each term (x^n / n!) equals the previous one times x, divided by n. Since x is a fixed point number, - // multiplying by it requires dividing by ONE_20, but dividing by the non-fixed point n values does not. - - term = ((term * x) / ONE_20) / 2; - seriesSum += term; - - term = ((term * x) / ONE_20) / 3; - seriesSum += term; - - term = ((term * x) / ONE_20) / 4; - seriesSum += term; - - term = ((term * x) / ONE_20) / 5; - seriesSum += term; - - term = ((term * x) / ONE_20) / 6; - seriesSum += term; - - term = ((term * x) / ONE_20) / 7; - seriesSum += term; - - term = ((term * x) / ONE_20) / 8; - seriesSum += term; - - term = ((term * x) / ONE_20) / 9; - seriesSum += term; - - term = ((term * x) / ONE_20) / 10; - seriesSum += term; - - term = ((term * x) / ONE_20) / 11; - seriesSum += term; - - term = ((term * x) / ONE_20) / 12; - seriesSum += term; - - // 12 Taylor terms are sufficient for 18 decimal precision. - - // We now have the first a_n (with no decimals), and the product of all other a_n present, and the Taylor - // approximation of the exponentiation of the remainder (both with 20 decimals). All that remains is to multiply - // all three (one 20 decimal fixed point multiplication, dividing by ONE_20, and one integer multiplication), - // and then drop two digits to return an 18 decimal value. - - return (((product * seriesSum) / ONE_20) * firstAN) / 100; + // x is now in the range (-42, 136) * 1e18. Convert to (-42, 136) * 2**96 + // for more intermediate precision and a binary basis. This base conversion + // is a multiplication by 1e18 / 2**96 = 5**18 / 2**78. + x = (x << 78) / 5**18; + + // Reduce range of x to (-½ ln 2, ½ ln 2) * 2**96 by factoring out powers + // of two such that exp(x) = exp(x') * 2**k, where k is an integer. + // Solving this gives k = round(x / log(2)) and x' = x - k * log(2). + int256 k = ((x << 96) / 54916777467707473351141471128 + 2**95) >> 96; + x = x - k * 54916777467707473351141471128; + + // k is in the range [-61, 195]. + + // Evaluate using a (6, 7)-term rational approximation. + // p is made monic, we'll multiply by a scale factor later. + int256 y = x + 1346386616545796478920950773328; + y = ((y * x) >> 96) + 57155421227552351082224309758442; + int256 p = y + x - 94201549194550492254356042504812; + p = ((p * y) >> 96) + 28719021644029726153956944680412240; + p = p * x + (4385272521454847904659076985693276 << 96); + + // We leave p in 2**192 basis so we don't need to scale it back up for the division. + int256 q = x - 2855989394907223263936484059900; + q = ((q * x) >> 96) + 50020603652535783019961831881945; + q = ((q * x) >> 96) - 533845033583426703283633433725380; + q = ((q * x) >> 96) + 3604857256930695427073651918091429; + q = ((q * x) >> 96) - 14423608567350463180887372962807573; + q = ((q * x) >> 96) + 26449188498355588339934803723976023; + + /// @solidity memory-safe-assembly + assembly { + // Div in assembly because solidity adds a zero check despite the unchecked. + // The q polynomial won't have zeros in the domain as all its roots are complex. + // No scaling is necessary because p is already 2**96 too large. + r := sdiv(p, q) + } + + // r should be in the range (0.09, 0.25) * 2**96. + + // We now need to multiply r by: + // * the scale factor s = ~6.031367120. + // * the 2**k factor from the range reduction. + // * the 1e18 / 2**96 factor for base conversion. + // We do this all at once, with an intermediate result in 2**213 + // basis, so the final right shift is always by a positive amount. + r = int256((uint256(r) * 3822833074963236453042738258902158003155416615667) >> uint256(195 - k)); } /** - * @dev Logarithm (log(arg, base), with signed 18 decimal fixed point base and argument. + * @dev Natural logarithm (ln(x)) with signed 18 decimal fixed point argument. */ - function log(int256 arg, int256 base) internal pure returns (int256) { - // This performs a simple base change: log(arg, base) = ln(arg) / ln(base). - - // Both logBase and logArg are computed as 36 decimal fixed point numbers, either by using ln_36, or by - // upscaling. - - int256 logBase; - if (LN_36_LOWER_BOUND < base && base < LN_36_UPPER_BOUND) { - logBase = _ln_36(base); - } else { - logBase = _ln(base) * ONE_18; - } - - int256 logArg; - if (LN_36_LOWER_BOUND < arg && arg < LN_36_UPPER_BOUND) { - logArg = _ln_36(arg); - } else { - logArg = _ln(arg) * ONE_18; - } - - // When dividing, we multiply by ONE_18 to arrive at a result with 18 decimal places - return (logArg * ONE_18) / logBase; - } - - /** - * @dev Natural logarithm (ln(a)) with signed 18 decimal fixed point argument. - */ - function ln(int256 a) internal pure returns (int256) { + function ln(int256 x) internal pure returns (int256 r) { // The real natural logarithm is not defined for negative numbers or zero. - _require(a > 0, Errors.OUT_OF_BOUNDS); - if (LN_36_LOWER_BOUND < a && a < LN_36_UPPER_BOUND) { - return _ln_36(a) / ONE_18; - } else { - return _ln(a); - } - } - - /** - * @dev Internal natural logarithm (ln(a)) with signed 18 decimal fixed point argument. - */ - function _ln(int256 a) private pure returns (int256) { - if (a < ONE_18) { - // Since ln(a^k) = k * ln(a), we can compute ln(a) as ln(a) = ln((1/a)^(-1)) = - ln((1/a)). If a is less - // than one, 1/a will be greater than one, and this if statement will not be entered in the recursive call. - // Fixed point division requires multiplying by ONE_18. - return (-_ln((ONE_18 * ONE_18) / a)); - } - - // First, we use the fact that ln^(a * b) = ln(a) + ln(b) to decompose ln(a) into a sum of powers of two, which - // we call x_n, where x_n == 2^(7 - n), which are the natural logarithm of precomputed quantities a_n (that is, - // ln(a_n) = x_n). We choose the first x_n, x0, to equal 2^7 because the exponential of all larger powers cannot - // be represented as 18 fixed point decimal numbers in 256 bits, and are therefore larger than a. - // At the end of this process we will have the sum of all x_n = ln(a_n) that apply, and the remainder of this - // decomposition, which will be lower than the smallest a_n. - // ln(a) = k_0 * x_0 + k_1 * x_1 + ... + k_n * x_n + ln(remainder), where each k_n equals either 0 or 1. - // We mutate a by subtracting a_n, making it the remainder of the decomposition. - - // For reasons related to how `exp` works, the first two a_n (e^(2^7) and e^(2^6)) are not stored as fixed point - // numbers with 18 decimals, but instead as plain integers with 0 decimals, so we need to multiply them by - // ONE_18 to convert them to fixed point. - // For each a_n, we test if that term is present in the decomposition (if a is larger than it), and if so divide - // by it and compute the accumulated sum. - - int256 sum = 0; - if (a >= a0 * ONE_18) { - a /= a0; // Integer, not fixed point division - sum += x0; - } - - if (a >= a1 * ONE_18) { - a /= a1; // Integer, not fixed point division - sum += x1; - } - - // All other a_n and x_n are stored as 20 digit fixed point numbers, so we convert the sum and a to this format. - sum *= 100; - a *= 100; - - // Because further a_n are 20 digit fixed point numbers, we multiply by ONE_20 when dividing by them. - - if (a >= a2) { - a = (a * ONE_20) / a2; - sum += x2; - } - - if (a >= a3) { - a = (a * ONE_20) / a3; - sum += x3; - } - - if (a >= a4) { - a = (a * ONE_20) / a4; - sum += x4; - } - - if (a >= a5) { - a = (a * ONE_20) / a5; - sum += x5; - } - - if (a >= a6) { - a = (a * ONE_20) / a6; - sum += x6; - } - - if (a >= a7) { - a = (a * ONE_20) / a7; - sum += x7; - } - - if (a >= a8) { - a = (a * ONE_20) / a8; - sum += x8; - } - - if (a >= a9) { - a = (a * ONE_20) / a9; - sum += x9; - } - - if (a >= a10) { - a = (a * ONE_20) / a10; - sum += x10; - } - - if (a >= a11) { - a = (a * ONE_20) / a11; - sum += x11; - } - - // a is now a small number (smaller than a_11, which roughly equals 1.06). This means we can use a Taylor series - // that converges rapidly for values of `a` close to one - the same one used in ln_36. - // Let z = (a - 1) / (a + 1). - // ln(a) = 2 * (z + z^3 / 3 + z^5 / 5 + z^7 / 7 + ... + z^(2 * n + 1) / (2 * n + 1)) - - // Recall that 20 digit fixed point division requires multiplying by ONE_20, and multiplication requires - // division by ONE_20. - int256 z = ((a - ONE_20) * ONE_20) / (a + ONE_20); - int256 z_squared = (z * z) / ONE_20; - - // num is the numerator of the series: the z^(2 * n + 1) term - int256 num = z; - - // seriesSum holds the accumulated sum of each term in the series, starting with the initial z - int256 seriesSum = num; - - // In each step, the numerator is multiplied by z^2 - num = (num * z_squared) / ONE_20; - seriesSum += num / 3; - - num = (num * z_squared) / ONE_20; - seriesSum += num / 5; - - num = (num * z_squared) / ONE_20; - seriesSum += num / 7; - - num = (num * z_squared) / ONE_20; - seriesSum += num / 9; - - num = (num * z_squared) / ONE_20; - seriesSum += num / 11; - - // 6 Taylor terms are sufficient for 36 decimal precision. - - // Finally, we multiply by 2 (non fixed point) to compute ln(remainder) - seriesSum *= 2; - - // We now have the sum of all x_n present, and the Taylor approximation of the logarithm of the remainder (both - // with 20 decimals). All that remains is to sum these two, and then drop two digits to return a 18 decimal - // value. - - return (sum + seriesSum) / 100; - } - - /** - * @dev Intrnal high precision (36 decimal places) natural logarithm (ln(x)) with signed 18 decimal fixed point argument, - * for x close to one. - * - * Should only be used if x is between LN_36_LOWER_BOUND and LN_36_UPPER_BOUND. - */ - function _ln_36(int256 x) private pure returns (int256) { - // Since ln(1) = 0, a value of x close to one will yield a very small result, which makes using 36 digits - // worthwhile. - - // First, we transform x to a 36 digit fixed point value. - x *= ONE_18; - - // We will use the following Taylor expansion, which converges very rapidly. Let z = (x - 1) / (x + 1). - // ln(x) = 2 * (z + z^3 / 3 + z^5 / 5 + z^7 / 7 + ... + z^(2 * n + 1) / (2 * n + 1)) - - // Recall that 36 digit fixed point division requires multiplying by ONE_36, and multiplication requires - // division by ONE_36. - int256 z = ((x - ONE_36) * ONE_36) / (x + ONE_36); - int256 z_squared = (z * z) / ONE_36; - - // num is the numerator of the series: the z^(2 * n + 1) term - int256 num = z; - - // seriesSum holds the accumulated sum of each term in the series, starting with the initial z - int256 seriesSum = num; - - // In each step, the numerator is multiplied by z^2 - num = (num * z_squared) / ONE_36; - seriesSum += num / 3; - - num = (num * z_squared) / ONE_36; - seriesSum += num / 5; - - num = (num * z_squared) / ONE_36; - seriesSum += num / 7; - - num = (num * z_squared) / ONE_36; - seriesSum += num / 9; - - num = (num * z_squared) / ONE_36; - seriesSum += num / 11; - - num = (num * z_squared) / ONE_36; - seriesSum += num / 13; - - num = (num * z_squared) / ONE_36; - seriesSum += num / 15; - - // 8 Taylor terms are sufficient for 36 decimal precision. - - // All that remains is multiplying by 2 (non fixed point). - return seriesSum * 2; + _require(x > 0, Errors.OUT_OF_BOUNDS); + + // We want to convert x from 10**18 fixed point to 2**96 fixed point. + // We do this by multiplying by 2**96 / 10**18. But since + // ln(x * C) = ln(x) + ln(C), we can simply do nothing here + // and add ln(2**96 / 10**18) at the end. + + /// @solidity memory-safe-assembly + assembly { + r := shl(7, lt(0xffffffffffffffffffffffffffffffff, x)) + r := or(r, shl(6, lt(0xffffffffffffffff, shr(r, x)))) + r := or(r, shl(5, lt(0xffffffff, shr(r, x)))) + r := or(r, shl(4, lt(0xffff, shr(r, x)))) + r := or(r, shl(3, lt(0xff, shr(r, x)))) + r := or(r, shl(2, lt(0xf, shr(r, x)))) + r := or(r, shl(1, lt(0x3, shr(r, x)))) + r := or(r, lt(0x1, shr(r, x))) + } + + // Reduce range of x to (1, 2) * 2**96 + // ln(2^k * x) = k * ln(2) + ln(x) + int256 k = r - 96; + x <<= uint256(159 - k); + x = int256(uint256(x) >> 159); + + // Evaluate using a (8, 8)-term rational approximation. + // p is made monic, we will multiply by a scale factor later. + int256 p = x + 3273285459638523848632254066296; + p = ((p * x) >> 96) + 24828157081833163892658089445524; + p = ((p * x) >> 96) + 43456485725739037958740375743393; + p = ((p * x) >> 96) - 11111509109440967052023855526967; + p = ((p * x) >> 96) - 45023709667254063763336534515857; + p = ((p * x) >> 96) - 14706773417378608786704636184526; + p = p * x - (795164235651350426258249787498 << 96); + + // We leave p in 2**192 basis so we don't need to scale it back up for the division. + // q is monic by convention. + int256 q = x + 5573035233440673466300451813936; + q = ((q * x) >> 96) + 71694874799317883764090561454958; + q = ((q * x) >> 96) + 283447036172924575727196451306956; + q = ((q * x) >> 96) + 401686690394027663651624208769553; + q = ((q * x) >> 96) + 204048457590392012362485061816622; + q = ((q * x) >> 96) + 31853899698501571402653359427138; + q = ((q * x) >> 96) + 909429971244387300277376558375; + /// @solidity memory-safe-assembly + assembly { + // Div in assembly because solidity adds a zero check despite the unchecked. + // The q polynomial is known not to have zeros in the domain. + // No scaling required because p is already 2**96 too large. + r := sdiv(p, q) + } + + // r is in the range (0, 0.125) * 2**96 + + // Finalization, we need to: + // * multiply by the scale factor s = 5.549… + // * add ln(2**96 / 10**18) + // * add k * ln(2) + // * multiply by 10**18 / 2**96 = 5**18 >> 78 + + // mul s * 5e18 * 2**96, base is now 5**18 * 2**192 + r *= 1677202110996718588342820967067443963516166; + // add ln(2) * k * 5e18 * 2**192 + r += 16597577552685614221487285958193947469193820559219878177908093499208371 * k; + // add ln(2**96 / 10**18) * 5e18 * 2**192 + r += 600920179829731861736702779321621459595472258049074101567377883020018308; + // base conversion: mul 2**18 / 2**192 + r >>= 174; } } diff --git a/pkg/solidity-utils/contracts/math/ReferenceLogExpMath.sol b/pkg/solidity-utils/contracts/math/ReferenceLogExpMath.sol new file mode 100644 index 0000000000..8382cca9bc --- /dev/null +++ b/pkg/solidity-utils/contracts/math/ReferenceLogExpMath.sol @@ -0,0 +1,514 @@ +// SPDX-License-Identifier: MIT +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +// documentation files (the “Software”), to deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +// Software. + +// THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +pragma solidity ^0.7.0; + +import "@balancer-labs/v2-interfaces/contracts/solidity-utils/helpers/BalancerErrors.sol"; + +/* solhint-disable */ + +/** + * @dev Exponentiation and logarithm functions for 18 decimal fixed point numbers (both base and exponent/argument). + * + * Exponentiation and logarithm with arbitrary bases (x^y and log_x(y)) are implemented by conversion to natural + * exponentiation and logarithm (where the base is Euler's number). + * + * @author Fernando Martinelli - @fernandomartinelli + * @author Sergio Yuhjtman - @sergioyuhjtman + * @author Daniel Fernandez - @dmf7z + */ +library ReferenceLogExpMath { + // All fixed point multiplications and divisions are inlined. This means we need to divide by ONE when multiplying + // two numbers, and multiply by ONE when dividing them. + + // All arguments and return values are 18 decimal fixed point numbers. + int256 constant ONE_18 = 1e18; + + // Internally, intermediate values are computed with higher precision as 20 decimal fixed point numbers, and in the + // case of ln36, 36 decimals. + int256 constant ONE_20 = 1e20; + int256 constant ONE_36 = 1e36; + + // The domain of natural exponentiation is bound by the word size and number of decimals used. + // + // Because internally the result will be stored using 20 decimals, the largest possible result is + // (2^255 - 1) / 10^20, which makes the largest exponent ln((2^255 - 1) / 10^20) = 130.700829182905140221. + // The smallest possible result is 10^(-18), which makes largest negative argument + // ln(10^(-18)) = -41.446531673892822312. + // We use 130.0 and -41.0 to have some safety margin. + int256 constant MAX_NATURAL_EXPONENT = 130e18; + int256 constant MIN_NATURAL_EXPONENT = -41e18; + + // Bounds for ln_36's argument. Both ln(0.9) and ln(1.1) can be represented with 36 decimal places in a fixed point + // 256 bit integer. + int256 constant LN_36_LOWER_BOUND = ONE_18 - 1e17; + int256 constant LN_36_UPPER_BOUND = ONE_18 + 1e17; + + uint256 constant MILD_EXPONENT_BOUND = 2**254 / uint256(ONE_20); + + // 18 decimal constants + int256 constant x0 = 128000000000000000000; // 2ˆ7 + int256 constant a0 = 38877084059945950922200000000000000000000000000000000000; // eˆ(x0) (no decimals) + int256 constant x1 = 64000000000000000000; // 2ˆ6 + int256 constant a1 = 6235149080811616882910000000; // eˆ(x1) (no decimals) + + // 20 decimal constants + int256 constant x2 = 3200000000000000000000; // 2ˆ5 + int256 constant a2 = 7896296018268069516100000000000000; // eˆ(x2) + int256 constant x3 = 1600000000000000000000; // 2ˆ4 + int256 constant a3 = 888611052050787263676000000; // eˆ(x3) + int256 constant x4 = 800000000000000000000; // 2ˆ3 + int256 constant a4 = 298095798704172827474000; // eˆ(x4) + int256 constant x5 = 400000000000000000000; // 2ˆ2 + int256 constant a5 = 5459815003314423907810; // eˆ(x5) + int256 constant x6 = 200000000000000000000; // 2ˆ1 + int256 constant a6 = 738905609893065022723; // eˆ(x6) + int256 constant x7 = 100000000000000000000; // 2ˆ0 + int256 constant a7 = 271828182845904523536; // eˆ(x7) + int256 constant x8 = 50000000000000000000; // 2ˆ-1 + int256 constant a8 = 164872127070012814685; // eˆ(x8) + int256 constant x9 = 25000000000000000000; // 2ˆ-2 + int256 constant a9 = 128402541668774148407; // eˆ(x9) + int256 constant x10 = 12500000000000000000; // 2ˆ-3 + int256 constant a10 = 113314845306682631683; // eˆ(x10) + int256 constant x11 = 6250000000000000000; // 2ˆ-4 + int256 constant a11 = 106449445891785942956; // eˆ(x11) + + /** + * @dev Exponentiation (x^y) with unsigned 18 decimal fixed point base and exponent. + * + * Reverts if ln(x) * y is smaller than `MIN_NATURAL_EXPONENT`, or larger than `MAX_NATURAL_EXPONENT`. + */ + function pow(uint256 x, uint256 y) internal pure returns (uint256) { + if (y == 0) { + // We solve the 0^0 indetermination by making it equal one. + return uint256(ONE_18); + } + + if (x == 0) { + return 0; + } + + // Instead of computing x^y directly, we instead rely on the properties of logarithms and exponentiation to + // arrive at that result. In particular, exp(ln(x)) = x, and ln(x^y) = y * ln(x). This means + // x^y = exp(y * ln(x)). + + // The ln function takes a signed value, so we need to make sure x fits in the signed 256 bit range. + _require(x >> 255 == 0, Errors.X_OUT_OF_BOUNDS); + int256 x_int256 = int256(x); + + // We will compute y * ln(x) in a single step. Depending on the value of x, we can either use ln or ln_36. In + // both cases, we leave the division by ONE_18 (due to fixed point multiplication) to the end. + + // This prevents y * ln(x) from overflowing, and at the same time guarantees y fits in the signed 256 bit range. + _require(y < MILD_EXPONENT_BOUND, Errors.Y_OUT_OF_BOUNDS); + int256 y_int256 = int256(y); + + int256 logx_times_y; + if (LN_36_LOWER_BOUND < x_int256 && x_int256 < LN_36_UPPER_BOUND) { + int256 ln_36_x = _ln_36(x_int256); + + // ln_36_x has 36 decimal places, so multiplying by y_int256 isn't as straightforward, since we can't just + // bring y_int256 to 36 decimal places, as it might overflow. Instead, we perform two 18 decimal + // multiplications and add the results: one with the first 18 decimals of ln_36_x, and one with the + // (downscaled) last 18 decimals. + logx_times_y = ((ln_36_x / ONE_18) * y_int256 + ((ln_36_x % ONE_18) * y_int256) / ONE_18); + } else { + logx_times_y = _ln(x_int256) * y_int256; + } + logx_times_y /= ONE_18; + + // Finally, we compute exp(y * ln(x)) to arrive at x^y + _require( + MIN_NATURAL_EXPONENT <= logx_times_y && logx_times_y <= MAX_NATURAL_EXPONENT, + Errors.PRODUCT_OUT_OF_BOUNDS + ); + + return uint256(exp(logx_times_y)); + } + + /** + * @dev Natural exponentiation (e^x) with signed 18 decimal fixed point exponent. + * + * Reverts if `x` is smaller than MIN_NATURAL_EXPONENT, or larger than `MAX_NATURAL_EXPONENT`. + */ + function exp(int256 x) internal pure returns (int256) { + _require(x >= MIN_NATURAL_EXPONENT && x <= MAX_NATURAL_EXPONENT, Errors.INVALID_EXPONENT); + + if (x < 0) { + // We only handle positive exponents: e^(-x) is computed as 1 / e^x. We can safely make x positive since it + // fits in the signed 256 bit range (as it is larger than MIN_NATURAL_EXPONENT). + // Fixed point division requires multiplying by ONE_18. + return ((ONE_18 * ONE_18) / exp(-x)); + } + + // First, we use the fact that e^(x+y) = e^x * e^y to decompose x into a sum of powers of two, which we call x_n, + // where x_n == 2^(7 - n), and e^x_n = a_n has been precomputed. We choose the first x_n, x0, to equal 2^7 + // because all larger powers are larger than MAX_NATURAL_EXPONENT, and therefore not present in the + // decomposition. + // At the end of this process we will have the product of all e^x_n = a_n that apply, and the remainder of this + // decomposition, which will be lower than the smallest x_n. + // exp(x) = k_0 * a_0 * k_1 * a_1 * ... + k_n * a_n * exp(remainder), where each k_n equals either 0 or 1. + // We mutate x by subtracting x_n, making it the remainder of the decomposition. + + // The first two a_n (e^(2^7) and e^(2^6)) are too large if stored as 18 decimal numbers, and could cause + // intermediate overflows. Instead we store them as plain integers, with 0 decimals. + // Additionally, x0 + x1 is larger than MAX_NATURAL_EXPONENT, which means they will not both be present in the + // decomposition. + + // For each x_n, we test if that term is present in the decomposition (if x is larger than it), and if so deduct + // it and compute the accumulated product. + + int256 firstAN; + if (x >= x0) { + x -= x0; + firstAN = a0; + } else if (x >= x1) { + x -= x1; + firstAN = a1; + } else { + firstAN = 1; // One with no decimal places + } + + // We now transform x into a 20 decimal fixed point number, to have enhanced precision when computing the + // smaller terms. + x *= 100; + + // `product` is the accumulated product of all a_n (except a0 and a1), which starts at 20 decimal fixed point + // one. Recall that fixed point multiplication requires dividing by ONE_20. + int256 product = ONE_20; + + if (x >= x2) { + x -= x2; + product = (product * a2) / ONE_20; + } + if (x >= x3) { + x -= x3; + product = (product * a3) / ONE_20; + } + if (x >= x4) { + x -= x4; + product = (product * a4) / ONE_20; + } + if (x >= x5) { + x -= x5; + product = (product * a5) / ONE_20; + } + if (x >= x6) { + x -= x6; + product = (product * a6) / ONE_20; + } + if (x >= x7) { + x -= x7; + product = (product * a7) / ONE_20; + } + if (x >= x8) { + x -= x8; + product = (product * a8) / ONE_20; + } + if (x >= x9) { + x -= x9; + product = (product * a9) / ONE_20; + } + + // x10 and x11 are unnecessary here since we have high enough precision already. + + // Now we need to compute e^x, where x is small (in particular, it is smaller than x9). We use the Taylor series + // expansion for e^x: 1 + x + (x^2 / 2!) + (x^3 / 3!) + ... + (x^n / n!). + + int256 seriesSum = ONE_20; // The initial one in the sum, with 20 decimal places. + int256 term; // Each term in the sum, where the nth term is (x^n / n!). + + // The first term is simply x. + term = x; + seriesSum += term; + + // Each term (x^n / n!) equals the previous one times x, divided by n. Since x is a fixed point number, + // multiplying by it requires dividing by ONE_20, but dividing by the non-fixed point n values does not. + + term = ((term * x) / ONE_20) / 2; + seriesSum += term; + + term = ((term * x) / ONE_20) / 3; + seriesSum += term; + + term = ((term * x) / ONE_20) / 4; + seriesSum += term; + + term = ((term * x) / ONE_20) / 5; + seriesSum += term; + + term = ((term * x) / ONE_20) / 6; + seriesSum += term; + + term = ((term * x) / ONE_20) / 7; + seriesSum += term; + + term = ((term * x) / ONE_20) / 8; + seriesSum += term; + + term = ((term * x) / ONE_20) / 9; + seriesSum += term; + + term = ((term * x) / ONE_20) / 10; + seriesSum += term; + + term = ((term * x) / ONE_20) / 11; + seriesSum += term; + + term = ((term * x) / ONE_20) / 12; + seriesSum += term; + + // 12 Taylor terms are sufficient for 18 decimal precision. + + // We now have the first a_n (with no decimals), and the product of all other a_n present, and the Taylor + // approximation of the exponentiation of the remainder (both with 20 decimals). All that remains is to multiply + // all three (one 20 decimal fixed point multiplication, dividing by ONE_20, and one integer multiplication), + // and then drop two digits to return an 18 decimal value. + + return (((product * seriesSum) / ONE_20) * firstAN) / 100; + } + + /** + * @dev Logarithm (log(arg, base), with signed 18 decimal fixed point base and argument. + */ + function log(int256 arg, int256 base) internal pure returns (int256) { + // This performs a simple base change: log(arg, base) = ln(arg) / ln(base). + + // Both logBase and logArg are computed as 36 decimal fixed point numbers, either by using ln_36, or by + // upscaling. + + int256 logBase; + if (LN_36_LOWER_BOUND < base && base < LN_36_UPPER_BOUND) { + logBase = _ln_36(base); + } else { + logBase = _ln(base) * ONE_18; + } + + int256 logArg; + if (LN_36_LOWER_BOUND < arg && arg < LN_36_UPPER_BOUND) { + logArg = _ln_36(arg); + } else { + logArg = _ln(arg) * ONE_18; + } + + // When dividing, we multiply by ONE_18 to arrive at a result with 18 decimal places + return (logArg * ONE_18) / logBase; + } + + /** + * @dev Natural logarithm (ln(a)) with signed 18 decimal fixed point argument. + */ + function ln(int256 a) internal pure returns (int256) { + // The real natural logarithm is not defined for negative numbers or zero. + _require(a > 0, Errors.OUT_OF_BOUNDS); + if (LN_36_LOWER_BOUND < a && a < LN_36_UPPER_BOUND) { + return _ln_36(a) / ONE_18; + } else { + return _ln(a); + } + } + + /** + * @dev Internal natural logarithm (ln(a)) with signed 18 decimal fixed point argument. + */ + function _ln(int256 a) private pure returns (int256) { + if (a < ONE_18) { + // Since ln(a^k) = k * ln(a), we can compute ln(a) as ln(a) = ln((1/a)^(-1)) = - ln((1/a)). If a is less + // than one, 1/a will be greater than one, and this if statement will not be entered in the recursive call. + // Fixed point division requires multiplying by ONE_18. + return (-_ln((ONE_18 * ONE_18) / a)); + } + + // First, we use the fact that ln^(a * b) = ln(a) + ln(b) to decompose ln(a) into a sum of powers of two, which + // we call x_n, where x_n == 2^(7 - n), which are the natural logarithm of precomputed quantities a_n (that is, + // ln(a_n) = x_n). We choose the first x_n, x0, to equal 2^7 because the exponential of all larger powers cannot + // be represented as 18 fixed point decimal numbers in 256 bits, and are therefore larger than a. + // At the end of this process we will have the sum of all x_n = ln(a_n) that apply, and the remainder of this + // decomposition, which will be lower than the smallest a_n. + // ln(a) = k_0 * x_0 + k_1 * x_1 + ... + k_n * x_n + ln(remainder), where each k_n equals either 0 or 1. + // We mutate a by subtracting a_n, making it the remainder of the decomposition. + + // For reasons related to how `exp` works, the first two a_n (e^(2^7) and e^(2^6)) are not stored as fixed point + // numbers with 18 decimals, but instead as plain integers with 0 decimals, so we need to multiply them by + // ONE_18 to convert them to fixed point. + // For each a_n, we test if that term is present in the decomposition (if a is larger than it), and if so divide + // by it and compute the accumulated sum. + + int256 sum = 0; + if (a >= a0 * ONE_18) { + a /= a0; // Integer, not fixed point division + sum += x0; + } + + if (a >= a1 * ONE_18) { + a /= a1; // Integer, not fixed point division + sum += x1; + } + + // All other a_n and x_n are stored as 20 digit fixed point numbers, so we convert the sum and a to this format. + sum *= 100; + a *= 100; + + // Because further a_n are 20 digit fixed point numbers, we multiply by ONE_20 when dividing by them. + + if (a >= a2) { + a = (a * ONE_20) / a2; + sum += x2; + } + + if (a >= a3) { + a = (a * ONE_20) / a3; + sum += x3; + } + + if (a >= a4) { + a = (a * ONE_20) / a4; + sum += x4; + } + + if (a >= a5) { + a = (a * ONE_20) / a5; + sum += x5; + } + + if (a >= a6) { + a = (a * ONE_20) / a6; + sum += x6; + } + + if (a >= a7) { + a = (a * ONE_20) / a7; + sum += x7; + } + + if (a >= a8) { + a = (a * ONE_20) / a8; + sum += x8; + } + + if (a >= a9) { + a = (a * ONE_20) / a9; + sum += x9; + } + + if (a >= a10) { + a = (a * ONE_20) / a10; + sum += x10; + } + + if (a >= a11) { + a = (a * ONE_20) / a11; + sum += x11; + } + + // a is now a small number (smaller than a_11, which roughly equals 1.06). This means we can use a Taylor series + // that converges rapidly for values of `a` close to one - the same one used in ln_36. + // Let z = (a - 1) / (a + 1). + // ln(a) = 2 * (z + z^3 / 3 + z^5 / 5 + z^7 / 7 + ... + z^(2 * n + 1) / (2 * n + 1)) + + // Recall that 20 digit fixed point division requires multiplying by ONE_20, and multiplication requires + // division by ONE_20. + int256 z = ((a - ONE_20) * ONE_20) / (a + ONE_20); + int256 z_squared = (z * z) / ONE_20; + + // num is the numerator of the series: the z^(2 * n + 1) term + int256 num = z; + + // seriesSum holds the accumulated sum of each term in the series, starting with the initial z + int256 seriesSum = num; + + // In each step, the numerator is multiplied by z^2 + num = (num * z_squared) / ONE_20; + seriesSum += num / 3; + + num = (num * z_squared) / ONE_20; + seriesSum += num / 5; + + num = (num * z_squared) / ONE_20; + seriesSum += num / 7; + + num = (num * z_squared) / ONE_20; + seriesSum += num / 9; + + num = (num * z_squared) / ONE_20; + seriesSum += num / 11; + + // 6 Taylor terms are sufficient for 36 decimal precision. + + // Finally, we multiply by 2 (non fixed point) to compute ln(remainder) + seriesSum *= 2; + + // We now have the sum of all x_n present, and the Taylor approximation of the logarithm of the remainder (both + // with 20 decimals). All that remains is to sum these two, and then drop two digits to return a 18 decimal + // value. + + return (sum + seriesSum) / 100; + } + + /** + * @dev Intrnal high precision (36 decimal places) natural logarithm (ln(x)) with signed 18 decimal fixed point argument, + * for x close to one. + * + * Should only be used if x is between LN_36_LOWER_BOUND and LN_36_UPPER_BOUND. + */ + function _ln_36(int256 x) private pure returns (int256) { + // Since ln(1) = 0, a value of x close to one will yield a very small result, which makes using 36 digits + // worthwhile. + + // First, we transform x to a 36 digit fixed point value. + x *= ONE_18; + + // We will use the following Taylor expansion, which converges very rapidly. Let z = (x - 1) / (x + 1). + // ln(x) = 2 * (z + z^3 / 3 + z^5 / 5 + z^7 / 7 + ... + z^(2 * n + 1) / (2 * n + 1)) + + // Recall that 36 digit fixed point division requires multiplying by ONE_36, and multiplication requires + // division by ONE_36. + int256 z = ((x - ONE_36) * ONE_36) / (x + ONE_36); + int256 z_squared = (z * z) / ONE_36; + + // num is the numerator of the series: the z^(2 * n + 1) term + int256 num = z; + + // seriesSum holds the accumulated sum of each term in the series, starting with the initial z + int256 seriesSum = num; + + // In each step, the numerator is multiplied by z^2 + num = (num * z_squared) / ONE_36; // z**3 + seriesSum += num / 3; + + num = (num * z_squared) / ONE_36; // z**5 + seriesSum += num / 5; + + num = (num * z_squared) / ONE_36; // z**7 + seriesSum += num / 7; + + num = (num * z_squared) / ONE_36; // z**9 + seriesSum += num / 9; + + num = (num * z_squared) / ONE_36; // z**11 + seriesSum += num / 11; + + num = (num * z_squared) / ONE_36; // z**13 + seriesSum += num / 13; + + num = (num * z_squared) / ONE_36; // z**15 + seriesSum += num / 15; + + // 8 Taylor terms are sufficient for 36 decimal precision. + + // All that remains is multiplying by 2 (non fixed point). + return seriesSum * 2; + } +} diff --git a/pkg/solidity-utils/test/foundry/LogExpMath.t.sol b/pkg/solidity-utils/test/foundry/LogExpMath.t.sol new file mode 100644 index 0000000000..cdcb48042a --- /dev/null +++ b/pkg/solidity-utils/test/foundry/LogExpMath.t.sol @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.7.0; + +import "forge-std/Test.sol"; + +import "../../contracts/math/LogExpMath.sol"; +import "../../contracts/math/ReferenceLogExpMath.sol"; + +contract LogExpMathsTest is Test { + function testExpEquivalence(int256 a) external { + vm.assume(a > ReferenceLogExpMath.MIN_NATURAL_EXPONENT); + vm.assume(a < ReferenceLogExpMath.MAX_NATURAL_EXPONENT); + + int256 expectedExp = ReferenceLogExpMath.exp(a); + int256 actualExp = LogExpMath.exp(a); + + assertApproxEqRel(actualExp, expectedExp, 1e2); + } + + function testLnEquivalence(int256 a) external { + vm.assume(a > 0); + + int256 expectedLn = ReferenceLogExpMath.ln(a); + int256 actualLn = LogExpMath.ln(a); + + assertApproxEqAbs(actualLn, expectedLn, 2); + } +}