Skip to content

Commit

Permalink
Parity for chained calls
Browse files Browse the repository at this point in the history
  • Loading branch information
ScreamingHawk committed Jul 16, 2024
1 parent be90c68 commit de0dbe3
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 63 deletions.
66 changes: 43 additions & 23 deletions src/payments/IPayments.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ interface IPaymentsFunctions {
uint256 amount;
}

struct ChainedCallDetails {
// Address for chained call
address chainedCallAddress;
// Data for chained call
bytes chainedCallData;
}

struct PaymentDetails {
// Unique ID for this purchase
uint256 purchaseId;
Expand All @@ -32,44 +39,34 @@ interface IPaymentsFunctions {
uint64 expiration;
// ID of the product
string productId;
// Address for chained call
address chainedCallAddress;
// Data for chained call
bytes chainedCallData;
// Chained call details
ChainedCallDetails chainedCallDetails;
}

/**
* Make a payment for a product.
* Returns the hash of the payment details.
* @param paymentDetails The payment details.
* @param signature The signature of the payment.
*/
function makePayment(PaymentDetails calldata paymentDetails, bytes calldata signature) external payable;

/**
* Complete a chained call.
* @param chainedCallAddress The address of the chained call.
* @param chainedCallData The data for the chained call.
* @notice This is only callable by an authorised party.
* @return paymentHash The hash of the payment details for signing.
*/
function performChainedCall(address chainedCallAddress, bytes calldata chainedCallData) external;
function hashPaymentDetails(PaymentDetails calldata paymentDetails) external view returns (bytes32 paymentHash);

/**
* Check is a signature is valid.
* Check is a payment signature is valid.
* @param paymentDetails The payment details.
* @param signature The signature of the payment.
* @return isValid True if the signature is valid.
*/
function isValidSignature(PaymentDetails calldata paymentDetails, bytes calldata signature)
function isValidPaymentSignature(PaymentDetails calldata paymentDetails, bytes calldata signature)
external
view
returns (bool isValid);

/**
* Returns the hash of the payment details.
* Make a payment for a product.
* @param paymentDetails The payment details.
* @return paymentHash The hash of the payment details for signing.
* @param signature The signature of the payment.
*/
function hashPaymentDetails(PaymentDetails calldata paymentDetails) external view returns (bytes32 paymentHash);
function makePayment(PaymentDetails calldata paymentDetails, bytes calldata signature) external payable;

/**
* Check if a payment has been accepted.
Expand All @@ -78,6 +75,32 @@ interface IPaymentsFunctions {
*/
function paymentAccepted(uint256 purchaseId) external view returns (bool);

/**
* Returns the hash of the chained call.
* @param chainedCallDetails The chained call details.
* @return callHash The hash of the chained call for signing.
*/
function hashChainedCallDetails(ChainedCallDetails calldata chainedCallDetails) external view returns (bytes32 callHash);

/**
* Complete a chained call.
* @param chainedCallDetails The chained call details.
* @param signature The signature of the chained call.
* @dev This is called when a payment is accepted off/cross chain.
*/
function performChainedCall(ChainedCallDetails calldata chainedCallDetails, bytes calldata signature) external;

/**
* Check is a chained call signature is valid.
* @param chainedCallDetails The chained call details.
* @param signature The signature of the chained call.
* @return isValid True if the signature is valid.
*/
function isValidChainedCallSignature(ChainedCallDetails calldata chainedCallDetails, bytes calldata signature)
external
view
returns (bool isValid);

/**
* Get the signer address.
* @return signer The signer address.
Expand All @@ -89,9 +112,6 @@ interface IPaymentsSignals {
/// @notice Emitted when a payment is already accepted. This prevents double spending.
error PaymentAlreadyAccepted();

/// @notice Emitted when a sender is invalid.
error InvalidSender();

/// @notice Emitted when a signature is invalid.
error InvalidSignature();

Expand Down
48 changes: 34 additions & 14 deletions src/payments/Payments.sol
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ contract Payments is Ownable, IPayments, IERC165 {
if (paymentAccepted[paymentDetails.purchaseId]) {
revert PaymentAlreadyAccepted();
}
if (!isValidSignature(paymentDetails, signature)) {
if (!isValidPaymentSignature(paymentDetails, signature)) {
revert InvalidSignature();
}
if (block.timestamp > paymentDetails.expiration) {
Expand All @@ -65,14 +65,14 @@ contract Payments is Ownable, IPayments, IERC165 {
emit PaymentMade(spender, paymentDetails.productRecipient, paymentDetails.purchaseId, paymentDetails.productId);

// Perform chained call
if (paymentDetails.chainedCallAddress != address(0)) {
_performChainedCall(paymentDetails.chainedCallAddress, paymentDetails.chainedCallData);
if (paymentDetails.chainedCallDetails.chainedCallAddress != address(0)) {
_performChainedCall(paymentDetails.chainedCallDetails);
}
}

/// @inheritdoc IPaymentsFunctions
/// @notice A valid signature does not guarantee that the payment will be accepted.
function isValidSignature(PaymentDetails calldata paymentDetails, bytes calldata signature)
function isValidPaymentSignature(PaymentDetails calldata paymentDetails, bytes calldata signature)
public
view
returns (bool)
Expand All @@ -96,28 +96,48 @@ contract Payments is Ownable, IPayments, IERC165 {
paymentDetails.paymentRecipients,
paymentDetails.expiration,
paymentDetails.productId,
paymentDetails.chainedCallAddress,
paymentDetails.chainedCallData
paymentDetails.chainedCallDetails
)
);
}

/// @inheritdoc IPaymentsFunctions
/// @notice This can only be called by the signer.
/// @dev As the signer can validate any payment (including zero) this function does not increase the security surface.
function performChainedCall(address chainedCallAddress, bytes calldata chainedCallData) external override {
// Check authorization
if (msg.sender != signer) {
revert InvalidSender();
function performChainedCall(ChainedCallDetails calldata chainedCallDetails, bytes calldata signature) external override {
if (!isValidChainedCallSignature(chainedCallDetails, signature)) {
revert InvalidSignature();
}
_performChainedCall(chainedCallAddress, chainedCallData);
_performChainedCall(chainedCallDetails);
}

/// @inheritdoc IPaymentsFunctions
function isValidChainedCallSignature(ChainedCallDetails calldata chainedCallDetails, bytes calldata signature)
public
view
returns (bool)
{
bytes32 messageHash = hashChainedCallDetails(chainedCallDetails);
address sigSigner = messageHash.recoverCalldata(signature);
return sigSigner == signer;
}

/// @inheritdoc IPaymentsFunctions
/// @dev This hash includes the chain ID.
function hashChainedCallDetails(ChainedCallDetails calldata chainedCallDetails) public view returns (bytes32) {
return keccak256(
abi.encode(
block.chainid,
chainedCallDetails.chainedCallAddress,
chainedCallDetails.chainedCallData
)
);
}

/**
* Perform a chained call and revert on error.
*/
function _performChainedCall(address chainedCallAddress, bytes calldata chainedCallData) internal {
(bool success, ) = chainedCallAddress.call{value: 0}(chainedCallData);
function _performChainedCall(ChainedCallDetails calldata chainedCallDetails) internal {
(bool success, ) = chainedCallDetails.chainedCallAddress.call{value: 0}(chainedCallDetails.chainedCallData);
if (!success) {
revert ChainedCallFailed();
}
Expand Down
58 changes: 32 additions & 26 deletions test/payments/Payments.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ contract PaymentsTest is Test, IPaymentsSignals {
paymentRecipients,
expiration,
input.productId,
address(0),
""
IPaymentsFunctions.ChainedCallDetails(
address(0),
""
)
);

// Mint required tokens
Expand Down Expand Up @@ -150,8 +152,7 @@ contract PaymentsTest is Test, IPaymentsSignals {
paymentRecipients,
expiration,
input.productId,
chainedTokenAddr,
chainedData
IPaymentsFunctions.ChainedCallDetails(chainedTokenAddr, chainedData)
);

// Mint required tokens
Expand Down Expand Up @@ -201,8 +202,7 @@ contract PaymentsTest is Test, IPaymentsSignals {
paymentRecipients,
expiration,
input.productId,
address(0),
""
IPaymentsFunctions.ChainedCallDetails(address(0), "")
);

// Mint required tokens
Expand Down Expand Up @@ -243,8 +243,7 @@ contract PaymentsTest is Test, IPaymentsSignals {
paymentRecipients,
expiration,
input.productId,
address(0),
""
IPaymentsFunctions.ChainedCallDetails(address(0), "")
);

// Send it
Expand Down Expand Up @@ -276,8 +275,7 @@ contract PaymentsTest is Test, IPaymentsSignals {
paymentRecipients,
input.expiration,
input.productId,
address(0),
""
IPaymentsFunctions.ChainedCallDetails(address(0), "")
);

// Mint required tokens
Expand Down Expand Up @@ -316,8 +314,7 @@ contract PaymentsTest is Test, IPaymentsSignals {
paymentRecipients,
expiration,
input.productId,
address(0),
""
IPaymentsFunctions.ChainedCallDetails(address(0), "")
);

// Do not mint required tokens
Expand Down Expand Up @@ -356,8 +353,7 @@ contract PaymentsTest is Test, IPaymentsSignals {
paymentRecipients,
expiration,
input.productId,
address(0),
""
IPaymentsFunctions.ChainedCallDetails(address(0), "")
);

// Sign it
Expand Down Expand Up @@ -393,8 +389,7 @@ contract PaymentsTest is Test, IPaymentsSignals {
paymentRecipients,
expiration,
input.productId,
address(0),
""
IPaymentsFunctions.ChainedCallDetails(address(0), "")
);

// Sign it
Expand Down Expand Up @@ -434,8 +429,7 @@ contract PaymentsTest is Test, IPaymentsSignals {
paymentRecipients,
expiration,
input.productId,
address(payments), // Chained call to payments will fail
chainedCallData
IPaymentsFunctions.ChainedCallDetails(address(payments), chainedCallData)
);

// Mint required tokens
Expand Down Expand Up @@ -464,15 +458,21 @@ contract PaymentsTest is Test, IPaymentsSignals {
(tokenAddr, tokenId, amount) = _validTokenParams(tokenType, tokenId, amount);

bytes memory callData = abi.encodeWithSelector(IGenericToken.mint.selector, recipient, tokenId, amount);
IPaymentsFunctions.ChainedCallDetails memory chainedCallDetails = IPaymentsFunctions.ChainedCallDetails(tokenAddr, callData);

// Sign it
bytes32 messageHash = payments.hashChainedCallDetails(chainedCallDetails);
(uint8 v, bytes32 r, bytes32 s) = vm.sign(signerPk, messageHash);
bytes memory sig = abi.encodePacked(r, s, v);

// Send it
vm.prank(signer);
payments.performChainedCall(tokenAddr, callData);
payments.performChainedCall(chainedCallDetails, sig);

assertEq(IGenericToken(tokenAddr).balanceOf(recipient, tokenId), amount);
}

function testPerformChainedCallInvalidCaller(address caller, uint8 tokenTypeInt, uint256 tokenId, uint256 amount, address recipient)
function testPerformChainedCallInvalidSignature(address caller, uint8 tokenTypeInt, uint256 tokenId, uint256 amount, address recipient, bytes calldata sig)
public
safeAddress(recipient)
{
Expand All @@ -483,24 +483,30 @@ contract PaymentsTest is Test, IPaymentsSignals {
(tokenAddr, tokenId, amount) = _validTokenParams(tokenType, tokenId, amount);

bytes memory callData = abi.encodeWithSelector(IGenericToken.mint.selector, recipient, tokenId, amount);
IPaymentsFunctions.ChainedCallDetails memory chainedCallDetails = IPaymentsFunctions.ChainedCallDetails(tokenAddr, callData);

// Send it
vm.expectRevert(InvalidSender.selector);
vm.expectRevert(InvalidSignature.selector);
vm.prank(caller);
payments.performChainedCall(tokenAddr, callData);
payments.performChainedCall(chainedCallDetails, sig);
}

function testPerformChainedCallInvalidCall(bytes memory chainedCallData)
function testPerformChainedCallInvalidCall(bytes calldata chainedCallData)
public
{
IPaymentsFunctions.ChainedCallDetails memory chainedCallDetails = IPaymentsFunctions.ChainedCallDetails(address(this), chainedCallData);
// Check the call will fail
(bool success,) = address(payments).call(chainedCallData);
(bool success,) = chainedCallDetails.chainedCallAddress.call(chainedCallDetails.chainedCallData);
vm.assume(!success);

// Sign it
bytes32 messageHash = payments.hashChainedCallDetails(chainedCallDetails);
(uint8 v, bytes32 r, bytes32 s) = vm.sign(signerPk, messageHash);
bytes memory sig = abi.encodePacked(r, s, v);

vm.expectRevert(ChainedCallFailed.selector);
vm.prank(signer);
// Chained call to payments will fail
payments.performChainedCall(address(payments), chainedCallData);
payments.performChainedCall(chainedCallDetails, sig);
}

// Update signer
Expand Down

0 comments on commit de0dbe3

Please sign in to comment.