diff --git a/script/gemforge/utils/index.js b/script/gemforge/utils/index.js index 07dd8e73..4706b622 100644 --- a/script/gemforge/utils/index.js +++ b/script/gemforge/utils/index.js @@ -32,10 +32,16 @@ const getProxyAddress = (exports.getProxyAddress = (targetId) => { exports.calculateUpgradeId = async (cutFile) => { const cutData = require(cutFile); - const encodedData = ethers.utils.defaultAbiCoder.encode( - ["tuple(address facetAddress, uint8 action, bytes4[] functionSelectors)[]", "address", "bytes"], - [cutData.cuts, cutData.initContractAddress, cutData.initData] + + const codeHashes = await Promise.all( + cutData.cuts.map(async (cut) => { + const code = await provider.getCode(cut.facetAddress); + return ethers.utils.keccak256(code); + }) ); + + const encodedData = ethers.utils.defaultAbiCoder.encode(["bytes32[]", "address", "bytes"], [codeHashes, cutData.initContractAddress, cutData.initData]); + return ethers.utils.keccak256(encodedData); }; diff --git a/src/facets/GovernanceFacet.sol b/src/facets/GovernanceFacet.sol index 1787b138..89ee0eba 100644 --- a/src/facets/GovernanceFacet.sol +++ b/src/facets/GovernanceFacet.sol @@ -25,12 +25,12 @@ contract GovernanceFacet is Modifiers { /** * @notice Calcuate upgrade hash: `id` * @dev calucate the upgrade hash by hashing all the inputs - * @param _diamondCut the array of FacetCut struct, IDiamondCut.FacetCut[] to be used for upgrade + * @param _codeHashes the array of contract bytecode hashes * @param _init address of the init diamond to be used for upgrade * @param _calldata bytes to be passed as call data for upgrade */ - function calculateUpgradeId(IDiamondCut.FacetCut[] calldata _diamondCut, address _init, bytes calldata _calldata) external pure returns (bytes32) { - return LibGovernance._calculateUpgradeId(_diamondCut, _init, _calldata); + function calculateUpgradeId(bytes32[] calldata _codeHashes, address _init, bytes calldata _calldata) external pure returns (bytes32) { + return LibGovernance._calculateUpgradeId(_codeHashes, _init, _calldata); } /** diff --git a/src/facets/PhasedDiamondCutFacet.sol b/src/facets/PhasedDiamondCutFacet.sol index be1f544c..785e9e61 100644 --- a/src/facets/PhasedDiamondCutFacet.sol +++ b/src/facets/PhasedDiamondCutFacet.sol @@ -24,7 +24,14 @@ contract PhasedDiamondCutFacet is IDiamondCut { { AppStorage storage s = LibAppStorage.diamondStorage(); - bytes32 upgradeId = LibGovernance._calculateUpgradeId(_diamondCut, _init, _calldata); + bytes32[] memory codeHashes = new bytes32[](_diamondCut.length); + for (uint256 i; i < _diamondCut.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(_diamondCut[i].facetAddress); + } + + // Calculate upgradeId (hash of codeHashes, _init, _calldata + bytes32 upgradeId = LibGovernance._calculateUpgradeId(codeHashes, _init, _calldata); + if (s.upgradeScheduled[upgradeId] < block.timestamp) { revert PhasedDiamondCutUpgradeFailed(upgradeId, block.timestamp); } diff --git a/src/libs/LibGovernance.sol b/src/libs/LibGovernance.sol index 92f699ab..c6bb2075 100644 --- a/src/libs/LibGovernance.sol +++ b/src/libs/LibGovernance.sol @@ -1,11 +1,15 @@ // SPDX-License-Identifier: MIT pragma solidity 0.8.20; -import { IDiamondCut } from "lib/diamond-2-hardhat/contracts/interfaces/IDiamondCut.sol"; - /// @notice Contains internal methods for upgrade functionality library LibGovernance { - function _calculateUpgradeId(IDiamondCut.FacetCut[] memory _diamondCut, address _init, bytes memory _calldata) internal pure returns (bytes32) { - return keccak256(abi.encode(_diamondCut, _init, _calldata)); + function _calculateUpgradeId(bytes32[] memory _codeHashes, address _init, bytes memory _calldata) internal pure returns (bytes32) { + return keccak256(abi.encode(_codeHashes, _init, _calldata)); + } + + function _getCodeHash(address contractAddress) internal view returns (bytes32 codehash_) { + assembly { + codehash_ := extcodehash(contractAddress) + } } } diff --git a/test/T01Deployment.t.sol b/test/T01Deployment.t.sol index 915d2f94..4a131b3e 100644 --- a/test/T01Deployment.t.sol +++ b/test/T01Deployment.t.sol @@ -60,8 +60,8 @@ contract T01DeploymentTest is D03ProtocolDefaults { function testCallInitDiamondTwice() public skipWhenForking { // note: Cannot use the InitDiamond contract more than once to initialize a diamond. IDiamondCut.FacetCut[] memory cut; - - bytes32 upgradeHash = LibGovernance._calculateUpgradeId(cut, address(initDiamond), abi.encodeCall(initDiamond.init, (systemAdmin))); + bytes32[] memory codeHashes = new bytes32[](cut.length); + bytes32 upgradeHash = LibGovernance._calculateUpgradeId(codeHashes, address(initDiamond), abi.encodeCall(initDiamond.init, (systemAdmin))); changePrank(systemAdmin); nayms.createUpgrade(upgradeHash); diff --git a/test/T01GovernanceUpgrades.t.sol b/test/T01GovernanceUpgrades.t.sol index 3827060a..81d68f2f 100644 --- a/test/T01GovernanceUpgrades.t.sol +++ b/test/T01GovernanceUpgrades.t.sol @@ -17,7 +17,8 @@ contract TestFacet { function sayHello() external pure returns (string memory greeting) { greeting = "hello"; } - +} +contract TestFacet2 { function sayHello2() external pure returns (string memory greeting) { greeting = "hello2"; } @@ -28,6 +29,7 @@ contract T01GovernanceUpgrades is D03ProtocolDefaults, MockAccounts { uint256 public constant STARTING_BLOCK_TIMESTAMP = 100; address public testFacetAddress; + address public testFacet2Address; function setUp() public { // note: The diamond starts with the PhasedDiamondCutFacet insteaad of the original DiamondCutFacet @@ -35,6 +37,7 @@ contract T01GovernanceUpgrades is D03ProtocolDefaults, MockAccounts { IDiamondCut.FacetCut[] memory cut = new IDiamondCut.FacetCut[](1); testFacetAddress = address(new TestFacet()); + testFacet2Address = address(new TestFacet2()); bytes4[] memory f0 = new bytes4[](1); f0[0] = TestFacet.sayHello.selector; cut[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacetAddress), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f0 }); @@ -50,8 +53,12 @@ contract T01GovernanceUpgrades is D03ProtocolDefaults, MockAccounts { f0[0] = TestFacet.sayHello.selector; cut[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacetAddress), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f0 }); + bytes32[] memory codeHashes = new bytes32[](cut.length); + for (uint i; i < cut.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(cut[i].facetAddress); + } // try to call diamondCut() without scheduling - bytes32 upgradeId = LibGovernance._calculateUpgradeId(cut, address(0), ""); + bytes32 upgradeId = LibGovernance._calculateUpgradeId(codeHashes, address(0), ""); vm.expectRevert(abi.encodeWithSelector(PhasedDiamondCutUpgradeFailed.selector, upgradeId, block.timestamp)); nayms.diamondCut(cut, address(0), ""); } @@ -62,10 +69,15 @@ contract T01GovernanceUpgrades is D03ProtocolDefaults, MockAccounts { f0[0] = TestFacet.sayHello.selector; cut[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacetAddress), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f0 }); + bytes32[] memory codeHashes = new bytes32[](cut.length); + for (uint i; i < cut.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(cut[i].facetAddress); + } + vm.warp(7 days + STARTING_BLOCK_TIMESTAMP + 1); // try to call diamondCut() without scheduling - bytes32 upgradeId = LibGovernance._calculateUpgradeId(cut, address(0), ""); + bytes32 upgradeId = LibGovernance._calculateUpgradeId(codeHashes, address(0), ""); vm.expectRevert(abi.encodeWithSelector(PhasedDiamondCutUpgradeFailed.selector, upgradeId, block.timestamp)); nayms.diamondCut(cut, address(0), ""); } @@ -76,7 +88,12 @@ contract T01GovernanceUpgrades is D03ProtocolDefaults, MockAccounts { f0[0] = TestFacet.sayHello.selector; cut[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacetAddress), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f0 }); - bytes32 upgradeId = LibGovernance._calculateUpgradeId(cut, address(0), ""); + bytes32[] memory codeHashes = new bytes32[](cut.length); + for (uint i; i < cut.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(cut[i].facetAddress); + } + + bytes32 upgradeId = LibGovernance._calculateUpgradeId(codeHashes, address(0), ""); nayms.createUpgrade(upgradeId); changePrank(owner); @@ -94,7 +111,12 @@ contract T01GovernanceUpgrades is D03ProtocolDefaults, MockAccounts { f0[0] = TestFacet.sayHello.selector; cut[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacetAddress), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f0 }); - bytes32 upgradeId = LibGovernance._calculateUpgradeId(cut, address(0), ""); + bytes32[] memory codeHashes = new bytes32[](cut.length); + for (uint i; i < cut.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(cut[i].facetAddress); + } + + bytes32 upgradeId = LibGovernance._calculateUpgradeId(codeHashes, address(0), ""); nayms.createUpgrade(upgradeId); changePrank(address(0xAAAAAAAAA)); @@ -108,7 +130,12 @@ contract T01GovernanceUpgrades is D03ProtocolDefaults, MockAccounts { f0[0] = TestFacet.sayHello.selector; cut[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacetAddress), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f0 }); - bytes32 upgradeId = LibGovernance._calculateUpgradeId(cut, address(0), ""); + bytes32[] memory codeHashes = new bytes32[](cut.length); + for (uint i; i < cut.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(cut[i].facetAddress); + } + + bytes32 upgradeId = LibGovernance._calculateUpgradeId(codeHashes, address(0), ""); vm.expectRevert("invalid upgrade ID"); nayms.cancelUpgrade(upgradeId); @@ -127,7 +154,12 @@ contract T01GovernanceUpgrades is D03ProtocolDefaults, MockAccounts { f0[0] = TestFacet.sayHello.selector; cut[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacetAddress), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f0 }); - bytes32 upgradeId = LibGovernance._calculateUpgradeId(cut, address(0), ""); + bytes32[] memory codeHashes = new bytes32[](cut.length); + for (uint i; i < cut.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(cut[i].facetAddress); + } + + bytes32 upgradeId = LibGovernance._calculateUpgradeId(codeHashes, address(0), ""); nayms.createUpgrade(upgradeId); vm.expectRevert("Upgrade has already been scheduled"); @@ -145,16 +177,26 @@ contract T01GovernanceUpgrades is D03ProtocolDefaults, MockAccounts { f0[0] = TestFacet.sayHello.selector; cut[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacetAddress), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f0 }); - bytes32 upgradeId = LibGovernance._calculateUpgradeId(cut, address(0), ""); + bytes32[] memory codeHashes = new bytes32[](cut.length); + for (uint i; i < cut.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(cut[i].facetAddress); + } + + bytes32 upgradeId = LibGovernance._calculateUpgradeId(codeHashes, address(0), ""); nayms.createUpgrade(upgradeId); // cut in the method sayHello2() IDiamondCut.FacetCut[] memory cut2 = new IDiamondCut.FacetCut[](1); bytes4[] memory f1 = new bytes4[](1); - f1[0] = TestFacet.sayHello2.selector; - cut2[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacetAddress), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f1 }); + f1[0] = TestFacet2.sayHello2.selector; + cut2[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacet2Address), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f1 }); + + codeHashes = new bytes32[](cut2.length); + for (uint i; i < cut2.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(cut2[i].facetAddress); + } - bytes32 upgradeId2 = LibGovernance._calculateUpgradeId(cut2, address(0), ""); + bytes32 upgradeId2 = LibGovernance._calculateUpgradeId(codeHashes, address(0), ""); nayms.createUpgrade(upgradeId2); changePrank(owner); @@ -170,7 +212,12 @@ contract T01GovernanceUpgrades is D03ProtocolDefaults, MockAccounts { f0[0] = TestFacet.sayHello.selector; cut[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacetAddress), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f0 }); - bytes32 upgradeId = LibGovernance._calculateUpgradeId(cut, address(0), ""); + bytes32[] memory codeHashes = new bytes32[](cut.length); + for (uint i; i < cut.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(cut[i].facetAddress); + } + + bytes32 upgradeId = LibGovernance._calculateUpgradeId(codeHashes, address(0), ""); nayms.createUpgrade(upgradeId); changePrank(address(0xAAAAAAAAA)); @@ -190,10 +237,15 @@ contract T01GovernanceUpgrades is D03ProtocolDefaults, MockAccounts { IDiamondCut.FacetCut[] memory cut2 = new IDiamondCut.FacetCut[](1); bytes4[] memory f1 = new bytes4[](1); - f1[0] = TestFacet.sayHello2.selector; - cut2[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacetAddress), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f1 }); + f1[0] = TestFacet2.sayHello2.selector; + cut2[0] = IDiamondCut.FacetCut({ facetAddress: address(testFacet2Address), action: IDiamondCut.FacetCutAction.Add, functionSelectors: f1 }); + + codeHashes = new bytes32[](cut2.length); + for (uint i; i < cut2.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(cut2[i].facetAddress); + } - bytes32 upgradeId2 = LibGovernance._calculateUpgradeId(cut2, address(0), ""); + bytes32 upgradeId2 = LibGovernance._calculateUpgradeId(codeHashes, address(0), ""); nayms.createUpgrade(upgradeId2); assertEq(block.timestamp + 1 days, nayms.getUpgrade(upgradeId2)); } diff --git a/test/defaults/D01Deployment.sol b/test/defaults/D01Deployment.sol index 03f480ed..37c38ce9 100644 --- a/test/defaults/D01Deployment.sol +++ b/test/defaults/D01Deployment.sol @@ -136,15 +136,19 @@ abstract contract D01Deployment is D00GlobalDefaults, Test { } function scheduleAndUpgradeDiamond(IDiamondCut.FacetCut[] memory _cut, address _init, bytes memory _calldata) internal { - // 1. schedule upgrade - // 2. upgrade - bytes32 upgradeHash = LibGovernance._calculateUpgradeId(_cut, _init, _calldata); + bytes32[] memory codeHashes = new bytes32[](_cut.length); + for (uint i; i < _cut.length; i++) { + codeHashes[i] = LibGovernance._getCodeHash(_cut[i].facetAddress); + } + bytes32 upgradeHash = LibGovernance._calculateUpgradeId(codeHashes, _init, _calldata); if (upgradeHash == 0xc597f3eb22d11c46f626cd856bd65e9127b04623d83e442686776a2e3b670bbf) { c.log("There are no facets to upgrade. This hash is the keccak256 hash of an empty IDiamondCut.FacetCut[]"); } else { changePrank(systemAdmin); + // 1. schedule upgrade nayms.createUpgrade(upgradeHash); changePrank(owner); + // 2. upgrade nayms.diamondCut(_cut, _init, _calldata); changePrank(systemAdmin); }