diff --git a/test/ERC20.t.sol b/test/ERC20.t.sol index 32c70d9..e23e01b 100644 --- a/test/ERC20.t.sol +++ b/test/ERC20.t.sol @@ -8,6 +8,7 @@ import {ERC20OpenZeppelin} from "@src/ERC20OpenZeppelin.sol"; import {ERC20Solady} from "@src/ERC20Solady.sol"; import {ERC20Solmate} from "@src/ERC20Solmate.sol"; +/// @custom:halmos --storage-layout=generic --loop 6 contract ERC20Test is Test, SymTest { ERC20OpenZeppelin public openzeppelin; ERC20Solady public solady; @@ -17,9 +18,9 @@ contract ERC20Test is Test, SymTest { bytes4[] callSelectors; function setUp() public { - openzeppelin = new ERC20OpenZeppelin("Token", "TOK", 6, msg.sender, 123e18); - solady = new ERC20Solady("Token", "TOK", 6, msg.sender, 123e18); - solmate = new ERC20Solmate("Token", "TOK", 6, msg.sender, 123e18); + openzeppelin = new ERC20OpenZeppelin("Token", "TOK", 6, address(0x1337), 123e18); + solady = new ERC20Solady("Token", "TOK", 6, address(0x1337), 123e18); + solmate = new ERC20Solmate("Token", "TOK", 6, address(0x1337), 123e18); staticcallSelectors = [ IERC20.balanceOf.selector, @@ -51,9 +52,12 @@ contract ERC20Test is Test, SymTest { } } - function check_differential_call(bytes[] memory calls, bytes[] memory staticcalls) public { + function check_differential_call(address[] memory senders, bytes[] memory calls, bytes[] memory staticcalls) + public + { vm.assume(isValidSelectors(callSelectors, calls)); vm.assume(isValidSelectors(staticcallSelectors, staticcalls)); + vm.assume(senders.length == calls.length); address[3] memory contracts = [address(openzeppelin), address(solady), address(solmate)]; @@ -62,30 +66,31 @@ contract ERC20Test is Test, SymTest { successes = new bool[](calls.length); results = new bytes[](calls.length); - for (uint256 i = 0; i < contracts.length; i++) { - for (uint256 j = 0; j < calls.length; j++) { + for (uint256 j = 0; j < calls.length; j++) { + for (uint256 i = 0; i < contracts.length; i++) { + vm.prank(senders[j]); (bool _success, bytes memory _result) = contracts[i].call(calls[j]); if (i == 0) { successes[i] = _success; results[i] = _result; } else { - assertEq(successes[j], _success); - assertEq(results[j], _result); + assertEq(successes[i], _success); + assertEq(results[i], _result); } } } successes = new bool[](staticcalls.length); results = new bytes[](staticcalls.length); - for (uint256 i = 0; i < contracts.length; i++) { - for (uint256 j = 0; j < staticcalls.length; j++) { - (bool _success, bytes memory _result) = contracts[i].call(staticcalls[j]); + for (uint256 j = 0; j < staticcalls.length; j++) { + for (uint256 i = 0; i < contracts.length; i++) { + (bool _success, bytes memory _result) = contracts[i].staticcall(staticcalls[j]); if (i == 0) { successes[i] = _success; results[i] = _result; } else { - assertEq(successes[j], _success); - assertEq(results[j], _result); + assertEq(successes[i], _success); + assertEq(results[i], _result); } } }