Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
notJoon committed Jan 31, 2025
1 parent 00011ea commit 6a922a1
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 63 deletions.
4 changes: 2 additions & 2 deletions contract/p/gnoswap/gnsmath/erros.gno
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ var (
errSqrtPriceZero = errors.New("sqrtPX96 should not be zero")
errLiquidityZero = errors.New("liquidity should not be zero")
errSqrtRatioAX96NotPositive = errors.New("sqrtRatioAX96 must be greater than zero")
errAmount0DeltaOverflow = errors.New("SqrtPriceMathGetAmount0DeltaStr: overflow")
errAmount1DeltaOverflow = errors.New("SqrtPriceMathGetAmount1DeltaStr: overflow")
errAmount0DeltaOverflow = errors.New("GetAmount0DeltaStr: overflow")
errAmount1DeltaOverflow = errors.New("GetAmount1DeltaStr: overflow")
errMSBZeroInput = errors.New("input for MSB calculation should not be zero")
errLSBZeroInput = errors.New("input for LSB calculation should not be zero")
)
56 changes: 28 additions & 28 deletions contract/p/gnoswap/gnsmath/sqrt_price_math.gno
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func getNextPriceAmount0Remove(
return nextSqrtPrice
}

// sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp calculates the next square root price
// getNextSqrtPriceFromAmount0RoundingUp calculates the next square root price
// based on the amount of token0 added or removed from the pool.
// NOTE: Always rounds up, because in the exact output case (increasing price) we need to move the price at least
// far enough to get the desired output amount, and in the exact input case (decreasing price) we need to move the
Expand All @@ -81,7 +81,7 @@ func getNextPriceAmount0Remove(
// - When `add` is false, the function calculates the new square root price after removing `amount` of token0.
// - The function uses high-precision math (MulDivRoundingUp, DivRoundingUp) to handle division rounding issues.
// - The function validates input conditions and panics if the state is invalid.
func sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(
func getNextSqrtPriceFromAmount0RoundingUp(
sqrtPX96 *u256.Uint,
liquidity *u256.Uint,
amount *u256.Uint,
Expand Down Expand Up @@ -149,7 +149,7 @@ func getNextPriceAmount1Remove(
return res
}

// sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown calculates the next square root price
// getNextSqrtPriceFromAmount1RoundingDown calculates the next square root price
// based on the amount of token1 added or removed from the pool, with rounding down.
// NOTE: Always rounds down, because in the exact output case (decreasing price) we need to move the price at least
// far enough to get the desired output amount, and in the exact input case (increasing price) we need to move the
Expand All @@ -170,7 +170,7 @@ func getNextPriceAmount1Remove(
// - When `add` is false, the function calculates the new square root price after removing `amount` of token1.
// - The function uses high-precision math (MulDiv and DivRoundingUp) to handle division and prevent precision loss.
// - The function validates input conditions and panics if the state is invalid.
func sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(
func getNextSqrtPriceFromAmount1RoundingDown(
sqrtPX96, liquidity, amount *u256.Uint,
add bool,
) *u256.Uint {
Expand All @@ -180,7 +180,7 @@ func sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(
return getNextPriceAmount1Remove(sqrtPX96, liquidity, amount)
}

// sqrtPriceMathGetNextSqrtPriceFromInput calculates the next square root price
// getNextSqrtPriceFromInput calculates the next square root price
// based on the amount of token0 or token1 added to the pool.
// NOTE: Always rounds up, because in the exact output case (increasing price) we need to move the price at least
// far enough to get the desired output amount, and in the exact input case (decreasing price) we need to move the
Expand All @@ -196,7 +196,7 @@ func sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(
//
// Returns:
// - The price after adding amountIn, depending on zeroForOne
func sqrtPriceMathGetNextSqrtPriceFromInput(
func getNextSqrtPriceFromInput(
sqrtPX96, liquidity, amountIn *u256.Uint,
zeroForOne bool,
) *u256.Uint {
Expand All @@ -209,13 +209,13 @@ func sqrtPriceMathGetNextSqrtPriceFromInput(
}

if zeroForOne {
return sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountIn, true)
return getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountIn, true)
}

return sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountIn, true)
return getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountIn, true)
}

// sqrtPriceMathGetNextSqrtPriceFromOutput calculates the next square root price
// getNextSqrtPriceFromOutput calculates the next square root price
// based on the amount of token0 or token1 removed from the pool.
//
// NOTE:
Expand All @@ -239,9 +239,9 @@ func sqrtPriceMathGetNextSqrtPriceFromInput(
// Notes:
// - Rounding direction depends on the swap direction (zeroForOne).
// - Relies on helper functions:
// - `sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown` for Token0 -> Token1.
// - `sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp` for Token1 -> Token0.
func sqrtPriceMathGetNextSqrtPriceFromOutput(
// - `getNextSqrtPriceFromAmount1RoundingDown` for Token0 -> Token1.
// - `getNextSqrtPriceFromAmount0RoundingUp` for Token1 -> Token0.
func getNextSqrtPriceFromOutput(
sqrtPX96, liquidity, amountOut *u256.Uint,
zeroForOne bool,
) *u256.Uint {
Expand All @@ -254,13 +254,13 @@ func sqrtPriceMathGetNextSqrtPriceFromOutput(
}

if zeroForOne {
return sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountOut, false)
return getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountOut, false)
}

return sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountOut, false)
return getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountOut, false)
}

// sqrtPriceMathGetAmount0DeltaHelper calculates the absolute difference between the amounts of token0 in two
// getAmount0DeltaHelper calculates the absolute difference between the amounts of token0 in two
// liquidity ranges defined by the square root prices sqrtRatioAX96 and sqrtRatioBX96. The difference is
// calculated relative to the range [sqrtRatioAX96, sqrtRatioBX96].
//
Expand All @@ -279,7 +279,7 @@ func sqrtPriceMathGetNextSqrtPriceFromOutput(
// - If sqrtRatioAX96 is zero or negative, the function panics.
// - The result is calculated using high-precision fixed-point arithmetic.
// - Rounding is applied based on the roundUp parameter.
func sqrtPriceMathGetAmount0DeltaHelper(
func getAmount0DeltaHelper(
sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint,
roundUp bool,
) *u256.Uint {
Expand All @@ -303,7 +303,7 @@ func sqrtPriceMathGetAmount0DeltaHelper(
return new(u256.Uint).Div(value, sqrtRatioAX96)
}

// sqrtPriceMathGetAmount1DeltaHelper calculates the absolute difference between the amounts of token1 in two
// getAmount1DeltaHelper calculates the absolute difference between the amounts of token1 in two
// liquidity ranges defined by the square root prices sqrtRatioAX96 and sqrtRatioBX96. The difference is
// calculated relative to the range [sqrtRatioAX96, sqrtRatioBX96].
//
Expand All @@ -321,7 +321,7 @@ func sqrtPriceMathGetAmount0DeltaHelper(
// Notes:
// - Rounding is applied based on the roundUp parameter.
// - The function swaps sqrtRatioAX96 and sqrtRatioBX96 if sqrtRatioAX96 > sqrtRatioBX96.
func sqrtPriceMathGetAmount1DeltaHelper(
func getAmount1DeltaHelper(
sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint,
roundUp bool,
) *u256.Uint {
Expand All @@ -337,7 +337,7 @@ func sqrtPriceMathGetAmount1DeltaHelper(
return u256.MulDiv(liquidity, diff, q96)
}

// SqrtPriceMathGetAmount0DeltaStr calculates the difference in the amount of token0
// GetAmount0DeltaStr calculates the difference in the amount of token0
// within a specified liquidity range defined by two square root prices (sqrtRatioAX96 and sqrtRatioBX96).
// This function returns the result as a string representation of an int256 value.
//
Expand All @@ -353,19 +353,19 @@ func sqrtPriceMathGetAmount1DeltaHelper(
// within the specified range. The value is negative if the liquidity is negative.
//
// Notes:
// - This function relies on the helper function `sqrtPriceMathGetAmount0DeltaHelper` to perform the core calculation.
// - This function relies on the helper function `getAmount0DeltaHelper` to perform the core calculation.
// - The helper function calculates the absolute difference between token0 amounts within the range.
// - If the computed result exceeds the maximum allowable value for int256 (2**255 - 1), the function will panic
// with an appropriate overflow error.
// - The rounding behavior of the result is controlled by the `roundUp` parameter passed to the helper function:
// - For negative liquidity, rounding is always down.
// - For positive liquidity, rounding is always up.
func SqrtPriceMathGetAmount0DeltaStr(
func GetAmount0DeltaStr(
sqrtRatioAX96, sqrtRatioBX96 *u256.Uint,
liquidity *i256.Int,
) string {
if liquidity.IsNeg() {
u := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false)
u := getAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false)
if u.Gt(u256.MustFromDecimal(MAX_INT256)) {
// if u > (2**255 - 1), cannot cast to int256
panic(errAmount0DeltaOverflow)
Expand All @@ -374,7 +374,7 @@ func SqrtPriceMathGetAmount0DeltaStr(
return i256.Zero().Neg(i).ToString()
}

u := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true)
u := getAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true)
if u.Gt(u256.MustFromDecimal(MAX_INT256)) {
// if u > (2**255 - 1), cannot cast to int256
panic(errAmount0DeltaOverflow)
Expand All @@ -383,7 +383,7 @@ func SqrtPriceMathGetAmount0DeltaStr(
return i256.FromUint256(u).ToString()
}

// SqrtPriceMathGetAmount1DeltaStr calculates the difference in the amount of token1
// GetAmount1DeltaStr calculates the difference in the amount of token1
// within a specified liquidity range defined by two square root prices (sqrtRatioAX96 and sqrtRatioBX96).
// This function returns the result as a string representation of an int256 value.
//
Expand All @@ -399,16 +399,16 @@ func SqrtPriceMathGetAmount0DeltaStr(
// within the specified range. The value is negative if the liquidity is negative.
//
// Notes:
// - This function relies on the helper function `sqrtPriceMathGetAmount1DeltaHelper` to perform the core calculation.
// - This function relies on the helper function `getAmount1DeltaHelper` to perform the core calculation.
// - The rounding behavior of the result is controlled by the `roundUp` parameter passed to the helper function:
// - For negative liquidity, rounding is always down.
// - For positive liquidity, rounding is always up.
func SqrtPriceMathGetAmount1DeltaStr(
func GetAmount1DeltaStr(
sqrtRatioAX96, sqrtRatioBX96 *u256.Uint,
liquidity *i256.Int,
) string {
if liquidity.IsNeg() {
u := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false)
u := getAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false)
if u.Gt(u256.MustFromDecimal(MAX_INT256)) {
// if u > (2**255 - 1), cannot cast to int256
panic(errAmount1DeltaOverflow)
Expand All @@ -417,7 +417,7 @@ func SqrtPriceMathGetAmount1DeltaStr(
return i256.Zero().Neg(i).ToString()
}

u := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true)
u := getAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true)
if u.Gt(u256.MustFromDecimal(MAX_INT256)) {
// if u > (2**255 - 1), cannot cast to int256
panic(errAmount1DeltaOverflow)
Expand Down
42 changes: 21 additions & 21 deletions contract/p/gnoswap/gnsmath/sqrt_price_math_test.gno
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestSqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(t *testing.T) {
liquidity := u256.MustFromDecimal("2000000")
amount := u256.Zero()

result := sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(
result := getNextSqrtPriceFromAmount0RoundingUp(
sqrtPX96,
liquidity,
amount,
Expand All @@ -32,7 +32,7 @@ func TestSqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(t *testing.T) {
liquidity := u256.MustFromDecimal("2000000")
amount := u256.MustFromDecimal("500000")

result := sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(
result := getNextSqrtPriceFromAmount0RoundingUp(
sqrtPX96,
liquidity,
amount,
Expand All @@ -51,7 +51,7 @@ func TestSqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(t *testing.T) {
liquidity := u256.MustFromDecimal("2000000")
amount := u256.MustFromDecimal("100000")

result := sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(
result := getNextSqrtPriceFromAmount1RoundingDown(
sqrtPX96,
liquidity,
amount,
Expand All @@ -70,7 +70,7 @@ func TestSqrtPriceMathGetAmount0DeltaStr(t *testing.T) {
ratioB := u256.MustFromDecimal("2000000")
liquidity := i256.FromUint256(u256.MustFromDecimal("5000000"))

result := SqrtPriceMathGetAmount0DeltaStr(ratioA, ratioB, liquidity)
result := GetAmount0DeltaStr(ratioA, ratioB, liquidity)

if result[0] == '-' {
t.Error("Result should be positive for positive liquidity")
Expand All @@ -82,7 +82,7 @@ func TestSqrtPriceMathGetAmount0DeltaStr(t *testing.T) {
ratioB := u256.MustFromDecimal("2000000")
liquidity := i256.Zero().Neg(i256.FromUint256(u256.MustFromDecimal("5000000")))

result := SqrtPriceMathGetAmount0DeltaStr(ratioA, ratioB, liquidity)
result := GetAmount0DeltaStr(ratioA, ratioB, liquidity)

if result[0] != '-' {
t.Error("Result should be negative for negative liquidity")
Expand All @@ -96,7 +96,7 @@ func TestSqrtPriceMathGetAmount0DeltaStr(t *testing.T) {
liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935"))

uassert.PanicsWithMessage(t, errAmount0DeltaOverflow.Error(), func() {
SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity)
GetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity)
})
})

Expand All @@ -108,7 +108,7 @@ func TestSqrtPriceMathGetAmount0DeltaStr(t *testing.T) {
liquidity = liquidity.Neg(liquidity) // Make liquidity negative

uassert.PanicsWithMessage(t, errAmount0DeltaOverflow.Error(), func() {
SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity)
GetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity)
})
})
}
Expand All @@ -119,7 +119,7 @@ func TestSqrtPriceMathGetAmount1DeltaStr(t *testing.T) {
ratioB := u256.MustFromDecimal("2000000")
liquidity := i256.FromUint256(u256.MustFromDecimal("5000000"))

result := SqrtPriceMathGetAmount1DeltaStr(ratioA, ratioB, liquidity)
result := GetAmount1DeltaStr(ratioA, ratioB, liquidity)

if result[0] == '-' {
t.Error("Result should be positive for positive liquidity")
Expand All @@ -131,7 +131,7 @@ func TestSqrtPriceMathGetAmount1DeltaStr(t *testing.T) {
ratioB := u256.MustFromDecimal("2000000")
liquidity := i256.Zero().Neg(i256.FromUint256(u256.MustFromDecimal("5000000")))

result := SqrtPriceMathGetAmount0DeltaStr(ratioA, ratioB, liquidity)
result := GetAmount0DeltaStr(ratioA, ratioB, liquidity)

if result[0] != '-' {
t.Error("Result should be negative for negative liquidity")
Expand All @@ -145,7 +145,7 @@ func TestSqrtPriceMathGetAmount1DeltaStr(t *testing.T) {
liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935"))

uassert.PanicsWithMessage(t, errAmount1DeltaOverflow.Error(), func() {
SqrtPriceMathGetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity)
GetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity)
})
})

Expand All @@ -157,7 +157,7 @@ func TestSqrtPriceMathGetAmount1DeltaStr(t *testing.T) {
liquidity = liquidity.Neg(liquidity) // Make liquidity negative

uassert.PanicsWithMessage(t, errAmount1DeltaOverflow.Error(), func() {
SqrtPriceMathGetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity)
GetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity)
})
})
}
Expand Down Expand Up @@ -270,10 +270,10 @@ func TestSqrtPriceMathGetNextSqrtPriceFromInput(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.shouldPanic {
uassert.PanicsWithMessage(t, tt.panicMsg, func() {
sqrtPriceMathGetNextSqrtPriceFromInput(tt.sqrtPriceX96, tt.liquidity, tt.amountIn, tt.zeroForOne)
getNextSqrtPriceFromInput(tt.sqrtPriceX96, tt.liquidity, tt.amountIn, tt.zeroForOne)
})
} else {
actual := sqrtPriceMathGetNextSqrtPriceFromInput(tt.sqrtPriceX96, tt.liquidity, tt.amountIn, tt.zeroForOne)
actual := getNextSqrtPriceFromInput(tt.sqrtPriceX96, tt.liquidity, tt.amountIn, tt.zeroForOne)
uassert.Equal(t, tt.expectedSqrtPriceX96, actual.ToString())
}
})
Expand All @@ -288,7 +288,7 @@ func TestSqrtPriceMathGetNextSqrtPriceFromInput2(t *testing.T) {
}
}()

sqrtPriceMathGetNextSqrtPriceFromInput(
getNextSqrtPriceFromInput(
u256.Zero(),
u256.MustFromDecimal("1000000"),
u256.MustFromDecimal("500000"),
Expand All @@ -303,7 +303,7 @@ func TestSqrtPriceMathGetNextSqrtPriceFromInput2(t *testing.T) {
}
}()

sqrtPriceMathGetNextSqrtPriceFromInput(
getNextSqrtPriceFromInput(
u256.MustFromDecimal("1000000"),
u256.Zero(),
u256.MustFromDecimal("500000"),
Expand Down Expand Up @@ -444,10 +444,10 @@ func TestSqrtPriceMathGetNextSqrtPriceFromOutput(t *testing.T) {
t.Errorf("Expected panic for %s", tt.name)
}
}()
sqrtPriceMathGetNextSqrtPriceFromOutput(tt.sqrtPriceX96, tt.liquidity, tt.amountOut, tt.zeroForOne)
getNextSqrtPriceFromOutput(tt.sqrtPriceX96, tt.liquidity, tt.amountOut, tt.zeroForOne)

} else {
actual := sqrtPriceMathGetNextSqrtPriceFromOutput(tt.sqrtPriceX96, tt.liquidity, tt.amountOut, tt.zeroForOne)
actual := getNextSqrtPriceFromOutput(tt.sqrtPriceX96, tt.liquidity, tt.amountOut, tt.zeroForOne)
uassert.Equal(t, tt.expectedSqrtPriceX96, actual.ToString())
}
})
Expand Down Expand Up @@ -513,7 +513,7 @@ func TestSqrtPriceMathGetAmount0DeltaHelper(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := sqrtPriceMathGetAmount0DeltaHelper(tt.sqrtRatioAX96, tt.sqrtRatioBX96, tt.liquidity, tt.roundUp)
actual := getAmount0DeltaHelper(tt.sqrtRatioAX96, tt.sqrtRatioBX96, tt.liquidity, tt.roundUp)
uassert.Equal(t, tt.expectedAmount0Delta, actual.ToString())
})
}
Expand Down Expand Up @@ -562,7 +562,7 @@ func TestSqrtPriceMathGetAmount1DeltaHelper(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := sqrtPriceMathGetAmount1DeltaHelper(tt.sqrtRatioAX96, tt.sqrtRatioBX96, tt.liquidity, tt.roundUp)
actual := getAmount1DeltaHelper(tt.sqrtRatioAX96, tt.sqrtRatioBX96, tt.liquidity, tt.roundUp)
uassert.Equal(t, tt.expectedAmount1Delta, actual.ToString())
})
}
Expand All @@ -574,10 +574,10 @@ func TestSwapComputation_SqrtP_SqrtQ_Mul_Overflow(t *testing.T) {
amountIn := u256.MustFromDecimal("406")
zeroForOne := true

sqrtQ := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtP, liquidity, amountIn, zeroForOne)
sqrtQ := getNextSqrtPriceFromInput(sqrtP, liquidity, amountIn, zeroForOne)
uassert.Equal(t, "1025574284609383582644711336373707553698163132913", sqrtQ.ToString())

amount0Delta := sqrtPriceMathGetAmount0DeltaHelper(sqrtQ, sqrtP, liquidity, true)
amount0Delta := getAmount0DeltaHelper(sqrtQ, sqrtP, liquidity, true)
uassert.Equal(t, "406", amount0Delta.ToString())
}

Expand Down
Loading

0 comments on commit 6a922a1

Please sign in to comment.