Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
ScreamingHawk committed Jul 16, 2024
1 parent de0dbe3 commit 9655ac1
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 33 deletions.
5 changes: 4 additions & 1 deletion src/payments/IPayments.sol
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ interface IPaymentsFunctions {
* @param chainedCallDetails The chained call details.
* @return callHash The hash of the chained call for signing.
*/
function hashChainedCallDetails(ChainedCallDetails calldata chainedCallDetails) external view returns (bytes32 callHash);
function hashChainedCallDetails(ChainedCallDetails calldata chainedCallDetails)
external
view
returns (bytes32 callHash);

/**
* Complete a chained call.
Expand Down
13 changes: 6 additions & 7 deletions src/payments/Payments.sol
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ contract Payments is Ownable, IPayments, IERC165 {

/// @inheritdoc IPaymentsFunctions
/// @dev As the signer can validate any payment (including zero) this function does not increase the security surface.
function performChainedCall(ChainedCallDetails calldata chainedCallDetails, bytes calldata signature) external override {
function performChainedCall(ChainedCallDetails calldata chainedCallDetails, bytes calldata signature)
external
override
{
if (!isValidChainedCallSignature(chainedCallDetails, signature)) {
revert InvalidSignature();
}
Expand All @@ -125,19 +128,15 @@ contract Payments is Ownable, IPayments, IERC165 {
/// @dev This hash includes the chain ID.
function hashChainedCallDetails(ChainedCallDetails calldata chainedCallDetails) public view returns (bytes32) {
return keccak256(
abi.encode(
block.chainid,
chainedCallDetails.chainedCallAddress,
chainedCallDetails.chainedCallData
)
abi.encode(block.chainid, chainedCallDetails.chainedCallAddress, chainedCallDetails.chainedCallData)
);
}

/**
* Perform a chained call and revert on error.
*/
function _performChainedCall(ChainedCallDetails calldata chainedCallDetails) internal {
(bool success, ) = chainedCallDetails.chainedCallAddress.call{value: 0}(chainedCallDetails.chainedCallData);
(bool success,) = chainedCallDetails.chainedCallAddress.call{value: 0}(chainedCallDetails.chainedCallData);
if (!success) {
revert ChainedCallFailed();
}
Expand Down
61 changes: 36 additions & 25 deletions test/payments/Payments.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ contract PaymentsTest is Test, IPaymentsSignals {
{
uint64 expiration = uint64(_bound(input.expiration, block.timestamp, type(uint64).max));
IPaymentsFunctions.TokenType tokenType = _toTokenType(input.tokenType);
(address tokenAddr, uint256 tokenId, uint256 amount) = _validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
(address tokenAddr, uint256 tokenId, uint256 amount) =
_validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
IPaymentsFunctions.PaymentRecipient[] memory paymentRecipients = new IPaymentsFunctions.PaymentRecipient[](1);
paymentRecipients[0] = input.paymentRecipient;
paymentRecipients[0].amount = amount;
Expand All @@ -89,10 +90,7 @@ contract PaymentsTest is Test, IPaymentsSignals {
paymentRecipients,
expiration,
input.productId,
IPaymentsFunctions.ChainedCallDetails(
address(0),
""
)
IPaymentsFunctions.ChainedCallDetails(address(0), "")
);

// Mint required tokens
Expand Down Expand Up @@ -126,7 +124,8 @@ contract PaymentsTest is Test, IPaymentsSignals {
{
uint64 expiration = uint64(_bound(input.expiration, block.timestamp, type(uint64).max));
IPaymentsFunctions.TokenType tokenType = _toTokenType(input.tokenType);
(address tokenAddr, uint256 tokenId, uint256 amount) = _validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
(address tokenAddr, uint256 tokenId, uint256 amount) =
_validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
IPaymentsFunctions.PaymentRecipient[] memory paymentRecipients = new IPaymentsFunctions.PaymentRecipient[](1);
paymentRecipients[0] = input.paymentRecipient;
paymentRecipients[0].amount = amount;
Expand All @@ -140,8 +139,10 @@ contract PaymentsTest is Test, IPaymentsSignals {
} else {
chainedTokenType = IPaymentsFunctions.TokenType.ERC20;
}
(address chainedTokenAddr, uint256 chainedTokenId, uint256 chainedAmount) = _validTokenParams(chainedTokenType, input.tokenId, input.paymentRecipient.amount);
bytes memory chainedData = abi.encodeWithSelector(IGenericToken.mint.selector, input.productRecipient, chainedTokenId, chainedAmount);
(address chainedTokenAddr, uint256 chainedTokenId, uint256 chainedAmount) =
_validTokenParams(chainedTokenType, input.tokenId, input.paymentRecipient.amount);
bytes memory chainedData =
abi.encodeWithSelector(IGenericToken.mint.selector, input.productRecipient, chainedTokenId, chainedAmount);

IPaymentsFunctions.PaymentDetails memory details = IPaymentsFunctions.PaymentDetails(
input.purchaseId,
Expand Down Expand Up @@ -187,7 +188,8 @@ contract PaymentsTest is Test, IPaymentsSignals {

uint64 expiration = uint64(_bound(input.expiration, block.timestamp, type(uint64).max));

(address tokenAddr, uint256 tokenId, uint256 amount) = _validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
(address tokenAddr, uint256 tokenId, uint256 amount) =
_validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
IPaymentsFunctions.PaymentRecipient[] memory paymentRecipients = new IPaymentsFunctions.PaymentRecipient[](2);
paymentRecipients[0] = input.paymentRecipient;
paymentRecipients[0].amount = amount;
Expand Down Expand Up @@ -229,7 +231,8 @@ contract PaymentsTest is Test, IPaymentsSignals {
{
uint64 expiration = uint64(_bound(input.expiration, block.timestamp, type(uint64).max));
IPaymentsFunctions.TokenType tokenType = _toTokenType(input.tokenType);
(address tokenAddr, uint256 tokenId, uint256 amount) = _validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
(address tokenAddr, uint256 tokenId, uint256 amount) =
_validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
IPaymentsFunctions.PaymentRecipient[] memory paymentRecipients = new IPaymentsFunctions.PaymentRecipient[](1);
paymentRecipients[0] = input.paymentRecipient;
paymentRecipients[0].amount = amount;
Expand Down Expand Up @@ -261,7 +264,8 @@ contract PaymentsTest is Test, IPaymentsSignals {
vm.warp(blockTimestamp);

IPaymentsFunctions.TokenType tokenType = _toTokenType(input.tokenType);
(address tokenAddr, uint256 tokenId, uint256 amount) = _validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
(address tokenAddr, uint256 tokenId, uint256 amount) =
_validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
IPaymentsFunctions.PaymentRecipient[] memory paymentRecipients = new IPaymentsFunctions.PaymentRecipient[](1);
paymentRecipients[0] = input.paymentRecipient;
paymentRecipients[0].amount = amount;
Expand Down Expand Up @@ -300,7 +304,8 @@ contract PaymentsTest is Test, IPaymentsSignals {
{
uint64 expiration = uint64(_bound(input.expiration, block.timestamp, type(uint64).max));
IPaymentsFunctions.TokenType tokenType = _toTokenType(input.tokenType);
(address tokenAddr, uint256 tokenId, uint256 amount) = _validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
(address tokenAddr, uint256 tokenId, uint256 amount) =
_validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
IPaymentsFunctions.PaymentRecipient[] memory paymentRecipients = new IPaymentsFunctions.PaymentRecipient[](1);
paymentRecipients[0] = input.paymentRecipient;
paymentRecipients[0].amount = amount;
Expand Down Expand Up @@ -372,10 +377,10 @@ contract PaymentsTest is Test, IPaymentsSignals {
safeAddress(caller)
safeAddress(input.paymentRecipient.recipient)
{

uint64 expiration = uint64(_bound(input.expiration, block.timestamp, type(uint64).max));
IPaymentsFunctions.TokenType tokenType = IPaymentsFunctions.TokenType.ERC721;
(address tokenAddr, uint256 tokenId, uint256 amount) = _validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
(address tokenAddr, uint256 tokenId, uint256 amount) =
_validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
IPaymentsFunctions.PaymentRecipient[] memory paymentRecipients = new IPaymentsFunctions.PaymentRecipient[](1);
paymentRecipients[0] = input.paymentRecipient;
paymentRecipients[0].amount = _bound(amount, 2, type(uint256).max); // Invalid amount
Expand Down Expand Up @@ -415,7 +420,8 @@ contract PaymentsTest is Test, IPaymentsSignals {

uint64 expiration = uint64(_bound(input.expiration, block.timestamp, type(uint64).max));
IPaymentsFunctions.TokenType tokenType = _toTokenType(input.tokenType);
(address tokenAddr, uint256 tokenId, uint256 amount) = _validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
(address tokenAddr, uint256 tokenId, uint256 amount) =
_validTokenParams(tokenType, input.tokenId, input.paymentRecipient.amount);
IPaymentsFunctions.PaymentRecipient[] memory paymentRecipients = new IPaymentsFunctions.PaymentRecipient[](1);
paymentRecipients[0] = input.paymentRecipient;
paymentRecipients[0].amount = amount;
Expand Down Expand Up @@ -458,7 +464,8 @@ contract PaymentsTest is Test, IPaymentsSignals {
(tokenAddr, tokenId, amount) = _validTokenParams(tokenType, tokenId, amount);

bytes memory callData = abi.encodeWithSelector(IGenericToken.mint.selector, recipient, tokenId, amount);
IPaymentsFunctions.ChainedCallDetails memory chainedCallDetails = IPaymentsFunctions.ChainedCallDetails(tokenAddr, callData);
IPaymentsFunctions.ChainedCallDetails memory chainedCallDetails =
IPaymentsFunctions.ChainedCallDetails(tokenAddr, callData);

// Sign it
bytes32 messageHash = payments.hashChainedCallDetails(chainedCallDetails);
Expand All @@ -472,29 +479,33 @@ contract PaymentsTest is Test, IPaymentsSignals {
assertEq(IGenericToken(tokenAddr).balanceOf(recipient, tokenId), amount);
}

function testPerformChainedCallInvalidSignature(address caller, uint8 tokenTypeInt, uint256 tokenId, uint256 amount, address recipient, bytes calldata sig)
public
safeAddress(recipient)
{
function testPerformChainedCallInvalidSignature(
address caller,
uint8 tokenTypeInt,
uint256 tokenId,
uint256 amount,
address recipient,
bytes calldata sig
) public safeAddress(recipient) {
vm.assume(caller != signer);

IPaymentsFunctions.TokenType tokenType = _toTokenType(tokenTypeInt);
address tokenAddr;
(tokenAddr, tokenId, amount) = _validTokenParams(tokenType, tokenId, amount);

bytes memory callData = abi.encodeWithSelector(IGenericToken.mint.selector, recipient, tokenId, amount);
IPaymentsFunctions.ChainedCallDetails memory chainedCallDetails = IPaymentsFunctions.ChainedCallDetails(tokenAddr, callData);
IPaymentsFunctions.ChainedCallDetails memory chainedCallDetails =
IPaymentsFunctions.ChainedCallDetails(tokenAddr, callData);

// Send it
vm.expectRevert(InvalidSignature.selector);
vm.prank(caller);
payments.performChainedCall(chainedCallDetails, sig);
}

function testPerformChainedCallInvalidCall(bytes calldata chainedCallData)
public
{
IPaymentsFunctions.ChainedCallDetails memory chainedCallDetails = IPaymentsFunctions.ChainedCallDetails(address(this), chainedCallData);
function testPerformChainedCallInvalidCall(bytes calldata chainedCallData) public {
IPaymentsFunctions.ChainedCallDetails memory chainedCallDetails =
IPaymentsFunctions.ChainedCallDetails(address(this), chainedCallData);
// Check the call will fail
(bool success,) = chainedCallDetails.chainedCallAddress.call(chainedCallDetails.chainedCallData);
vm.assume(!success);
Expand Down

0 comments on commit 9655ac1

Please sign in to comment.