Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement message size validation to prevent excessive payloads #1197

Merged
merged 7 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/sdk/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,5 @@ export const EXTENSION_EVENTS = {
CONNECT: 'connect',
CONNECTED: 'connected',
};

export const MAX_MESSAGE_LENGTH = 1_000_000; // 1MB limit
abretonc7s marked this conversation as resolved.
Show resolved Hide resolved
47 changes: 47 additions & 0 deletions packages/sdk/src/services/MobilePortStream/write.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Buffer } from 'buffer';
import { MAX_MESSAGE_LENGTH } from '../../config';
import { write } from './write';

describe('write function', () => {
Expand Down Expand Up @@ -77,4 +78,50 @@ describe('write function', () => {
new Error('MobilePortStream - disconnected'),
);
});

describe('Message Size Validation', () => {
beforeEach(() => {
jest.clearAllMocks();
global.window = {
location: { href: 'http://example.com' },
ReactNativeWebView: { postMessage: mockPostMessage },
} as any;
});

it('should reject messages exceeding MAX_MESSAGE_LENGTH', () => {
const largeData = {
data: {
jsonrpc: '2.0',
method: 'test_method',
params: ['x'.repeat(MAX_MESSAGE_LENGTH)],
},
};

write(largeData, 'utf-8', cb);

expect(cb).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.stringMatching(
/Message size \d+ exceeds maximum allowed size of \d+ bytes/u,
),
}),
);
expect(mockPostMessage).not.toHaveBeenCalled();
});

it('should accept messages within MAX_MESSAGE_LENGTH', () => {
const validData = {
data: {
jsonrpc: '2.0',
method: 'test_method',
params: ['x'.repeat(100)],
},
};

write(validData, 'utf-8', cb);

expect(cb).toHaveBeenCalledWith();
expect(mockPostMessage).toHaveBeenCalled();
});
});
});
24 changes: 19 additions & 5 deletions packages/sdk/src/services/MobilePortStream/write.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Buffer } from 'buffer';
import { MAX_MESSAGE_LENGTH } from '../../config';

/**
* Handles communication between the in-app browser and MetaMask mobile application.
Expand All @@ -15,6 +16,7 @@ export function write(
cb: (error?: Error | null) => void,
) {
try {
let stringifiedData: string;
if (Buffer.isBuffer(chunk)) {
const data: {
type: 'Buffer';
Expand All @@ -23,18 +25,30 @@ export function write(
} = chunk.toJSON();

data._isBuffer = true;
window.ReactNativeWebView?.postMessage(
JSON.stringify({ ...data, origin: window.location.href }),
);
stringifiedData = JSON.stringify({
...data,
origin: window.location.href,
});
} else {
if (chunk.data) {
chunk.data.toNative = true;
}

window.ReactNativeWebView?.postMessage(
JSON.stringify({ ...chunk, origin: window.location.href }),
stringifiedData = JSON.stringify({
...chunk,
origin: window.location.href,
});
}

if (stringifiedData.length > MAX_MESSAGE_LENGTH) {
return cb(
new Error(
`Message size ${stringifiedData.length} exceeds maximum allowed size of ${MAX_MESSAGE_LENGTH} bytes`,
),
);
}

window.ReactNativeWebView?.postMessage(stringifiedData);
} catch (err) {
return cb(new Error('MobilePortStream - disconnected'));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { Ethereum } from '../Ethereum'; // Adjust the import based on your project structure
import { RemoteCommunicationPostMessageStream } from '../../PostMessageStream/RemoteCommunicationPostMessageStream'; // Adjust the import based on your project structure
import { METHODS_TO_REDIRECT } from '../../config';
import { MAX_MESSAGE_LENGTH, METHODS_TO_REDIRECT } from '../../config';
import * as loggerModule from '../../utils/logger'; // Adjust the import based on your project structure
import { write } from './write'; // Adjust the import based on your project structure
import { Ethereum } from '../Ethereum'; // Adjust the import based on your project structure
import { extractMethod } from './extractMethod';
import { write } from './write'; // Adjust the import based on your project structure

jest.mock('./extractMethod');
jest.mock('../Ethereum');
Expand Down Expand Up @@ -162,11 +162,22 @@ describe('write function', () => {
mockIsMobileWeb.mockReturnValue(false);
mockIsSecure.mockReturnValue(true);
mockGetChannelId.mockReturnValue('some_channel_id');
mockIsMetaMaskInstalled.mockReturnValue(true);
mockGetKeyInfo.mockReturnValue({ ecies: { public: 'test_public_key' } });
mockHasDeeplinkProtocol.mockReturnValue(false);
});

it('should redirect if method exists in METHODS_TO_REDIRECT', async () => {
mockExtractMethod.mockReturnValue({
method: Object.keys(METHODS_TO_REDIRECT)[0],
data: {
data: {
jsonrpc: '2.0',
method: Object.keys(METHODS_TO_REDIRECT)[0],
params: [],
},
},
triggeredInstaller: false,
});

await write(
Expand Down Expand Up @@ -239,4 +250,71 @@ describe('write function', () => {
expect(spyLogger).toHaveBeenCalled();
});
});

describe('Message Size Validation', () => {
it('should reject messages exceeding MAX_MESSAGE_LENGTH', async () => {
mockGetChannelId.mockReturnValue('some_channel_id');
mockIsReady.mockReturnValue(true);
mockIsConnected.mockReturnValue(true);

// Mock extractMethod to return large data
const largeData = {
jsonrpc: '2.0',
method: 'eth_call',
params: ['x'.repeat(MAX_MESSAGE_LENGTH + 1)],
};

mockExtractMethod.mockReturnValue({
method: 'eth_call',
data: {
data: largeData,
},
});

await write(
mockRemoteCommunicationPostMessageStream,
{ jsonrpc: '2.0', method: 'eth_call' },
'utf8',
callback,
);

// Don't test for exact error message, just verify it contains the key parts
expect(callback).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.stringMatching(
/Message size \d+ exceeds maximum allowed size of \d+ bytes/u,
),
}),
);
expect(mockSendMessage).not.toHaveBeenCalled();
});

it('should accept messages within MAX_MESSAGE_LENGTH', async () => {
mockGetChannelId.mockReturnValue('some_channel_id');
mockIsReady.mockReturnValue(true);
mockIsConnected.mockReturnValue(true);

// Mock extractMethod to return valid-sized data
mockExtractMethod.mockReturnValue({
method: 'eth_call',
data: {
data: {
jsonrpc: '2.0',
method: 'eth_call',
params: ['x'.repeat(100)],
},
},
});

await write(
mockRemoteCommunicationPostMessageStream,
{ jsonrpc: '2.0', method: 'eth_call' },
'utf8',
callback,
);

expect(callback).toHaveBeenCalledWith();
expect(mockSendMessage).toHaveBeenCalled();
});
});
});
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { RemoteCommunicationPostMessageStream } from '../../PostMessageStream/RemoteCommunicationPostMessageStream';
import { METHODS_TO_REDIRECT, RPC_METHODS } from '../../config';
import {
METHODS_TO_REDIRECT,
RPC_METHODS,
MAX_MESSAGE_LENGTH,
} from '../../config';
import {
METAMASK_CONNECT_BASE_URL,
METAMASK_DEEPLINK_BASE,
Expand Down Expand Up @@ -57,11 +61,17 @@ export async function write(
deeplinkProtocolAvailable && mobileWeb && authorized;

try {
console.warn(
`[RCPMS: _write()] triggeredInstaller=${triggeredInstaller} activeDeeplinkProtocol=${activeDeeplinkProtocol}`,
);

if (!triggeredInstaller) {
// Check message size before sending
const stringifiedData = JSON.stringify(data?.data);
if (stringifiedData.length > MAX_MESSAGE_LENGTH) {
return callback(
new Error(
`Message size ${stringifiedData.length} exceeds maximum allowed size of ${MAX_MESSAGE_LENGTH} bytes`,
),
);
}

// The only reason not to send via network is because the rpc call will be sent in the deeplink
instance.state.remote
?.sendMessage(data?.data)
Expand Down
Loading