Skip to content

Commit

Permalink
Add an if instruction to JuvixReg (#2855)
Browse files Browse the repository at this point in the history
* Closes #2829
* Adds a transformation which converts `br` to `if` when the variable
branched on was assigned in the previous instruction. The transformation
itself doesn't check liveness and doesn't remove the assignment. Dead
code elimination should be run afterwards to remove the assignment.
* For Cairo, it only makes sense to convert `br` to `if` for equality
comparisons against zero. The assignment before `br` will always become
dead after converting `br` to `if`, because we convert to SSA before.
  • Loading branch information
lukaszcz authored Jun 26, 2024
1 parent 7cfddcf commit 4dcbb00
Show file tree
Hide file tree
Showing 41 changed files with 421 additions and 115 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ repos:
types_or: [json]

- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v17.0.2
rev: v18.1.4
hooks:
- id: clang-format
files: runtime/.+\.(c|h)$
Expand Down
18 changes: 9 additions & 9 deletions runtime/c/src/juvix/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
DECL_TAIL_APPLY_3; \
juvix_program_start:

#define JUVIX_EPILOGUE \
juvix_program_end: \
STACK_POPT; \
IO_INTERPRET; \
#define JUVIX_EPILOGUE \
juvix_program_end : STACK_POPT; \
IO_INTERPRET; \
io_print_toplevel(juvix_result);

// Temporary / local vars
Expand All @@ -45,9 +44,7 @@

// Begin a function definition. `max_stack` is the maximum stack allocation in
// the function.
#define JUVIX_FUNCTION(label, max_stack) \
label: \
STACK_ENTER((max_stack))
#define JUVIX_FUNCTION(label, max_stack) label : STACK_ENTER((max_stack))

/*
Macro sequence for function definition:
Expand All @@ -67,8 +64,7 @@
*/

// Begin a function with no stack allocation.
#define JUVIX_FUNCTION_NS(label) \
label:
#define JUVIX_FUNCTION_NS(label) label:

#define JUVIX_INT_ADD(var0, var1, var2) (var0 = smallint_add(var1, var2))
#define JUVIX_INT_SUB(var0, var1, var2) (var0 = smallint_sub(var1, var2))
Expand All @@ -83,6 +79,10 @@
#define JUVIX_VAL_EQ(var0, var1, var2) \
(var0 = make_bool(juvix_equal(var1, var2)))

#define JUVIX_BOOL_INT_LT(var1, var2) (smallint_lt(var1, var2))
#define JUVIX_BOOL_INT_LE(var1, var2) (smallint_le(var1, var2))
#define JUVIX_BOOL_VAL_EQ(var1, var2) (make_bool(juvix_equal(var1, var2)))

#define JUVIX_STR_CONCAT(var0, var1, var2) CONCAT_CSTRINGS(var0, var1, var2)

#define JUVIX_STR_TO_INT(var0, var1) \
Expand Down
3 changes: 2 additions & 1 deletion runtime/c/src/juvix/defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ static inline void *palign(void *ptr, uintptr_t alignment) {
return (void *)align((uintptr_t)ptr, alignment);
}
// `y` must be a power of 2
#define ASSERT_ALIGNED(x, y) ASSERT(((uintptr_t)(x) & ((uintptr_t)(y)-1)) == 0)
#define ASSERT_ALIGNED(x, y) \
ASSERT(((uintptr_t)(x) & ((uintptr_t)(y) - 1)) == 0)

#if defined(API_LIBC) && defined(DEBUG)
#define LOG(...) fprintf(stderr, __VA_ARGS__)
Expand Down
6 changes: 3 additions & 3 deletions src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ recurse' sig = go True
fixMemIntOp mem
OpIntMod ->
fixMemIntOp mem
OpIntLt ->
OpBool OpIntLt ->
fixMemBinOp' mem mkTypeInteger mkTypeInteger mkTypeBool
OpIntLe ->
OpBool OpIntLe ->
fixMemBinOp' mem mkTypeInteger mkTypeInteger mkTypeBool
OpFieldAdd ->
fixMemFieldOp mem
Expand All @@ -168,7 +168,7 @@ recurse' sig = go True
fixMemFieldOp mem
OpFieldDiv ->
fixMemFieldOp mem
OpEq ->
OpBool OpEq ->
fixMemBinOp' mem TyDynamic TyDynamic mkTypeBool
OpStrConcat ->
fixMemBinOp' mem TyString TyString TyString
Expand Down
6 changes: 3 additions & 3 deletions src/Juvix/Compiler/Asm/Translation/FromSource.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ command = do
"mod" ->
return $ mkBinop' loc OpIntMod
"lt" ->
return $ mkBinop' loc OpIntLt
return $ mkBinop' loc (OpBool OpIntLt)
"le" ->
return $ mkBinop' loc OpIntLe
return $ mkBinop' loc (OpBool OpIntLe)
"eq" ->
return $ mkBinop' loc OpEq
return $ mkBinop' loc (OpBool OpEq)
"fadd" ->
return $ mkBinop' loc OpFieldAdd
"fsub" ->
Expand Down
34 changes: 31 additions & 3 deletions src/Juvix/Compiler/Backend/C/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ fromRegInstr bNoStack info = \case
fromCallClosures x
Reg.Return x ->
return $ fromReturn x
Reg.If x ->
fromIf x
Reg.Branch x ->
fromBranch x
Reg.Case x ->
Expand All @@ -271,16 +273,22 @@ fromRegInstr bNoStack info = \case
fromValue _instrBinopArg2
]

getBoolOpMacro :: Reg.BoolOp -> Text
getBoolOpMacro = \case
Reg.OpIntLt -> "JUVIX_BOOL_INT_LT"
Reg.OpIntLe -> "JUVIX_BOOL_INT_LE"
Reg.OpEq -> "JUVIX_BOOL_VAL_EQ"

getBinaryOpMacro :: Reg.BinaryOp -> Text
getBinaryOpMacro = \case
Reg.OpIntAdd -> "JUVIX_INT_ADD"
Reg.OpIntSub -> "JUVIX_INT_SUB"
Reg.OpIntMul -> "JUVIX_INT_MUL"
Reg.OpIntDiv -> "JUVIX_INT_DIV"
Reg.OpIntMod -> "JUVIX_INT_MOD"
Reg.OpIntLt -> "JUVIX_INT_LT"
Reg.OpIntLe -> "JUVIX_INT_LE"
Reg.OpEq -> "JUVIX_VAL_EQ"
Reg.OpBool Reg.OpIntLt -> "JUVIX_INT_LT"
Reg.OpBool Reg.OpIntLe -> "JUVIX_INT_LE"
Reg.OpBool Reg.OpEq -> "JUVIX_VAL_EQ"
Reg.OpStrConcat -> "JUVIX_STR_CONCAT"
Reg.OpFieldAdd -> unsupported "field type"
Reg.OpFieldSub -> unsupported "field type"
Expand Down Expand Up @@ -504,6 +512,26 @@ fromRegInstr bNoStack info = \case
integer argsNum,
ExpressionVar lab
]
fromIf :: Reg.InstrIf -> Sem r [Statement]
fromIf Reg.InstrIf {..} = do
br1 <- fromRegCode bNoStack info _instrIfTrue
br2 <- fromRegCode bNoStack info _instrIfFalse
return
[ StatementIf $
If
{ _ifCondition =
macroCall
"is_true"
[ macroCall
(getBoolOpMacro _instrIfOp)
[ fromValue _instrIfArg1,
fromValue _instrIfArg2
]
],
_ifThen = StatementCompound br1,
_ifElse = Just (StatementCompound br2)
}
]

fromBranch :: Reg.InstrBranch -> Sem r [Statement]
fromBranch Reg.InstrBranch {..} = do
Expand Down
29 changes: 26 additions & 3 deletions src/Juvix/Compiler/Backend/Rust/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ fromRegInstr info = \case
fromCallClosures x
Reg.Return x ->
fromReturn x
Reg.If x ->
fromIf x
Reg.Branch x ->
fromBranch x
Reg.Case x ->
Expand All @@ -173,16 +175,20 @@ fromRegInstr info = \case
[fromValue _instrBinopArg1, fromValue _instrBinopArg2]
)

getBoolOpName :: Reg.BoolOp -> Text
getBoolOpName = \case
Reg.OpIntLt -> "smallint_lt"
Reg.OpIntLe -> "smallint_le"
Reg.OpEq -> "juvix_equal"

getBinaryOpName :: Reg.BinaryOp -> Text
getBinaryOpName = \case
Reg.OpBool x -> getBoolOpName x
Reg.OpIntAdd -> "smallint_add"
Reg.OpIntSub -> "smallint_sub"
Reg.OpIntMul -> "smallint_mul"
Reg.OpIntDiv -> "smallint_div"
Reg.OpIntMod -> "smallint_mod"
Reg.OpIntLt -> "smallint_lt"
Reg.OpIntLe -> "smallint_le"
Reg.OpEq -> "juvix_equal"
Reg.OpStrConcat -> unsupported "strings"
Reg.OpFieldAdd -> unsupported "field type"
Reg.OpFieldSub -> unsupported "field type"
Expand Down Expand Up @@ -349,6 +355,23 @@ fromRegInstr info = \case
]
]

fromIf :: Reg.InstrIf -> [Statement]
fromIf Reg.InstrIf {..} =
stmtsIf
( mkCall
"word_to_bool"
[ ( mkCall
(getBoolOpName _instrIfOp)
[fromValue _instrIfArg1, fromValue _instrIfArg2]
)
]
)
br1
br2
where
br1 = fromRegCode info _instrIfTrue
br2 = fromRegCode info _instrIfFalse

fromBranch :: Reg.InstrBranch -> [Statement]
fromBranch Reg.InstrBranch {..} =
stmtsIf (mkCall "word_to_bool" [fromValue _instrBranchValue]) br1 br2
Expand Down
71 changes: 43 additions & 28 deletions src/Juvix/Compiler/Casm/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,9 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
goExtraBinop IntDiv res arg1 arg2
Reg.OpIntMod ->
goExtraBinop IntMod res arg1 arg2
Reg.OpIntLt ->
Reg.OpBool Reg.OpIntLt ->
goExtraBinop IntLt res arg1 arg2
Reg.OpIntLe ->
Reg.OpBool Reg.OpIntLe ->
goIntLe res arg1 arg2
Reg.OpFieldAdd ->
goNativeBinop FieldAdd res arg1 arg2
Expand All @@ -399,7 +399,7 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
goNativeBinop FieldMul res arg1 arg2
Reg.OpFieldDiv ->
goExtraBinop FieldDiv res arg1 arg2
Reg.OpEq ->
Reg.OpBool Reg.OpEq ->
goEq res arg1 arg2
Reg.OpStrConcat ->
unsupported "strings"
Expand Down Expand Up @@ -527,6 +527,7 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
Reg.Call x -> goCall liveVars x
Reg.TailCall x -> goTailCall x
Reg.Return x -> goReturn x
Reg.If x -> goIf liveVars x
Reg.Branch x -> goBranch liveVars x
Reg.Case x -> goCase liveVars x

Expand Down Expand Up @@ -573,34 +574,48 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
goAssignApValue _instrReturnValue
output'' Return

goIf :: HashSet Reg.VarRef -> Reg.InstrIf -> Sem r ()
goIf liveVars Reg.InstrIf {..} = case _instrIfOp of
Reg.OpEq
| Reg.ValConst (Reg.ConstInt 0) <- _instrIfArg1 -> do
v <- goValue _instrIfArg2
goBranch' liveVars _instrIfOutVar _instrIfTrue _instrIfFalse v
| Reg.ValConst (Reg.ConstInt 0) <- _instrIfArg2 -> do
v <- goValue _instrIfArg1
goBranch' liveVars _instrIfOutVar _instrIfTrue _instrIfFalse v
_ -> impossible

goBranch :: HashSet Reg.VarRef -> Reg.InstrBranch -> Sem r ()
goBranch liveVars Reg.InstrBranch {..} = do
v <- goValue _instrBranchValue
case v of
Imm c
| c == 0 -> goBlock blts failLab liveVars _instrBranchOutVar _instrBranchTrue
| otherwise -> goBlock blts failLab liveVars _instrBranchOutVar _instrBranchFalse
Ref r -> do
symFalse <- freshSymbol
symEnd <- freshSymbol
let labFalse = LabelRef symFalse Nothing
labEnd = LabelRef symEnd Nothing
output'' $ mkJumpIf (Lab labFalse) r
ap0 <- getAP
vars <- getVars
bltOff <- getBuiltinOffset
goLocalBlock ap0 vars bltOff liveVars _instrBranchOutVar _instrBranchTrue
-- _instrBranchOutVar is Nothing iff the branch returns
when (isJust _instrBranchOutVar) $
output'' (mkJumpRel (Val $ Lab labEnd))
addrFalse <- getPC
registerLabelAddress symFalse addrFalse
output'' $ Label labFalse
goLocalBlock ap0 vars bltOff liveVars _instrBranchOutVar _instrBranchFalse
addrEnd <- getPC
registerLabelAddress symEnd addrEnd
output'' $ Label labEnd
Lab {} -> impossible
goBranch' liveVars _instrBranchOutVar _instrBranchTrue _instrBranchFalse v

goBranch' :: HashSet Reg.VarRef -> Maybe Reg.VarRef -> Reg.Block -> Reg.Block -> Value -> Sem r ()
goBranch' liveVars outVar branchTrue branchFalse = \case
Imm c
| c == 0 -> goBlock blts failLab liveVars outVar branchTrue
| otherwise -> goBlock blts failLab liveVars outVar branchFalse
Ref r -> do
symFalse <- freshSymbol
symEnd <- freshSymbol
let labFalse = LabelRef symFalse Nothing
labEnd = LabelRef symEnd Nothing
output'' $ mkJumpIf (Lab labFalse) r
ap0 <- getAP
vars <- getVars
bltOff <- getBuiltinOffset
goLocalBlock ap0 vars bltOff liveVars outVar branchTrue
-- outVar is Nothing iff the branch returns
when (isJust outVar) $
output'' (mkJumpRel (Val $ Lab labEnd))
addrFalse <- getPC
registerLabelAddress symFalse addrFalse
output'' $ Label labFalse
goLocalBlock ap0 vars bltOff liveVars outVar branchFalse
addrEnd <- getPC
registerLabelAddress symEnd addrEnd
output'' $ Label labEnd
Lab {} -> impossible

goLoad :: Reg.Value -> Offset -> Sem r RValue
goLoad val off = do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base

-- | Checks if `node` is a case tree such that all leaves are constructor
-- applications, and each constructor `C` matched on in `c` either occurs as a
-- leaf in `node` at most once or the branch body in `c` associated with `C` is
-- an immediate value.
isConstructorTree :: Module -> Case -> Node -> Bool
isConstructorTree md c node = case run $ runFail $ go mempty node of
Just ctrsMap ->
Expand All @@ -29,6 +33,8 @@ isConstructorTree md c node = case run $ runFail $ go mempty node of
tags' = HashSet.fromList tags
Nothing -> True

-- Returns the map from tags to their number of occurrences in the leaves of
-- the case tree.
go :: (Member Fail r) => HashMap Tag Int -> Node -> Sem r (HashMap Tag Int)
go ctrs = \case
NCtr Constr {..} ->
Expand All @@ -39,6 +45,9 @@ isConstructorTree md c node = case run $ runFail $ go mempty node of
_ ->
fail

-- | Convert e.g. `case (if A C1 C2) of C1 := X | C2 := Y` to
-- `if A (case C1 of C1 := X | C2 := Y) (case C2 of C1 := X | C2 := Y)`
-- See: https://github.com/anoma/juvix/issues/2440
convertNode :: Module -> Node -> Node
convertNode md = dmap go
where
Expand Down
6 changes: 3 additions & 3 deletions src/Juvix/Compiler/Nockma/Translation/FromTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,9 @@ compile = \case
Tree.OpIntMul -> return (callStdlib StdlibMul args)
Tree.OpIntDiv -> return (callStdlib StdlibDiv args)
Tree.OpIntMod -> return (callStdlib StdlibMod args)
Tree.OpIntLt -> return (callStdlib StdlibLt args)
Tree.OpIntLe -> return (callStdlib StdlibLe args)
Tree.OpEq -> testEq _nodeBinopArg1 _nodeBinopArg2
Tree.OpBool Tree.OpIntLt -> return (callStdlib StdlibLt args)
Tree.OpBool Tree.OpIntLe -> return (callStdlib StdlibLe args)
Tree.OpBool Tree.OpEq -> testEq _nodeBinopArg1 _nodeBinopArg2
Tree.OpStrConcat -> return (callStdlib StdlibCatBytes args)
Tree.OpFieldAdd -> fieldErr
Tree.OpFieldSub -> fieldErr
Expand Down
Loading

0 comments on commit 4dcbb00

Please sign in to comment.