Skip to content

Commit

Permalink
Support for Cairo builtins (#2718)
Browse files Browse the repository at this point in the history
This PR implements generic support for Cairo VM builtins. The calling
convention in the generated CASM code is changed to allow for passing
around the builtin pointers. Appropriate builtin initialization and
finalization code is added. Support for specific builtins (e.g. Poseidon
hash, range check, Elliptic Curve operation) still needs to be
implemented in separate PRs.

* Closes #2683
  • Loading branch information
lukaszcz authored Apr 16, 2024
1 parent 65176a3 commit ad76c7a
Show file tree
Hide file tree
Showing 15 changed files with 201 additions and 65 deletions.
2 changes: 1 addition & 1 deletion app/Commands/Dev/Casm/Compile.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ runCommand opts = do
runReader entryPoint
. runError @JuvixError
. casmToCairo
$ Casm.Result labi code
$ Casm.Result labi code []
res <- getRight r
liftIO $ JSON.encodeFile (toFilePath cairoFile) res
where
Expand Down
8 changes: 6 additions & 2 deletions runtime/src/casm/stdlib.casm
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ juvix_get_ap_reg:

-- [fp - 3]: closure
-- [fp - 4]: n = the number of arguments to extend with
-- [fp - 4 - k]: argument n - k - 1 (reverse order!)
-- [fp - 4 - k]: argument n - k - 1 (reverse order!) (k is 0-based)
juvix_extend_closure:
-- copy stored args reversing them;
-- to copy the stored args to the new closure
Expand Down Expand Up @@ -95,10 +95,14 @@ juvix_extend_closure:
[ap] = [fp + 15]; ap++
ret

-- [fp - 3]: closure; [fp - 3 - k]: argument k to closure call
-- [fp - 3]: closure;
-- [fp - 4 - k]: argument k to closure call (0-based)
-- [fp - 4 - n]: builtin pointer, where n = number of supplied args
juvix_call_closure:
-- jmp rel (9 - argsnum)
jmp rel [[fp - 3] + 2]
-- builtin ptr + args
[ap] = [fp - 12]; ap++
[ap] = [fp - 11]; ap++
[ap] = [fp - 10]; ap++
[ap] = [fp - 9]; ap++
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_cairo_vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

BASE=`basename "$1" .json`

juvix-cairo-vm "$@" --print_output --proof_mode --trace_file ${BASE}.trace --air_public_input=${BASE}_public_input.json --air_private_input=${BASE}_private_input.json --memory_file=${BASE}_memory.mem --layout=small
juvix-cairo-vm "$@" --print_output --proof_mode --trace_file ${BASE}.trace --air_public_input=${BASE}_public_input.json --air_private_input=${BASE}_private_input.json --memory_file=${BASE}_memory.mem --layout=all_cairo
36 changes: 24 additions & 12 deletions src/Juvix/Compiler/Backend/Cairo/Extra/Serialization.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,29 @@ import Juvix.Compiler.Backend.Cairo.Data.Result
import Juvix.Compiler.Backend.Cairo.Language
import Numeric

serialize :: [Element] -> Result
serialize elems =
serialize :: [Text] -> [Element] -> Result
serialize builtins elems =
Result
{ _resultData =
initializeOutput
initializeBuiltins
++ map toHexText (serialize' elems)
++ finalizeOutput
++ finalizeBuiltins
++ finalizeJump,
_resultStart = 0,
_resultEnd = length initializeOutput + length elems + length finalizeOutput,
_resultEnd = length initializeBuiltins + length elems + length finalizeBuiltins,
_resultMain = 0,
_resultHints = hints,
_resultBuiltins = ["output"]
_resultBuiltins = "output" : builtins
}
where
builtinsNum :: Natural
builtinsNum = fromIntegral (length builtins)

hints :: [(Int, Text)]
hints = catMaybes $ zipWith mkHint elems [0 ..]

pcShift :: Int
pcShift = length initializeOutput
pcShift = length initializeBuiltins

mkHint :: Element -> Int -> Maybe (Int, Text)
mkHint el pc = case el of
Expand All @@ -34,21 +37,30 @@ serialize elems =
toHexText :: Natural -> Text
toHexText n = "0x" <> fromString (showHex n "")

initializeOutput :: [Text]
initializeOutput =
initializeBuiltins :: [Text]
initializeBuiltins =
-- ap += allBuiltinsNum
[ "0x40480017fff7fff",
"0x1"
toHexText (builtinsNum + 1)
]

finalizeOutput :: [Text]
finalizeOutput =
finalizeBuiltins :: [Text]
finalizeBuiltins =
-- [[fp]] = [ap - 1] -- [output_ptr] = [ap - 1]
-- [ap] = [fp] + 1; ap++ -- output_ptr
[ "0x4002800080007fff",
"0x4826800180008000",
"0x1"
]
++
-- [ap] = [ap - builtinsNum - 2]; ap++
replicate
builtinsNum
(toHexText (0x48107ffe7fff8000 - shift builtinsNum 32))

finalizeJump :: [Text]
finalizeJump =
-- jmp rel 0
[ "0x10780017fff7fff",
"0x0"
]
Expand Down
27 changes: 27 additions & 0 deletions src/Juvix/Compiler/Casm/Data/Builtins.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
module Juvix.Compiler.Casm.Data.Builtins where

import Juvix.Extra.Strings qualified as Str
import Juvix.Prelude

-- The order of the builtins must correspond to the "standard" builtin order in
-- the Cairo VM implementation. See:
-- https://github.com/lambdaclass/cairo-vm/blob/main/vm/src/vm/runners/cairo_runner.rs#L257
data Builtin
= BuiltinRangeCheck
| BuiltinEcOp
| BuiltinPoseidon
deriving stock (Show, Eq, Generic, Enum, Bounded)

instance Hashable Builtin

builtinsNum :: Int
builtinsNum = length (allElements @Builtin)

builtinName :: Builtin -> Text
builtinName = \case
BuiltinRangeCheck -> Str.cairoRangeCheck
BuiltinEcOp -> Str.cairoEcOp
BuiltinPoseidon -> Str.cairoPoseidon

builtinNames :: [Text]
builtinNames = map builtinName allElements
6 changes: 5 additions & 1 deletion src/Juvix/Compiler/Casm/Data/LabelInfoBuilder.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
module Juvix.Compiler.Casm.Data.LabelInfoBuilder where
module Juvix.Compiler.Casm.Data.LabelInfoBuilder
( module Juvix.Compiler.Casm.Data.LabelInfo,
module Juvix.Compiler.Casm.Data.LabelInfoBuilder,
)
where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Casm.Data.LabelInfo
Expand Down
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Casm/Data/Result.hs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
module Juvix.Compiler.Casm.Data.Result where

import Juvix.Compiler.Casm.Data.Builtins
import Juvix.Compiler.Casm.Data.LabelInfo
import Juvix.Compiler.Casm.Language

data Result = Result
{ _resultLabelInfo :: LabelInfo,
_resultCode :: [Instruction]
_resultCode :: [Instruction],
_resultBuiltins :: [Builtin]
}

makeLenses ''Result
5 changes: 4 additions & 1 deletion src/Juvix/Compiler/Casm/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Data.HashMap.Strict qualified as HashMap
import Data.Vector qualified as Vec
import Data.Vector.Mutable qualified as MV
import GHC.IO qualified as GHC
import Juvix.Compiler.Casm.Data.Builtins
import Juvix.Compiler.Casm.Data.InputInfo
import Juvix.Compiler.Casm.Data.LabelInfo
import Juvix.Compiler.Casm.Error
Expand Down Expand Up @@ -39,7 +40,9 @@ hRunCode hout inputInfo (LabelInfo labelInfo) instrs0 = runST goCode
goCode :: ST s FField
goCode = do
mem <- MV.replicate initialMemSize Nothing
go 0 0 0 mem
forM_ [0 .. builtinsNum] $ \k ->
MV.write mem k (Just (fieldFromInteger cairoFieldSize 0))
go 0 (builtinsNum + 1) 0 mem

go ::
Address ->
Expand Down
7 changes: 6 additions & 1 deletion src/Juvix/Compiler/Casm/Translation/FromCairo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ import Juvix.Compiler.Casm.Data.Result
import Juvix.Compiler.Casm.Language

fromCairo :: [Cairo.Element] -> Result
fromCairo elems0 = Result mempty (go 0 [] elems0)
fromCairo elems0 =
Result
{ _resultLabelInfo = mempty,
_resultCode = go 0 [] elems0,
_resultBuiltins = mempty
}
where
errorMsg :: Address -> Text -> a
errorMsg addr msg = error ("error at address " <> show addr <> ": " <> msg)
Expand Down
Loading

0 comments on commit ad76c7a

Please sign in to comment.