Skip to content

Commit

Permalink
fix: address best practices
Browse files Browse the repository at this point in the history
  • Loading branch information
jaypaik committed Sep 20, 2023
1 parent 5ee69ed commit 581f73a
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 48 deletions.
60 changes: 42 additions & 18 deletions src/CustomSlotInitializable.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

pragma solidity ^0.8.21;

import {Address} from "@openzeppelin/contracts/utils/Address.sol";

/**
* @dev Identical to OpenZeppelin's `Initializable`, except that its state variables are kept at a custom storage slot
* instead of at the start of storage.
Expand Down Expand Up @@ -73,6 +71,16 @@ abstract contract CustomSlotInitializable {
bool initializing;
}

/**
* @dev The contract is already initialized.
*/
error InvalidInitialization();

/**
* @dev The contract is not initializing.
*/
error NotInitializing();

/**
* @dev Triggered when the contract has been initialized or reinitialized.
*/
Expand All @@ -92,13 +100,23 @@ abstract contract CustomSlotInitializable {
* Emits an {Initialized} event.
*/
modifier initializer() {
CustomSlotInitializableStorage storage _storage = _getInitialiazableStorage();
CustomSlotInitializableStorage storage _storage = _getInitializableStorage();

// Cache values to avoid duplicated sloads
bool isTopLevelCall = !_storage.initializing;
require(
(isTopLevelCall && _storage.initialized < 1)
|| (!Address.isContract(address(this)) && _storage.initialized == 1),
"Initializable: contract is already initialized"
);
uint64 initialized = _storage.initialized;

// Allowed calls:
// - initialSetup: the contract is not in the initializing state and no previous version was
// initialized
// - construction: the contract is initialized at version 1 (no reininitialization) and the
// current contract is just being deployed
bool initialSetup = initialized == 0 && isTopLevelCall;
bool construction = initialized == 1 && address(this).code.length == 0;

if (!initialSetup && !construction) {
revert InvalidInitialization();
}
_storage.initialized = 1;
if (isTopLevelCall) {
_storage.initializing = true;
Expand Down Expand Up @@ -129,10 +147,11 @@ abstract contract CustomSlotInitializable {
* Emits an {Initialized} event.
*/
modifier reinitializer(uint8 version) {
CustomSlotInitializableStorage storage _storage = _getInitialiazableStorage();
require(
!_storage.initializing && _storage.initialized < version, "Initializable: contract is already initialized"
);
CustomSlotInitializableStorage storage _storage = _getInitializableStorage();

if (_storage.initializing || _storage.initialized >= version) {
revert InvalidInitialization();
}
_storage.initialized = version;
_storage.initializing = true;
_;
Expand All @@ -145,7 +164,9 @@ abstract contract CustomSlotInitializable {
* {initializer} and {reinitializer} modifiers, directly or indirectly.
*/
modifier onlyInitializing() {
require(_getInitialiazableStorage().initializing, "Initializable: contract is not initializing");
if (!_isInitializing()) {
revert NotInitializing();
}
_;
}

Expand All @@ -158,8 +179,11 @@ abstract contract CustomSlotInitializable {
* Emits an {Initialized} event the first time it is successfully executed.
*/
function _disableInitializers() internal virtual {
CustomSlotInitializableStorage storage _storage = _getInitialiazableStorage();
require(!_storage.initializing, "Initializable: contract is initializing");
CustomSlotInitializableStorage storage _storage = _getInitializableStorage();

if (_storage.initializing) {
revert InvalidInitialization();
}
if (_storage.initialized != type(uint64).max) {
_storage.initialized = type(uint64).max;
emit Initialized(type(uint64).max);
Expand All @@ -170,17 +194,17 @@ abstract contract CustomSlotInitializable {
* @dev Returns the highest version that has been initialized. See {reinitializer}.
*/
function _getInitializedVersion() internal view returns (uint64) {
return _getInitialiazableStorage().initialized;
return _getInitializableStorage().initialized;
}

/**
* @dev Returns `true` if the contract is currently initializing. See {onlyInitializing}.
*/
function _isInitializing() internal view returns (bool) {
return _getInitialiazableStorage().initializing;
return _getInitializableStorage().initializing;
}

function _getInitialiazableStorage() private view returns (CustomSlotInitializableStorage storage _storage) {
function _getInitializableStorage() private view returns (CustomSlotInitializableStorage storage _storage) {
bytes32 position = _storagePosition;
assembly {
_storage.slot := position
Expand Down
71 changes: 53 additions & 18 deletions src/LightAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ import {CustomSlotInitializable} from "./CustomSlotInitializable.sol";
* user operations through a bundler.
*
* 4. Event `SimpleAccountInitialized` renamed to `LightAccountInitialized`.
*
* 5. Uses custom errors.
*/
contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, CustomSlotInitializable, IERC1271 {
using ECDSA for bytes32;
Expand Down Expand Up @@ -74,6 +76,22 @@ contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, Cus
*/
event OwnershipTransferred(address indexed previousOwner, address indexed newOwner);

/**
* @dev The length of the array does not match the expected length.
*/
error ArrayLengthMismatch();

/**
* @dev The new owner is not a valid owner (e.g., `address(0)` or the
* account itself).
*/
error InvalidOwner(address owner);

/**
* @dev The caller is not authorized.
*/
error NotAuthorized(address caller);

modifier onlyOwner() {
_onlyOwner();
_;
Expand Down Expand Up @@ -108,9 +126,15 @@ contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, Cus
*/
function executeBatch(address[] calldata dest, bytes[] calldata func) external {
_requireFromEntryPointOrOwner();
require(dest.length == func.length, "wrong array lengths");
for (uint256 i = 0; i < dest.length; i++) {
if (dest.length != func.length) {
revert ArrayLengthMismatch();
}
uint256 length = dest.length;
for (uint256 i = 0; i < length;) {
_call(dest[i], 0, func[i]);
unchecked {
++i;
}
}
}

Expand All @@ -124,12 +148,31 @@ contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, Cus
*/
function executeBatch(address[] calldata dest, uint256[] calldata value, bytes[] calldata func) external {
_requireFromEntryPointOrOwner();
require(dest.length == func.length && dest.length == value.length, "wrong array lengths");
for (uint256 i = 0; i < dest.length; i++) {
if (dest.length != func.length || dest.length != value.length) {
revert ArrayLengthMismatch();
}
uint256 length = dest.length;
for (uint256 i = 0; i < length;) {
_call(dest[i], value[i], func[i]);
unchecked {
++i;
}
}
}

/**
* @notice Transfers ownership of the contract to a new account (`newOwner`).
* Can only be called by the current owner or from the entry point via a
* user operation signed by the current owner.
* @param newOwner The new owner
*/
function transferOwnership(address newOwner) external virtual onlyOwner {
if (newOwner == address(0) || newOwner == address(this)) {
revert InvalidOwner(newOwner);
}
_transferOwnership(newOwner);
}

/**
* @notice Called once as part of initialization, either during initial deployment or when first upgrading to
* this contract.
Expand Down Expand Up @@ -158,18 +201,6 @@ contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, Cus
entryPoint().withdrawTo(withdrawAddress, amount);
}

/**
* @notice Transfers ownership of the contract to a new account (`newOwner`).
* Can only be called by the current owner or from the entry point via a
* user operation signed by the current owner.
* @param newOwner The new owner
*/
function transferOwnership(address newOwner) public virtual onlyOwner {
require(newOwner != address(0), "account: new owner is the zero address");
require(newOwner != address(this), "account: new owner is self");
_transferOwnership(newOwner);
}

/// @inheritdoc BaseAccount
function entryPoint() public view virtual override returns (IEntryPoint) {
return _entryPoint;
Expand Down Expand Up @@ -252,12 +283,16 @@ contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, Cus

function _onlyOwner() internal view {
//directly from EOA owner, or through the account itself (which gets redirected through execute())
require(msg.sender == owner() || msg.sender == address(this), "only owner");
if (msg.sender != address(this) && msg.sender != owner()) {
revert NotAuthorized(msg.sender);
}
}

// Require the function call went through EntryPoint or owner
function _requireFromEntryPointOrOwner() internal view {
require(msg.sender == address(entryPoint()) || msg.sender == owner(), "account: not Owner or EntryPoint");
if (msg.sender != address(entryPoint()) && msg.sender != owner()) {
revert NotAuthorized(msg.sender);
}
}

function _call(address target, uint256 value, bytes memory data) internal {
Expand Down
8 changes: 4 additions & 4 deletions test/CustomSlotInitializable.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,26 @@ contract CustomSlotInitializableTest is Test {
}

function testCannotReinitialize() public {
vm.expectRevert(bytes("Initializable: contract is already initialized"));
vm.expectRevert(CustomSlotInitializable.InvalidInitialization.selector);
v1Proxy.upgradeToAndCall(v1Impl, abi.encodeCall(V1.initialize, ()));
}

function testCannotUpgradeBackwards() public {
v1Proxy.upgradeToAndCall(v2Impl, abi.encodeCall(V2.initialize, ()));
V2 v2Proxy = V2(address(v1Proxy));
vm.expectRevert(bytes("Initializable: contract is already initialized"));
vm.expectRevert(CustomSlotInitializable.InvalidInitialization.selector);
v2Proxy.upgradeToAndCall(v1Impl, abi.encodeCall(V1.initialize, ()));
}

function testDisableInitializers() public {
v1Proxy.disableInitializers();
vm.expectRevert(bytes("Initializable: contract is already initialized"));
vm.expectRevert(CustomSlotInitializable.InvalidInitialization.selector);
v1Proxy.upgradeToAndCall(v2Impl, abi.encodeCall(V2.initialize, ()));
}

function testCannotCallDisableInitializersInInitializer() public {
DisablesInitializersWhileInitializing account = new DisablesInitializersWhileInitializing();
vm.expectRevert("Initializable: contract is initializing");
vm.expectRevert(CustomSlotInitializable.InvalidInitialization.selector);
account.initialize();
}

Expand Down
16 changes: 8 additions & 8 deletions test/LightAccount.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ contract LightAccountTest is Test {
}

function testExecuteCannotBeCalledByRandos() public {
vm.expectRevert(bytes("account: not Owner or EntryPoint"));
vm.expectRevert(abi.encodeWithSelector(LightAccount.NotAuthorized.selector, (address(this))));
account.execute(address(lightSwitch), 0, abi.encodeCall(LightSwitch.turnOn, ()));
}

Expand Down Expand Up @@ -110,7 +110,7 @@ contract LightAccountTest is Test {
dest[1] = address(lightSwitch);
bytes[] memory func = new bytes[](1);
func[0] = abi.encodeCall(LightSwitch.turnOn, ());
vm.expectRevert(bytes("wrong array lengths"));
vm.expectRevert(LightAccount.ArrayLengthMismatch.selector);
account.executeBatch(dest, func);
}

Expand All @@ -136,7 +136,7 @@ contract LightAccountTest is Test {
value[1] = uint256(1 ether);
bytes[] memory func = new bytes[](1);
func[0] = abi.encodeCall(LightSwitch.turnOn, ());
vm.expectRevert(bytes("wrong array lengths"));
vm.expectRevert(LightAccount.ArrayLengthMismatch.selector);
account.executeBatch(dest, value, func);
}

Expand All @@ -163,7 +163,7 @@ contract LightAccountTest is Test {

function testWithdrawDepositToCannotBeCalledByRandos() public {
account.addDeposit{value: 10}();
vm.expectRevert(bytes("only owner"));
vm.expectRevert(abi.encodeWithSelector(LightAccount.NotAuthorized.selector, (address(this))));
account.withdrawDepositTo(BENEFICIARY, 5);
}

Expand All @@ -189,19 +189,19 @@ contract LightAccountTest is Test {
}

function testRandosCannotTransferOwnership() public {
vm.expectRevert(bytes("only owner"));
vm.expectRevert(abi.encodeWithSelector(LightAccount.NotAuthorized.selector, (address(this))));
account.transferOwnership(address(0x100));
}

function testCannotTransferOwnershipToZero() public {
vm.prank(eoaAddress);
vm.expectRevert(bytes("account: new owner is the zero address"));
vm.expectRevert(abi.encodeWithSelector(LightAccount.InvalidOwner.selector, (address(0))));
account.transferOwnership(address(0));
}

function testCannotTransferOwnershipToLightContractItself() public {
vm.prank(eoaAddress);
vm.expectRevert(bytes("account: new owner is self"));
vm.expectRevert(abi.encodeWithSelector(LightAccount.InvalidOwner.selector, (address(account))));
account.transferOwnership(address(account));
}

Expand Down Expand Up @@ -244,7 +244,7 @@ contract LightAccountTest is Test {
// Try to upgrade to a normal SimpleAccount with a different entry point.
IEntryPoint newEntryPoint = IEntryPoint(address(0x2000));
SimpleAccount newImplementation = new SimpleAccount(newEntryPoint);
vm.expectRevert(bytes("only owner"));
vm.expectRevert(abi.encodeWithSelector(LightAccount.NotAuthorized.selector, (address(this))));
account.upgradeToAndCall(address(newImplementation), abi.encodeCall(SimpleAccount.initialize, (address(this))));
}

Expand Down

0 comments on commit 581f73a

Please sign in to comment.