diff --git a/src/math/PercentageMath.sol b/src/math/PercentageMath.sol index 49a9bfb..9be3837 100644 --- a/src/math/PercentageMath.sol +++ b/src/math/PercentageMath.sol @@ -13,6 +13,11 @@ library PercentageMath { uint256 internal constant MAX_UINT256 = 2**256 - 1; uint256 internal constant MAX_UINT256_MINUS_HALF_PERCENTAGE = 2**256 - 1 - 0.5e4; + /// ERRORS /// + + // Thrown when percentage is above 100%. + error PercentageTooHigh(); + /// INTERNAL /// /// @notice Executes a percentage multiplication. @@ -51,4 +56,19 @@ library PercentageMath { y := div(add(mul(x, PERCENTAGE_FACTOR), y), percentage) } } + + /// @notice Executes a weighted average, given an interval [x, y] and a percent p: x * (1 - p) + y * p + /// @param x The value at the start of the interval (included). + /// @param y The value at the end of the interval (included). + /// @param percentage The percentage of the interval to be calculated. + /// @return the average of x and y, weighted by percentage. + function weightedAvg( + uint256 x, + uint256 y, + uint256 percentage + ) internal pure returns (uint256) { + if (percentage > PERCENTAGE_FACTOR) revert PercentageTooHigh(); + + return percentMul(x, PERCENTAGE_FACTOR - percentage) + percentMul(y, percentage); + } } diff --git a/test/TestPercentageMath.sol b/test/TestPercentageMath.sol index 46783d3..59a9fc3 100644 --- a/test/TestPercentageMath.sol +++ b/test/TestPercentageMath.sol @@ -14,9 +14,19 @@ contract PercentageMathFunctions { function percentDiv(uint256 x, uint256 y) public pure returns (uint256) { return PercentageMath.percentDiv(x, y); } + + function weightedAvg( + uint256 x, + uint256 y, + uint256 percentage + ) public pure returns (uint256) { + return PercentageMath.weightedAvg(x, y, percentage); + } } contract PercentageMathFunctionsRef { + error PercentageTooHigh(); + function percentMul(uint256 x, uint256 y) public pure returns (uint256) { return PercentageMathRef.percentMul(x, y); } @@ -24,6 +34,18 @@ contract PercentageMathFunctionsRef { function percentDiv(uint256 x, uint256 y) public pure returns (uint256) { return PercentageMathRef.percentDiv(x, y); } + + function weightedAvg( + uint256 x, + uint256 y, + uint256 percentage + ) public pure returns (uint256) { + if (percentage > PercentageMath.PERCENTAGE_FACTOR) revert PercentageTooHigh(); + + return + PercentageMathRef.percentMul(x, PercentageMathRef.PERCENTAGE_FACTOR - percentage) + + PercentageMathRef.percentMul(y, percentage); + } } contract TestPercentageMath is Test { @@ -75,6 +97,30 @@ contract TestPercentageMath is Test { PercentageMath.percentDiv(x, y); } + function testWeightedAvg( + uint256 x, + uint256 y, + uint16 percentage + ) public { + vm.assume(percentage <= PERCENTAGE_FACTOR); + if (percentage > 0) vm.assume(y <= MAX_UINT256_MINUS_HALF_PERCENTAGE / percentage); + if (percentage < PERCENTAGE_FACTOR) + vm.assume(x <= MAX_UINT256_MINUS_HALF_PERCENTAGE / (PERCENTAGE_FACTOR - percentage)); + + assertEq(PercentageMath.weightedAvg(x, y, percentage), mathRef.weightedAvg(x, y, percentage)); + } + + function testWeightedAvgRevertWhenPercentageTooHigh( + uint256 x, + uint256 y, + uint256 percentage + ) public { + vm.assume(percentage > PERCENTAGE_FACTOR); + + vm.expectRevert(abi.encodeWithSignature("PercentageTooHigh()")); + PercentageMath.weightedAvg(x, y, percentage); + } + /// GAS COMPARISONS /// function testGasPercentageMul() public view { @@ -86,4 +132,9 @@ contract TestPercentageMath is Test { math.percentDiv(1 ether, 1_000); mathRef.percentDiv(1 ether, 1_000); } + + function testGasPercentageAvg() public view { + math.weightedAvg(1 ether, 2 ether, 5_000); + mathRef.weightedAvg(1 ether, 2 ether, 5_000); + } }