Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hint): widemul128 implementation #137

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cairo_programs/cairo/hints/wide_mul128.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn main() {
let _ = 0x123456_u128 * 0xFEDCBA_u128;
}
5 changes: 5 additions & 0 deletions src/hints/hintHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import {
shouldSkipSquashLoop,
} from './dict/shouldSkipSquashLoop';
import { TestLessThan, testLessThan } from './math/testLessThan';
import { WideMul128, wideMul128 } from './math/wideMul128';

/**
* Map hint names to the function executing their logic.
Expand Down Expand Up @@ -125,4 +126,8 @@ export const handlers: HintHandler = {
const h = hint as TestLessThan;
testLessThan(vm, h.lhs, h.rhs, h.dst);
},
[HintName.WideMul128]: (vm, hint) => {
const h = hint as WideMul128;
wideMul128(vm, h.lhs, h.rhs, h.high, h.low);
},
};
1 change: 1 addition & 0 deletions src/hints/hintName.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ export enum HintName {
ShouldContinueSquashLoop = 'ShouldContinueSquashLoop',
ShouldSkipSquashLoop = 'ShouldSkipSquashLoop',
TestLessThan = 'TestLessThan',
WideMul128 = 'WideMul128',
}
2 changes: 2 additions & 0 deletions src/hints/hintSchema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { initSquashDataParser } from './dict/initSquashData';
import { shouldContinueSquashLoopParser } from './dict/shouldContinueSquashLoop';
import { shouldSkipSquashLoopParser } from './dict/shouldSkipSquashLoop';
import { testLessThanParser } from './math/testLessThan';
import { wideMul128Parser } from './math/wideMul128';

/** Zod object to parse any implemented hints */
const hint = z.union([
Expand All @@ -33,6 +34,7 @@ const hint = z.union([
shouldContinueSquashLoopParser,
shouldSkipSquashLoopParser,
testLessThanParser,
wideMul128Parser,
]);

/** Zod object to parse an array of hints grouped on a given PC */
Expand Down
96 changes: 96 additions & 0 deletions src/hints/math/wideMul128.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import { describe, expect, test } from 'bun:test';
import { Felt } from 'primitives/felt';
import { VirtualMachine } from 'vm/virtualMachine';
import { HintName } from 'hints/hintName';
import { OpType } from 'hints/hintParamsSchema';
import { wideMul128Parser } from './wideMul128';
import { wideMul128 } from './wideMul128';
import { Register } from 'vm/instruction';

const WIDE_MUL_128_HINT = {
WideMul128: {
lhs: {
Deref: {
register: 'AP',
offset: 0,
},
},
rhs: {
Deref: {
register: 'AP',
offset: 1,
},
},
high: {
register: 'AP',
offset: 2,
},
low: {
register: 'AP',
offset: 3,
},
},
};

describe('WideMul128', () => {
test('should properly parse WideMul128 hint', () => {
const hint = wideMul128Parser.parse(WIDE_MUL_128_HINT);
expect(hint).toEqual({
type: HintName.WideMul128,
lhs: {
type: OpType.Deref,
cell: {
register: Register.Ap,
offset: 0,
},
},
rhs: {
type: OpType.Deref,
cell: {
register: Register.Ap,
offset: 1,
},
},
high: {
register: Register.Ap,
offset: 2,
},
low: {
register: Register.Ap,
offset: 3,
},
});
});

test.each([
[new Felt(2n), new Felt(3n), new Felt(0n), new Felt(6n)],
[new Felt(1n << 64n), new Felt(1n << 64n), new Felt(1n), new Felt(0n)],
[
new Felt((1n << 63n) - 1n),
new Felt(2n),
new Felt(0n),
new Felt((1n << 64n) - 2n),
],
[new Felt(1n << 127n), new Felt(2n), new Felt(1n), new Felt(0n)],
])(
'should properly execute WideMul128 hint',
(lhsValue, rhsValue, highExpected, lowExpected) => {
const hint = wideMul128Parser.parse(WIDE_MUL_128_HINT);
const vm = new VirtualMachine();
vm.memory.addSegment();
vm.memory.addSegment();

vm.memory.assertEq(vm.ap, lhsValue);
vm.memory.assertEq(vm.ap.add(1), rhsValue);

wideMul128(vm, hint.lhs, hint.rhs, hint.high, hint.low);

expect(vm.memory.get(vm.cellRefToRelocatable(hint.high))).toEqual(
highExpected
);
expect(vm.memory.get(vm.cellRefToRelocatable(hint.low))).toEqual(
lowExpected
);
}
);
});
64 changes: 64 additions & 0 deletions src/hints/math/wideMul128.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import { z } from 'zod';

import { VirtualMachine } from 'vm/virtualMachine';

import {
resOperand,
ResOperand,
cellRef,
CellRef,
} from 'hints/hintParamsSchema';
import { HintName } from 'hints/hintName';
import { Felt } from 'primitives/felt';

/** Zod object to parse WideMul128 hint */
export const wideMul128Parser = z
.object({
WideMul128: z.object({
lhs: resOperand,
rhs: resOperand,
high: cellRef,
low: cellRef,
}),
})
.transform(({ WideMul128: { lhs, rhs, high, low } }) => ({
type: HintName.WideMul128,
lhs,
rhs,
high,
low,
}));

/**
* WideMul128 hint type
*/
export type WideMul128 = z.infer<typeof wideMul128Parser>;

/**
* Perform 128-bit multiplication and store high and low parts.
*
* @param {VirtualMachine} vm - The virtual machine instance
* @param {ResOperand} lhs - The left-hand side operand
* @param {ResOperand} rhs - The right-hand side operand
* @param {CellRef} high - The address to store the high part of the product
* @param {CellRef} low - The address to store the low part of the product
*/
export const wideMul128 = (
vm: VirtualMachine,
lhs: ResOperand,
rhs: ResOperand,
high: CellRef,
low: CellRef
) => {
const lhsValue = vm.getResOperandValue(lhs).toBigInt();
const rhsValue = vm.getResOperandValue(rhs).toBigInt();

const product = lhsValue * rhsValue;
const mask = (1n << 128n) - 1n;

const highValue = new Felt(product >> 128n);
const lowValue = new Felt(product & mask);

vm.memory.assertEq(vm.cellRefToRelocatable(high), highValue);
vm.memory.assertEq(vm.cellRefToRelocatable(low), lowValue);
};