Skip to content

Commit

Permalink
remove provider arg from getFee
Browse files Browse the repository at this point in the history
  • Loading branch information
cctdaniel committed Jan 23, 2025
1 parent fb8a7cd commit 5812d78
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 42 deletions.
4 changes: 1 addition & 3 deletions target_chains/ethereum/contracts/contracts/pulse/IPulse.sol
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ interface IPulse is PulseEvents {
* @notice Calculates the total fee required for a price update request
* @dev Total fee = base Pyth protocol fee + gas costs for callback
* @param callbackGasLimit The amount of gas allocated for callback execution
* @param provider The provider to use for the fee calculation
* @return feeAmount The total fee in wei that must be provided as msg.value
*/
function getFee(
uint256 callbackGasLimit,
address provider
uint256 callbackGasLimit
) external view returns (uint128 feeAmount);

function getAccruedFees() external view returns (uint128 accruedFeesInWei);
Expand Down
12 changes: 5 additions & 7 deletions target_chains/ethereum/contracts/contracts/pulse/Pulse.sol
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ abstract contract Pulse is IPulse, PulseState {
}
requestSequenceNumber = _state.currentSequenceNumber++;

uint128 requiredFee = getFee(callbackGasLimit, provider);
uint128 requiredFee = getFee(callbackGasLimit);
if (msg.value < requiredFee) revert InsufficientFee();

Request storage req = allocRequest(requestSequenceNumber);
Expand Down Expand Up @@ -190,14 +190,12 @@ abstract contract Pulse is IPulse, PulseState {
}

function getFee(
uint256 callbackGasLimit,
address provider
uint256 callbackGasLimit
) public view override returns (uint128 feeAmount) {
if (provider == address(0)) {
provider = _state.defaultProvider;
}
uint128 baseFee = _state.pythFeeInWei;
uint128 providerFeeInWei = _state.providers[provider].feeInWei;
uint128 providerFeeInWei = _state
.providers[_state.defaultProvider]
.feeInWei;
uint256 gasFee = callbackGasLimit * providerFeeInWei;
feeAmount = baseFee + SafeCast.toUint128(gasFee);
}
Expand Down
62 changes: 30 additions & 32 deletions target_chains/ethereum/contracts/forge-test/Pulse.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,13 @@ contract PulseTest is Test, PulseEvents {
}

// Helper function to calculate total fee
function calculateTotalFee(
address provider
) internal view returns (uint128) {
return pulse.getFee(CALLBACK_GAS_LIMIT, provider);
function calculateTotalFee() internal view returns (uint128) {
return pulse.getFee(CALLBACK_GAS_LIMIT);
}

// Helper function to setup consumer request
function setupConsumerRequest(
address consumerAddress,
address provider
address consumerAddress
)
internal
returns (
Expand All @@ -181,7 +178,7 @@ contract PulseTest is Test, PulseEvents {
publishTime = block.timestamp;
vm.deal(consumerAddress, 1 gwei);

uint128 totalFee = calculateTotalFee(provider);
uint128 totalFee = calculateTotalFee();

vm.prank(consumerAddress);
sequenceNumber = pulse.requestPriceUpdatesWithCallback{value: totalFee}(
Expand All @@ -202,7 +199,7 @@ contract PulseTest is Test, PulseEvents {

// Fund the consumer contract with enough ETH for higher gas price
vm.deal(address(consumer), 1 ether);
uint128 totalFee = calculateTotalFee(defaultProvider);
uint128 totalFee = calculateTotalFee();

// Create the event data we expect to see
PulseState.Request memory expectedRequest = PulseState.Request({
Expand Down Expand Up @@ -276,7 +273,7 @@ contract PulseTest is Test, PulseEvents {

// Fund the consumer contract
vm.deal(address(consumer), 1 gwei);
uint128 totalFee = calculateTotalFee(defaultProvider);
uint128 totalFee = calculateTotalFee();

// Step 1: Make the request as consumer
vm.prank(address(consumer));
Expand Down Expand Up @@ -353,7 +350,7 @@ contract PulseTest is Test, PulseEvents {
uint64 sequenceNumber,
bytes32[] memory priceIds,
uint256 publishTime
) = setupConsumerRequest(address(failingConsumer), defaultProvider);
) = setupConsumerRequest(address(failingConsumer));

PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
publishTime
Expand Down Expand Up @@ -381,7 +378,7 @@ contract PulseTest is Test, PulseEvents {
uint64 sequenceNumber,
bytes32[] memory priceIds,
uint256 publishTime
) = setupConsumerRequest(address(failingConsumer), defaultProvider);
) = setupConsumerRequest(address(failingConsumer));

PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
publishTime
Expand All @@ -408,7 +405,7 @@ contract PulseTest is Test, PulseEvents {
uint64 sequenceNumber,
bytes32[] memory priceIds,
uint256 publishTime
) = setupConsumerRequest(address(consumer), defaultProvider);
) = setupConsumerRequest(address(consumer));

// Setup mock data
PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
Expand All @@ -433,7 +430,7 @@ contract PulseTest is Test, PulseEvents {
uint256 futureTime = block.timestamp + 10; // 10 seconds in future
vm.deal(address(consumer), 1 gwei);

uint128 totalFee = calculateTotalFee(defaultProvider);
uint128 totalFee = calculateTotalFee();
vm.prank(address(consumer));
uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
value: totalFee
Expand Down Expand Up @@ -468,7 +465,7 @@ contract PulseTest is Test, PulseEvents {
uint256 farFutureTime = block.timestamp + 61; // Just over 1 minute
vm.deal(address(consumer), 1 gwei);

uint128 totalFee = calculateTotalFee(defaultProvider);
uint128 totalFee = calculateTotalFee();
vm.prank(address(consumer));

vm.expectRevert("Too far in future");
Expand All @@ -484,7 +481,7 @@ contract PulseTest is Test, PulseEvents {
uint64 sequenceNumber,
bytes32[] memory priceIds,
uint256 publishTime
) = setupConsumerRequest(address(consumer), defaultProvider);
) = setupConsumerRequest(address(consumer));

PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
publishTime
Expand Down Expand Up @@ -514,7 +511,7 @@ contract PulseTest is Test, PulseEvents {
uint128 expectedFee = SafeCast.toUint128(
DEFAULT_PROVIDER_FEE * gasLimit
) + PYTH_FEE;
uint128 actualFee = pulse.getFee(gasLimit, defaultProvider);
uint128 actualFee = pulse.getFee(gasLimit);
assertEq(
actualFee,
expectedFee,
Expand All @@ -524,7 +521,7 @@ contract PulseTest is Test, PulseEvents {

// Test with zero gas limit
uint128 expectedMinFee = PYTH_FEE;
uint128 actualMinFee = pulse.getFee(0, defaultProvider);
uint128 actualMinFee = pulse.getFee(0);
assertEq(
actualMinFee,
expectedMinFee,
Expand All @@ -538,9 +535,11 @@ contract PulseTest is Test, PulseEvents {
vm.deal(address(consumer), 1 gwei);

vm.prank(address(consumer));
pulse.requestPriceUpdatesWithCallback{
value: calculateTotalFee(defaultProvider)
}(block.timestamp, priceIds, CALLBACK_GAS_LIMIT);
pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
block.timestamp,
priceIds,
CALLBACK_GAS_LIMIT
);

// Get admin's balance before withdrawal
uint256 adminBalanceBefore = admin.balance;
Expand Down Expand Up @@ -586,9 +585,11 @@ contract PulseTest is Test, PulseEvents {
vm.deal(address(consumer), 1 gwei);

vm.prank(address(consumer));
pulse.requestPriceUpdatesWithCallback{
value: calculateTotalFee(defaultProvider)
}(block.timestamp, priceIds, CALLBACK_GAS_LIMIT);
pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
block.timestamp,
priceIds,
CALLBACK_GAS_LIMIT
);

// Get provider's accrued fees instead of total fees
PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo(
Expand Down Expand Up @@ -645,10 +646,7 @@ contract PulseTest is Test, PulseEvents {
uint256 publishTime = block.timestamp;

// Setup request
(uint64 sequenceNumber, , ) = setupConsumerRequest(
address(consumer),
defaultProvider
);
(uint64 sequenceNumber, , ) = setupConsumerRequest(address(consumer));

// Create different priceIds
bytes32[] memory wrongPriceIds = new bytes32[](2);
Expand Down Expand Up @@ -682,7 +680,7 @@ contract PulseTest is Test, PulseEvents {
}

vm.deal(address(consumer), 1 gwei);
uint128 totalFee = calculateTotalFee(defaultProvider);
uint128 totalFee = calculateTotalFee();

vm.prank(address(consumer));
vm.expectRevert(
Expand Down Expand Up @@ -752,7 +750,7 @@ contract PulseTest is Test, PulseEvents {
bytes32[] memory priceIds = new bytes32[](1);
priceIds[0] = bytes32(uint256(1));

uint128 totalFee = pulse.getFee(CALLBACK_GAS_LIMIT, provider);
uint128 totalFee = pulse.getFee(CALLBACK_GAS_LIMIT);

vm.deal(address(consumer), totalFee);
vm.prank(address(consumer));
Expand Down Expand Up @@ -802,7 +800,7 @@ contract PulseTest is Test, PulseEvents {
uint64 sequenceNumber,
bytes32[] memory priceIds,
uint256 publishTime
) = setupConsumerRequest(address(consumer), defaultProvider);
) = setupConsumerRequest(address(consumer));

// Setup mock data
PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
Expand Down Expand Up @@ -832,7 +830,7 @@ contract PulseTest is Test, PulseEvents {
uint64 sequenceNumber,
bytes32[] memory priceIds,
uint256 publishTime
) = setupConsumerRequest(address(consumer), defaultProvider);
) = setupConsumerRequest(address(consumer));

// Setup mock data
PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
Expand Down Expand Up @@ -864,7 +862,7 @@ contract PulseTest is Test, PulseEvents {
uint64 sequenceNumber,
bytes32[] memory priceIds,
uint256 publishTime
) = setupConsumerRequest(address(consumer), defaultProvider);
) = setupConsumerRequest(address(consumer));

// Setup mock data
PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
Expand Down

0 comments on commit 5812d78

Please sign in to comment.