Skip to content

Commit

Permalink
Registry checks for hooks (#67)
Browse files Browse the repository at this point in the history
* [#46] Implement Guard interface in SafeProtocolManager

* [#47] Create FunctionHandlerManager.sol and inherit in SafeProtocolManager

* [#47] Create BaseManager contract, rename modifier, rename error, check registry while adding function handler

* [#47] Update natspec doc

* [#46] Setup Safe

* [#46] Fix EOF

* [#46] User setupTest function

* [#46] Fix lint issue

* [#46] Add tests

* [#46] Add test with delegateCall for hooks flow

* [#46] User temporary variable for storing hooks address

* [#47] Implement logic for non-static calls to function handler manager, test to set function handler

* [#47] Add tests for Function Handler

* [#47] Pass sender address in handle function

* [#47] Fix test

* [#47] Use ZeroAddress from ethers

* [#46] Reset tempHooksAddress

* [#47] Test static call to function handler

* [#47] Fix lint issue

* [#47] Fix typo

* [#46] Refactor tests for SafeProtocolManager as Guard

* [#46] Fix failing test

* [#46] Update comment

* [#47] Update tests for Function Handler

* [#47] Update tests for function handler

* [#47] Remove test function handler from .solcover.js

* [#47] Return data from handle function

* [#47] Verify call data passed to handle(...)

* [#47] Update doc string

* [#47] Check if function handler is whitelisted

* [#47] Make fallback function non-payable, optimize codesize

* [#47] Fix lint issue

* [#53] Registry checks for hooks

* [#53] Remove chained assignments

* [#53] Fix failing tests

* [#53] Remove chained assignments

* [#53] Update comment on using temp hooks address

* [#53] Update data sent as parameter to pre-check hook in checkModuleTransaction, update natspec
  • Loading branch information
akshay-ap authored Aug 22, 2023
1 parent ea7f2c3 commit c8c494d
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 25 deletions.
33 changes: 23 additions & 10 deletions contracts/SafeProtocolManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,10 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana
bytes memory signatures,
address msgSender
) external {
// Store hooks address in tempHooksAddress so that checkAfterExecution(...) and checkModuleTransaction(...) can access it.
address tempHooksAddressForSafe = tempHooksAddress[msg.sender] = enabledHooks[msg.sender];
// Store hooks address in tempHooksAddress so that checkAfterExecution(...) can access it.
// A temprary storage is required to use old hooks in checkAfterExecution if hooks get updated in between transaction
tempHooksAddress[msg.sender] = enabledHooks[msg.sender];
address tempHooksAddressForSafe = enabledHooks[msg.sender];

if (tempHooksAddressForSafe == address(0)) return;
bytes memory executionMetadata = abi.encode(
Expand Down Expand Up @@ -372,34 +374,45 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana
tempHooksAddress[msg.sender] = address(0);
}

/**
* @notice This function is introduced in Safe contracts v1.5 and used for checking module transactions when a guard is enabled.
* This function will be called when executing a transaction from a module with Safe v1.5 and Manager enabled as Guard on Safe.
* @param to The address to which the transaction is intended.
* @param value The value of the transaction in Wei.
* @param data The transaction data.
* @param operation The type of operation of the transaction.
* @param module The module involved in the transaction.
* @return moduleTxHash The hash of the module transaction.
*/
function checkModuleTransaction(
address to,
uint256 value,
bytes memory data,
Enum.Operation operation,
address module /* onlyPermittedPlugin(module) uncomment this? */ // Use term plugin?
) external returns (bytes32 moduleTxHash) {
// Store hooks address in tempHooksAddress so that checkAfterExecution(...) and checkModuleTransaction(...) can access it.
address tempHooksAddressForSafe = tempHooksAddress[msg.sender] = enabledHooks[msg.sender];

bytes memory executionMetadata = abi.encode(to, value, data, operation, module);
// Store hooks address in tempHooksAddress so that checkAfterExecution(...) can access it.
// A temprary storage is required to use old hooks in checkAfterExecution if hooks get updated in between transaction
tempHooksAddress[msg.sender] = enabledHooks[msg.sender];
address tempHooksAddressForSafe = enabledHooks[msg.sender];

if (tempHooksAddressForSafe == address(0)) return keccak256(executionMetadata);
moduleTxHash = keccak256(abi.encode(to, value, data, operation, module));
if (tempHooksAddressForSafe == address(0)) return moduleTxHash;

if (operation == Enum.Operation.Call) {
SafeProtocolAction[] memory actions = new SafeProtocolAction[](1);
actions[0] = SafeProtocolAction(payable(to), value, data);
SafeTransaction memory safeTx = SafeTransaction(actions, 0, "");
ISafeProtocolHooks(tempHooksAddressForSafe).preCheck(ISafe(msg.sender), safeTx, 0, executionMetadata);
ISafeProtocolHooks(tempHooksAddressForSafe).preCheck(ISafe(msg.sender), safeTx, 1, abi.encode(module));
} else {
// Using else instead of "else if(operation == Enum.Operation.DelegateCall)" to reduce gas usage
// and Safe allows only Call and DelegateCall operations.
SafeProtocolAction memory action = SafeProtocolAction(payable(to), value, data);
SafeRootAccess memory safeTx = SafeRootAccess(action, 0, "");
ISafeProtocolHooks(tempHooksAddressForSafe).preCheckRootAccess(ISafe(msg.sender), safeTx, 0, executionMetadata);
ISafeProtocolHooks(tempHooksAddressForSafe).preCheckRootAccess(ISafe(msg.sender), safeTx, 1, abi.encode(module));
}

return keccak256(executionMetadata);
return moduleTxHash;
}

function supportsInterface(bytes4 interfaceId) external view virtual override returns (bool) {
Expand Down
13 changes: 7 additions & 6 deletions contracts/base/HooksManager.sol
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// SPDX-License-Identifier: LGPL-3.0-only
pragma solidity ^0.8.18;
import {ISafeProtocolHooks} from "../interfaces/Integrations.sol";

import {RegistryManager} from "./RegistryManager.sol";
import {OnlyAccountCallable} from "./OnlyAccountCallable.sol";

contract HooksManager is OnlyAccountCallable {
abstract contract HooksManager is RegistryManager, OnlyAccountCallable {
mapping(address => address) public enabledHooks;

/// @notice This variable should store the address of the hooks contract whenever
Expand All @@ -14,9 +16,6 @@ contract HooksManager is OnlyAccountCallable {
// Events
event HooksChanged(address indexed safe, address indexed hooksAddress);

// Errors
error AddressDoesNotImplementHooksInterface(address hooksAddress);

/**
* @notice Returns the address of hooks for a Safe account provided as a fucntion parameter.
* Returns address(0) is no hooks are enabled.
Expand All @@ -32,8 +31,10 @@ contract HooksManager is OnlyAccountCallable {
* @param hooks Address of the hooks to be enabled for msg.sender.
*/
function setHooks(address hooks) external onlyAccount {
if (hooks != address(0) && !ISafeProtocolHooks(hooks).supportsInterface(type(ISafeProtocolHooks).interfaceId)) {
revert AddressDoesNotImplementHooksInterface(hooks);
if (hooks != address(0)) {
checkPermittedIntegration(hooks);
if (!ISafeProtocolHooks(hooks).supportsInterface(type(ISafeProtocolHooks).interfaceId))
revert AccountDoesNotImplementValidInterfaceId(hooks);
}
enabledHooks[msg.sender] = hooks;
emit HooksChanged(msg.sender, hooks);
Expand Down
1 change: 0 additions & 1 deletion contracts/test/TestExecutor.sol
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// SPDX-License-Identifier: LGPL-3.0-only
pragma solidity ^0.8.18;
import {ISafe} from "../interfaces/Accounts.sol";
import {MockContract} from "@safe-global/mock-contract/contracts/MockContract.sol";

contract TestExecutor is ISafe {
address public module;
Expand Down
64 changes: 61 additions & 3 deletions test/SafeProtocolManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { buildRootTx, buildSingleTx } from "./utils/builder";
import { getHooksWithFailingPrechecks, getHooksWithPassingChecks, getHooksWithFailingPostCheck } from "./utils/mockHooksBuilder";
import { IntegrationType } from "./utils/constants";
import { getInstance } from "./utils/contracts";
import { SafeProtocolManager } from "../typechain-types";
import { MockContract, SafeProtocolManager } from "../typechain-types";

describe("SafeProtocolManager", async () => {
let deployer: SignerWithAddress, owner: SignerWithAddress, user1: SignerWithAddress, user2: SignerWithAddress;
Expand Down Expand Up @@ -475,6 +475,7 @@ describe("SafeProtocolManager", async () => {
const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture);
// Enable hooks on a safe
const hooks = await getHooksWithPassingChecks();
await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks);
const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]);
await safe.exec(safe.target, 0, dataSetHooks);

Expand Down Expand Up @@ -509,6 +510,7 @@ describe("SafeProtocolManager", async () => {
const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture);
// Enable hooks on a safe
const hooks = await getHooksWithFailingPrechecks();
await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks);

const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]);
await safe.exec(safe.target, 0, dataSetHooks);
Expand All @@ -529,6 +531,7 @@ describe("SafeProtocolManager", async () => {
const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture);
// Enable hooks on a safe
const hooks = await getHooksWithFailingPostCheck();
await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks);

const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]);
await safe.exec(safe.target, 0, dataSetHooks);
Expand Down Expand Up @@ -671,8 +674,11 @@ describe("SafeProtocolManager", async () => {
it("Should execute a transaction from root access enabled plugin with hooks enabled", async () => {
const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture);
const safeAddress = await safe.getAddress();

// Enable hooks on a safe
const hooks = await getHooksWithPassingChecks();
await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks);

const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]);
await safe.exec(safe.target, 0, dataSetHooks);

Expand Down Expand Up @@ -716,6 +722,7 @@ describe("SafeProtocolManager", async () => {
const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture);
// Enable hooks on a safe
const hooks = await getHooksWithFailingPrechecks();
await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks);

const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]);
await safe.exec(safe.target, 0, dataSetHooks);
Expand Down Expand Up @@ -748,6 +755,8 @@ describe("SafeProtocolManager", async () => {

// Enable hooks on a safe
const hooks = await getHooksWithFailingPostCheck();
await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks);

const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]);
await safe.exec(safe.target, 0, dataSetHooks);

Expand Down Expand Up @@ -928,6 +937,7 @@ describe("SafeProtocolManager", async () => {

await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks);
await safeProtocolRegistry.connect(owner).addIntegration(hooksWithFailingPreChecks.target, IntegrationType.Hooks);
await safeProtocolRegistry.connect(owner).addIntegration(hooksWithFailingPostCheck.target, IntegrationType.Hooks);

return { safe, safeProtocolManager, hooks, hooksWithFailingPreChecks, hooksWithFailingPostCheck };
});
Expand Down Expand Up @@ -1062,6 +1072,15 @@ describe("SafeProtocolManager", async () => {
]);

expect(await safe.executeCallViaMock(safeProtocolManager.target, 0, execPostChecks, MaxUint256));

// Check if temporary hooks related storage is cleared after tx
expect(await safeProtocolManager.tempHooksAddress.staticCall(safe.target)).to.deep.equal(ZeroAddress);

const mockHooks = await getInstance<MockContract>("MockContract", hooks.target);
// Pre-check hooks calls
expect(await mockHooks.invocationCountForMethod("0x176ae7b7")).to.equal(1);
const postCheckCallData = hooks.interface.encodeFunctionData("postCheck", [safe.target, true, "0x"]);
expect(await mockHooks.invocationCountForCalldata(postCheckCallData)).to.equal(1);
});

it("Should pass hooks checks for module transaction with call operation", async () => {
Expand All @@ -1084,8 +1103,25 @@ describe("SafeProtocolManager", async () => {
hre.ethers.randomBytes(32),
true,
]);

expect(await safe.executeCallViaMock(safeProtocolManager.target, 0, execPostChecks, MaxUint256));

// Check if temporary hooks related storage is cleared after tx
expect(await safeProtocolManager.tempHooksAddress.staticCall(safe.target)).to.deep.equal(ZeroAddress);

const mockHooks = await getInstance<MockContract>("MockContract", hooks.target);
// preCheck hooks calls
const safeTx = buildSingleTx(user2.address, 0n, "0x", 0n, hre.ethers.ZeroHash);
const preCheckCalldata = hooks.interface.encodeFunctionData("preCheck", [
safe.target,
safeTx,
1,
hre.ethers.AbiCoder.defaultAbiCoder().encode(["address"], [ZeroAddress]),
]);
expect(await mockHooks.invocationCountForMethod("0x176ae7b7")).to.equal(1);
expect(await mockHooks.invocationCountForCalldata(preCheckCalldata)).to.equal(1);
const postCheckCallData = hooks.interface.encodeFunctionData("postCheck", [safe.target, true, "0x"]);

expect(await mockHooks.invocationCountForCalldata(postCheckCallData)).to.equal(1);
});

it("Should pass hooks checks for module transaction with delegateCall operation", async () => {
Expand All @@ -1110,9 +1146,25 @@ describe("SafeProtocolManager", async () => {
]);

expect(await safe.executeCallViaMock(safeProtocolManager.target, 0, execPostChecks, MaxUint256));

const mockHooks = await getInstance<MockContract>("MockContract", hooks.target);
// preCheck hooks calls
const safeTx = buildRootTx(user2.address, 0n, "0x", 0n, hre.ethers.ZeroHash);
const preCheckCalldata = hooks.interface.encodeFunctionData("preCheckRootAccess", [
safe.target,
safeTx,
1,
hre.ethers.AbiCoder.defaultAbiCoder().encode(["address"], [ZeroAddress]),
]);
// 0x7359b742 -> preCheckRootAccess function signature
expect(await mockHooks.invocationCountForMethod("0x7359b742")).to.equal(1);
expect(await mockHooks.invocationCountForCalldata(preCheckCalldata)).to.equal(1);
const postCheckCallData = hooks.interface.encodeFunctionData("postCheck", [safe.target, true, "0x"]);

expect(await mockHooks.invocationCountForCalldata(postCheckCallData)).to.equal(1);
});

it("Should execute pass hooks checks for delegateCall operation", async () => {
it("Should pass hooks checks for delegateCall operation", async () => {
const { safe, safeProtocolManager, hooks } = await setupTests();
// Set Hooks contract for the Safe
const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [hooks.target]);
Expand All @@ -1139,6 +1191,12 @@ describe("SafeProtocolManager", async () => {
]);

expect(await safe.executeCallViaMock(safeProtocolManager.target, 0, execPostChecks, MaxUint256));

const mockHooks = await getInstance<MockContract>("MockContract", hooks.target);
// preCheckRootAccess hooks calls
expect(await mockHooks.invocationCountForMethod("0x7359b742")).to.equal(1);
const postCheckCallData = hooks.interface.encodeFunctionData("postCheck", [safe.target, true, "0x"]);
expect(await mockHooks.invocationCountForCalldata(postCheckCallData)).to.equal(1);
});

it("uses old hooks in checkAfterExecution if hooks get updated in between transactions", async () => {
Expand Down
14 changes: 9 additions & 5 deletions test/base/HooksManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { expect } from "chai";
import { SignerWithAddress } from "@nomicfoundation/hardhat-ethers/signers";
import { getHooksWithPassingChecks } from "../utils/mockHooksBuilder";
import { ZeroAddress } from "ethers";
import { IntegrationType } from "../utils/constants";

describe("HooksManager", async () => {
let deployer: SignerWithAddress, user1: SignerWithAddress, owner: SignerWithAddress;
Expand All @@ -21,8 +22,9 @@ describe("HooksManager", async () => {

const safe = await hre.ethers.deployContract("TestExecutor", [hooksManager.target], { signer: deployer });
const hooks = await getHooksWithPassingChecks();
await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks);

return { hooksManager, hooks, safe };
return { hooksManager, hooks, safe, safeProtocolRegistry };
});

it("Should emit HooksChanged event when hooks are enabled", async () => {
Expand Down Expand Up @@ -65,15 +67,17 @@ describe("HooksManager", async () => {
await expect(hooksManager.setHooks(hooksAddress)).to.be.reverted;
});

it("Should revert AddressDoesNotImplementHooksInterface if user attempts address does not implement Hooks interface", async () => {
const { hooksManager, safe } = await setupTests();
it("Should revert AccountDoesNotImplementValidInterfaceId if user attempts address does not implement Hooks interface", async () => {
const { hooksManager, safe, safeProtocolRegistry } = await setupTests();
const contractNotImplementingHooksInterface = await (await hre.ethers.getContractFactory("MockContract")).deploy();
await contractNotImplementingHooksInterface.givenMethodReturnBool("0x01ffc9a7", false);
await contractNotImplementingHooksInterface.givenMethodReturnBool("0x01ffc9a7", true);
await safeProtocolRegistry.connect(owner).addIntegration(contractNotImplementingHooksInterface.target, IntegrationType.Hooks);

await contractNotImplementingHooksInterface.givenMethodReturnBool("0x01ffc9a7", false);
const calldata = hooksManager.interface.encodeFunctionData("setHooks", [contractNotImplementingHooksInterface.target]);
await expect(safe.exec(safe.target, 0n, calldata)).to.be.revertedWithCustomError(
hooksManager,
"AddressDoesNotImplementHooksInterface",
"AccountDoesNotImplementValidInterfaceId",
);
});
});

0 comments on commit c8c494d

Please sign in to comment.