From 1b5785847fb5df29bb34a94c366bc8f24cf737b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?T=C3=A9rence=20Clastres?= Date: Sat, 2 Sep 2023 00:57:08 +0200 Subject: [PATCH] Constrain generic AST --- src/codegen/codegenEnv.ml | 30 +- src/codegen/codegenUtils.ml | 20 +- src/codegen/codegen_.ml | 103 ++++--- src/common/builtins.ml | 2 +- src/common/env.ml | 6 +- src/common/ppCommon.ml | 8 +- src/common/sailModule.ml | 2 +- src/common/typesCommon.ml | 49 +-- src/parsing/astParser.ml | 16 +- src/parsing/parser.mly | 26 +- src/passes/ir/sailHir/astHir.ml | 95 ------ src/passes/ir/sailHir/hir.ml | 142 +++++---- src/passes/ir/sailHir/hirAst.ml | 123 ++++++++ src/passes/ir/sailHir/hirMonad.ml | 5 +- src/passes/ir/sailHir/hirUtils.ml | 248 ++++++++------- src/passes/ir/sailHir/pp_hir.ml | 60 ++-- src/passes/ir/sailMir/mir.ml | 120 ++++---- .../ir/sailMir/{astMir.ml => mirAst.ml} | 4 +- src/passes/ir/sailMir/mirMonad.ml | 2 +- src/passes/ir/sailMir/mirUtils.ml | 7 +- src/passes/ir/sailMir/pp_mir.ml | 28 +- src/passes/ir/sailThir/thir.ml | 284 +++++++++--------- src/passes/ir/sailThir/thirUtils.ml | 95 +++--- src/passes/misc/cfg_analysis.ml | 4 +- .../monomorphization/monomorphization.ml | 127 ++++---- .../monomorphization/monomorphizationUtils.ml | 56 ++-- src/passes/process/process.ml | 43 +-- src/passes/process/processMonad.ml | 10 +- src/passes/process/processUtils.ml | 12 +- 29 files changed, 923 insertions(+), 804 deletions(-) delete mode 100644 src/passes/ir/sailHir/astHir.ml create mode 100644 src/passes/ir/sailHir/hirAst.ml rename src/passes/ir/sailMir/{astMir.ml => mirAst.ml} (96%) diff --git a/src/codegen/codegenEnv.ml b/src/codegen/codegenEnv.ml index 6bf9e9e..f1dc565 100644 --- a/src/codegen/codegenEnv.ml +++ b/src/codegen/codegenEnv.ml @@ -12,7 +12,7 @@ open MakeOrderedFunctions(ImportCmp) module Declarations = struct include SailModule.Declarations type process_decl = unit - type method_decl = {defn : AstMir.mir_function method_defn ; llval : llvalue ; extern : bool} + type method_decl = {defn : MirAst.mir_function method_defn ; llval : llvalue ; extern : bool} type struct_decl = {defn : struct_proto; ty : lltype} type enum_decl = unit end @@ -34,9 +34,9 @@ open Declarations type in_body = Monomorphization.Pass.out_body -let getLLVMBasicType f t llc llm : lltype E.t = - let rec aux t = - match snd t with +let getLLVMBasicType f ty llc llm : lltype E.t = + let rec aux ty = + match ty.value with | Bool -> i1_type llc |> return | Int n -> integer_type llc n |> return | Float -> double_type llc |> return @@ -44,13 +44,13 @@ let getLLVMBasicType f t llc llm : lltype E.t = | String -> i8_type llc |> pointer_type |> return | ArrayType (t,s) -> let+ t = aux t in array_type t s | Box t | RefType (t,_) -> aux t <&> pointer_type - | GenericType _ -> E.throw Logging.(make_msg (fst t) "no generic type in codegen") - | CompoundType {name=(_,name); _} when name = "_value" -> i64_type llc |> return (* for extern functions *) + | GenericType _ -> E.throw Logging.(make_msg ty.loc "no generic type in codegen") + | CompoundType {name; _} when name.value = "_value" -> i64_type llc |> return (* for extern functions *) | CompoundType {origin=None;_} - | CompoundType {decl_ty=None;_} -> E.throw Logging.(make_msg (fst t) "compound type with no origin or decl_ty") - | CompoundType {origin=Some (_,mname); name=(_,name); decl_ty=Some d;_} -> - f (mname,name,d) llc llm aux - in aux t + | CompoundType {decl_ty=None;_} -> E.throw Logging.(make_msg ty.loc "compound type with no origin or decl_ty") + | CompoundType {origin=Some mname; name; decl_ty=Some d;_} -> + f (mname.value,name.value,d) llc llm aux + in aux ty let handle_compound_type_codegen env (mname,name,d) llc _llm (aux : sailtype -> lltype E.t) : lltype E.t = @@ -80,7 +80,7 @@ let getLLVMBasicType f t llc llm : lltype E.t = | Some (E _enum) -> failwith "todo enum" | Some (S (_,defn)) -> let _,f_types = List.split defn.fields in - let* elts = ListM.map (fun (_,t,_) -> aux t) f_types <&> Array.of_list in + let* elts = ListM.map (fun ty -> aux (fst ty.value)) f_types <&> Array.of_list in begin match type_by_name llm ("struct." ^ name) with | Some ty -> return ty @@ -129,7 +129,7 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t = ); let valueify_method_sig (m:method_sig) : method_sig = - let value = fun pos -> dummy_pos,CompoundType{origin=None;name=(pos,"_value");generic_instances=[];decl_ty=None} in + let value = fun pos -> mk_locatable dummy_pos @@ CompoundType{origin=None;name=(mk_locatable pos "_value");generic_instances=[];decl_ty=None} in let rtype = m.rtype in (* keep the current type *) let params = List.map (fun (p:param) -> {p with ty=(value p.loc)}) m.params in {m with params; rtype} @@ -137,8 +137,8 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t = (* because the imports are at the mir stage, we also have to do some codegen for them *) - let load_methods (methods: IrMir.AstMir.mir_function method_defn list) is_import env = - ListM.fold_left ( fun d (m:IrMir.AstMir.mir_function method_defn) -> + let load_methods (methods: IrMir.MirAst.mir_function method_defn list) is_import env = + ListM.fold_left ( fun d (m:IrMir.MirAst.mir_function method_defn) -> let extern,proto = if (Either.is_left m.m_body) then (* extern method, all parameters must be of type value *) true,valueify_method_sig m.m_proto @@ -167,7 +167,7 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t = let load_structs structs write_env = SEnv.fold (fun acc (name,(_,defn)) -> let _,f_types = List.split defn.fields in - let* elts = ListM.map (fun (_,t,_) -> _getLLVMType sm.declEnv t llc llm) f_types <&> Array.of_list in + let* elts = ListM.map (fun ty-> _getLLVMType sm.declEnv (fst ty.value) llc llm) f_types <&> Array.of_list in let ty = match type_by_name llm ("struct." ^ name) with | Some ty -> ty | None -> let ty = named_struct_type llc ("struct." ^ name) in diff --git a/src/codegen/codegenUtils.ml b/src/codegen/codegenUtils.ml index d0d9d44..8cf2fe4 100644 --- a/src/codegen/codegenUtils.ml +++ b/src/codegen/codegenUtils.ml @@ -20,16 +20,16 @@ let getLLVMLiteral (l:literal) (llvm:llvm_args) : llvalue = | LChar c -> const_int (i8_type llvm.c) (Char.code c) | LString s -> build_global_stringptr s ".str" llvm.b -let ty_of_alias(t:sailtype) env : sailtype = - match snd t with - | CompoundType {origin=Some (_,mname); name=(_,name);decl_ty=Some T ();_} -> +let ty_of_alias(ty:sailtype) env : sailtype = + match ty.value with + | CompoundType {origin=Some mname; name;decl_ty=Some T ();_} -> begin - match DeclEnv.find_decl name (Specific (mname,Type)) env with + match DeclEnv.find_decl name.value (Specific (mname.value,Type)) env with | Some {ty=Some t';_} -> t' - | Some {ty=None;_} -> t - | None -> failwith @@ Fmt.str "ty_of_alias : '%s' not found in %s" (string_of_sailtype (Some t)) mname + | Some {ty=None;_} -> ty + | None -> failwith @@ Fmt.str "ty_of_alias : '%s' not found in %s" (string_of_sailtype (Some ty)) mname.value end - | _ -> t + | _ -> ty let unary (op:unOp) (t,v) : llbuilder -> llvalue = let f = @@ -37,7 +37,7 @@ let unary (op:unOp) (t,v) : llbuilder -> llvalue = | Float,Neg -> build_fneg | Int _,Neg -> build_neg | _,Not -> build_not - | _ -> Printf.sprintf "bad unary operand type : '%s'. Only double and int are supported" (string_of_sailtype (Some t)) |> failwith + | _ -> Printf.sprintf "bad unary operand type : '%s'. Only double and int are supported" (string_of_sailtype (Some (mk_locatable (fst t) (snd t)))) |> failwith in f v "" @@ -76,8 +76,8 @@ let binary (op:binOp) (t:sailtype) (l1:llvalue) (l2:llvalue) : llbuilder -> llva | And -> "and" | Or -> "or" | Le -> "le" | Lt -> "lt" | Ge -> "ge" | Gt -> "gt" | Mul -> "mul" | NEq -> "neq" | Div -> "div" in - let t = if snd t = Bool then fst t,Int 1 else t in (* thir will have checked for correctness *) - let l = operators (snd t) in + let t = if t.value = Bool then mk_locatable t.loc @@ Int 1 else t in (* thir will have checked for correctness *) + let l = operators t.value in let open Common.Monad.MonadOperator(Common.MonadOption.M) in match l >>| List.assoc_opt op |> Option.join with | Some oper -> oper l1 l2 "" diff --git a/src/codegen/codegen_.ml b/src/codegen/codegen_.ml index 871ec93..93285e3 100644 --- a/src/codegen/codegen_.ml +++ b/src/codegen/codegen_.ml @@ -6,10 +6,10 @@ open IrMir open Monad.UseMonad(E) module L = Llvm module E = Logging.Logger -let get_type (e:AstMir.expression) = snd e.info +let get_type (e:MirAst.expression) = e.tag.ty -let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x: AstMir.expression) : L.llvalue E.t = - match x.exp with +let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp: MirAst.expression) : L.llvalue E.t = + match exp.node with | Variable x -> let+ _,v = match (SailEnv.get_var x venv) with | Some (_,n) -> return n @@ -18,50 +18,57 @@ let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x: | Deref x -> eval_r env llvm x - | ArrayRead (array_exp, index_exp) -> - let* array_val = eval_l env llvm array_exp in - let+ index = eval_r env llvm index_exp in + | ArrayRead a -> + let* array_val = eval_l env llvm a.array in + let+ index = eval_r env llvm a.idx in let llvm_array = L.build_in_bounds_gep array_val [|L.(const_int (i64_type llvm.c) 0 ); index|] "" llvm.b in llvm_array - | StructRead ((_,mname),struct_exp,(_,field)) -> - let* st = eval_l env llvm struct_exp in - let+ st_type_name = Env.TypeEnv.get_from_id struct_exp.info tenv >>= function _,CompoundType c -> return (snd c.name) | _ -> E.throw Logging.(make_msg dummy_pos "problem with structure type") in - let fields = (SailEnv.get_decl st_type_name (Specific (mname,Struct)) venv |> Option.get).defn.fields in - let _,_,idx = List.assoc field fields in + | StructRead2 s -> + let* st = eval_l env llvm s.value.strct in + let* st_type_name = Env.TypeEnv.get_from_id (mk_locatable s.value.strct.tag.loc s.value.strct.tag.ty) tenv >>= function + | {value=CompoundType c;_} -> return c.name.value + | _ -> E.throw Logging.(make_msg dummy_pos "problem with structure type") + in + let+ decl = (SailEnv.get_decl st_type_name (Specific (s.import.value,Struct)) venv + |> E.throw_if_none Logging.(make_msg exp.tag.loc @@ Fmt.str "compiler error : no decl '%s' found" st_type_name)) in + + let fields = decl.defn.fields in + let {value=_,idx;_} = List.assoc s.value.field.value fields in L.build_struct_gep st idx "" llvm.b - | StructAlloc (_,(_,name),fields) -> - let _,fieldlist = fields |> List.split in - let* strct_ty = match L.type_by_name llvm.m ("struct." ^ name) with + | StructAlloc2 s -> + let _,fieldlist = s.value.fields |> List.split in + let* strct_ty = match L.type_by_name llvm.m ("struct." ^ s.value.name.value) with | Some s -> return s | None -> - E.throw Logging.(make_msg (fst x.info) @@ "unknown structure : " ^ ("struct." ^ name)) + E.throw Logging.(make_msg exp.tag.loc @@ "unknown structure : " ^ ("struct." ^ s.value.name.value)) in let struct_v = L.build_alloca strct_ty "" llvm.b in - let+ () = ListM.iteri ( fun i (_,f) -> - let+ v = eval_r env llvm f in + let+ () = ListM.iteri ( fun i f -> + let+ v = eval_r env llvm f.value in let v_f = L.build_struct_gep struct_v i "" llvm.b in L.build_store v v_f llvm.b |> ignore ) fieldlist in struct_v - | _ -> E.throw Logging.(make_msg (fst x.info) "unexpected rvalue for codegen") + | Literal _ | UnOp _ | BinOp _ | Ref _ | ArrayStatic _ | EnumAlloc _ -> + E.throw Logging.(make_msg exp.tag.loc "unexpected lvalue for codegen") -and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x:AstMir.expression) : L.llvalue E.t = - let* ty = Env.TypeEnv.get_from_id x.info tenv in - match x.exp with - | Variable _ | StructRead _ | ArrayRead _ | StructAlloc _ -> let+ v = eval_l env llvm x in L.build_load v "" llvm.b +and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp:MirAst.expression) : L.llvalue E.t = + let* ty = Env.TypeEnv.get_from_id (mk_locatable exp.tag.loc exp.tag.ty) tenv in + match exp.node with + | Variable _ | StructRead2 _ | ArrayRead _ | StructAlloc2 _ -> let+ v = eval_l env llvm exp in L.build_load v "" llvm.b | Literal l -> return @@ getLLVMLiteral l llvm - | UnOp (op,e) -> let+ l = eval_r env llvm e in unary op (ty_of_alias ty (snd venv),l) llvm.b + | UnOp (op,e) -> let+ l = eval_r env llvm e in unary op (let ty = ty_of_alias ty (snd venv) in (ty.loc,ty.value),l) llvm.b - | BinOp (op,e1, e2) -> - let+ l1 = eval_r env llvm e1 - and* l2 = eval_r env llvm e2 - in binary op (ty_of_alias ty (snd venv)) l1 l2 llvm.b + | BinOp bop -> + let+ l1 = eval_r env llvm bop.left + and* l2 = eval_r env llvm bop.right + in binary bop.op (ty_of_alias ty (snd venv)) l1 l2 llvm.b | Ref (_,e) -> eval_l env llvm e | Deref e -> let+ v = eval_l env llvm e in L.build_load v "" llvm.b @@ -80,32 +87,31 @@ and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x:AstMi L.build_load array "" llvm.b end - | EnumAlloc _ -> E.throw Logging.(make_msg (fst x.info) "enum allocation unimplemented") + | EnumAlloc _ -> E.throw Logging.(make_msg exp.tag.loc "enum allocation unimplemented") - | _ -> E.throw Logging.(make_msg (fst x.info) "problem with thir") -and construct_call (name:string) ((loc,mname):l_str) (args:AstMir.expression list) (venv,tenv as env : SailEnv.t*Env.TypeEnv.t) (llvm:llvm_args) : L.llvalue E.t = - let* args_type,llargs = ListM.map (fun arg -> let+ r = eval_r env llvm arg in arg.info,r) args >>| List.split +and construct_call (name:string) (mname:l_str) (args:MirAst.expression list) (venv,tenv as env : SailEnv.t*Env.TypeEnv.t) (llvm:llvm_args) : L.llvalue E.t = + let* args_type,llargs = ListM.map (fun arg -> let+ r = eval_r env llvm arg in arg.tag,r) args >>| List.split in (* let mname = mangle_method_name name origin.mname args_type in *) - let mangled_name = "_" ^ mname ^ "_" ^ name in + let mangled_name = "_" ^ mname.value ^ "_" ^ name in Logs.debug (fun m -> m "constructing call to %s" name); - let* llval,ext = match SailEnv.get_decl mangled_name (Specific (mname,Method)) venv with + let* llval,ext = match SailEnv.get_decl mangled_name (Specific (mname.value,Method)) venv with | None -> begin - match SailEnv.get_decl name (Specific (mname,Method)) venv with + match SailEnv.get_decl name (Specific (mname.value,Method)) venv with | Some {llval;extern;_} -> return (llval,extern) - | None -> E.throw Logging.(make_msg loc @@ Printf.sprintf "implementation of %s not found" mangled_name ) + | None -> E.throw Logging.(make_msg mname.loc @@ Printf.sprintf "implementation of %s not found" mangled_name ) end | Some {llval;extern;_} -> return (llval,extern) in let+ args = if ext then - ListM.map2 (fun t v -> - let+ t = Env.TypeEnv.get_from_id t tenv in + ListM.map2 (fun (t:IrThir.ThirUtils.exp_tag) v -> + let+ t = Env.TypeEnv.get_from_id (mk_locatable t.loc t.ty) tenv in let builder = - match snd (ty_of_alias t (snd venv)) with + match (ty_of_alias t (snd venv)).value with | Bool | Int _ | Char -> L.build_zext | Float -> L.build_bitcast | CompoundType _ -> fun v _ _ _ -> v @@ -118,16 +124,16 @@ and construct_call (name:string) ((loc,mname):l_str) (args:AstMir.expression lis in L.build_call llval (Array.of_list args) "" llvm.b -open AstMir +open MirAst let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,tenv : SailEnv.t*Env.TypeEnv.t) : unit E.t = - let declare_var (mut:bool) (name:string) (ty:sailtype) (exp:AstMir.expression option) (venv : SailEnv.t) : SailEnv.t E.t = + let declare_var (mut:bool) (name:string) (ty:sailtype) (exp:MirAst.expression option) (venv : SailEnv.t) : SailEnv.t E.t = let _ = mut in (* todo manage mutable types *) let entry_b = L.(entry_block proto |> instr_begin |> builder_at llvm.c) in let* v = match exp with | Some e -> - let* t = Env.TypeEnv.get_from_id e.info tenv + let* t = Env.TypeEnv.get_from_id (mk_locatable e.tag.loc e.tag.ty) tenv and* v = eval_r (venv,tenv) llvm e in let+ ty = getLLVMType (snd venv) t llvm.c llvm.m in let x = L.build_alloca ty name entry_b in @@ -154,20 +160,21 @@ let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,t let llvm_bbs = BlockMap.add lbl llvm_bb llvm_bbs in L.position_at_end llvm_bb llvm.b; let* () = ListM.iter (fun x -> assign_var x.target x.expression (venv,tenv)) bb.assignments in - match bb.terminator with - | Some (Return e) -> + let* terminator = E.throw_if_none Logging.(make_msg bb.location "no terminator : mir is broken") bb.terminator in + match terminator with + | Return e -> let+ ret = match e with | Some r -> let+ v = eval_r (venv,tenv) llvm r in L.build_ret v | None -> return L.build_ret_void in ret llvm.b |> ignore; llvm_bbs - | Some (Goto lbl) -> + | Goto lbl -> let+ llvm_bbs = aux lbl llvm_bbs venv in L.position_at_end llvm_bb llvm.b; let _ = L.build_br (BlockMap.find lbl llvm_bbs) llvm.b in llvm_bbs - | Some (Invoke f) -> + | Invoke f -> let* c = construct_call f.id f.origin f.params (venv,tenv) llvm in begin match f.target with @@ -179,7 +186,7 @@ let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,t L.build_br (BlockMap.find f.next llvm_bbs) llvm.b |> ignore; llvm_bbs - | Some (SwitchInt si) -> + | SwitchInt si -> let* sw_val = eval_r (venv,tenv) llvm si.choice in let sw_val = L.build_intcast sw_val (L.i32_type llvm.c) "" llvm.b in (* for condition, expression val will be bool *) let* llvm_bbs = aux si.default llvm_bbs venv in @@ -192,9 +199,7 @@ let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,t in L.add_case sw n (BlockMap.find lbl bm); bm ) llvm_bbs si.paths - - | None -> E.throw Logging.(make_msg bb.location "no terminator : mir is broken") - | Some Break -> E.throw Logging.(make_msg bb.location "no break should be there") + | Break -> E.throw Logging.(make_msg bb.location "no break should be there") end in ( diff --git a/src/common/builtins.ml b/src/common/builtins.ml index 8f00dab..0a507a2 100644 --- a/src/common/builtins.ml +++ b/src/common/builtins.ml @@ -5,4 +5,4 @@ let register_builtin name generics p rtype variadic l: method_sig list = let get_builtins () : method_sig list = [] - |> register_builtin "box" ["T"] [dummy_pos,GenericType "T"] (Some (dummy_pos,Box (dummy_pos,GenericType "T"))) false \ No newline at end of file + |> register_builtin "box" ["T"] [mk_locatable dummy_pos (GenericType "T")] (Some (mk_locatable dummy_pos (Box (mk_locatable dummy_pos (GenericType "T"))))) false \ No newline at end of file diff --git a/src/common/env.ml b/src/common/env.ml index 6f4368a..40eae86 100644 --- a/src/common/env.ml +++ b/src/common/env.ml @@ -388,7 +388,7 @@ module TypeEnv = struct let empty = FieldMap.empty let get_id ty (te :t) : string * t = let add_if_no_exists s = FieldMap.update s (Option.fold ~none:(Some ty) ~some:Option.some) in - let s = match snd ty with + let s = match ty.value with | Bool -> "bool" | Float -> "float" | Char -> "char" @@ -396,12 +396,12 @@ module TypeEnv = struct | Int n -> "int" ^ string_of_int n | ArrayType _ -> "array" | GenericType t -> t - | CompoundType t -> snd t.name + | CompoundType t -> t.name.value | Box _ -> "box" | RefType _ -> "ref" in s,add_if_no_exists s te - let get_from_id (lid,id:l_str) te : sailtype E.t = E.throw_if_none Logging.(make_msg lid @@ Fmt.str "id '%s' not found" id) (FieldMap.find_opt id te) + let get_from_id (id:l_str) te : sailtype E.t = E.throw_if_none Logging.(make_msg id.loc @@ Fmt.str "id '%s' not found" id.value) (FieldMap.find_opt id.value te) end diff --git a/src/common/ppCommon.ml b/src/common/ppCommon.ml index f6064b8..adee878 100644 --- a/src/common/ppCommon.ml +++ b/src/common/ppCommon.ml @@ -40,16 +40,16 @@ let pp_binop pf b = - let rec pp_type (pf : formatter) (_,t : sailtype) : unit = - match t with + let rec pp_type (pf : formatter) (t : sailtype) : unit = + match t.value with Bool -> pp_print_string pf "bool" | Int n -> Format.fprintf pf "i%i" n | Float -> pp_print_string pf "float" | Char -> pp_print_string pf "char" | String -> pp_print_string pf "string" | ArrayType (t,s) -> Format.fprintf pf "array<%a;%d>" pp_type t s - | CompoundType {name=(_,x); generic_instances;_} -> - Format.fprintf pf "%s<%a>" x (pp_print_list ~pp_sep:pp_comma pp_type) generic_instances + | CompoundType {name; generic_instances;_} -> + Format.fprintf pf "%s<%a>" name.value (pp_print_list ~pp_sep:pp_comma pp_type) generic_instances | Box(t) -> Format.fprintf pf "ref<%a>" pp_type t | RefType (t,b) -> if b then Format.fprintf pf "&mut %a" pp_type t diff --git a/src/common/sailModule.ml b/src/common/sailModule.ml index fd0dc1d..78f9605 100644 --- a/src/common/sailModule.ml +++ b/src/common/sailModule.ml @@ -49,4 +49,4 @@ let method_decl_of_defn (d : 'a method_defn) : Declarations.method_decl = and args = d.m_proto.params and generics = d.m_proto.generics and variadic = d.m_proto.variadic in - ((pos,name),{ret;args;generics;variadic}) \ No newline at end of file + (mk_locatable pos name,{ret;args;generics;variadic}) \ No newline at end of file diff --git a/src/common/typesCommon.ml b/src/common/typesCommon.ml index bfdae09..3573149 100644 --- a/src/common/typesCommon.ml +++ b/src/common/typesCommon.ml @@ -26,9 +26,19 @@ module FieldSet = Set.Make (String) type loc = Lexing.position * Lexing.position let dummy_pos : loc = Lexing.dummy_pos,Lexing.dummy_pos +type ('i,'n) tagged_node = {tag : 'i; node : 'n} + +type ('i,'v) importable = {import : 'i; value : 'v} +let mk_importable import value = {import;value} + +type 'v locatable = {loc : loc; value : 'v} +let mk_locatable loc value = {loc;value} + + + type 'a dict = (string * 'a) list -type l_str = loc * string +type l_str = string locatable type ('m,'p,'s,'e,'t) decl_sum = M of 'm | P of 'p | S of 's | E of 'e | T of 't @@ -50,7 +60,7 @@ let string_of_decl : (_,_,_,_,_) decl_sum -> string = function | T _ -> "type" -type sailtype = loc * sailtype_ and sailtype_ = +type sailtype = sailtype_ locatable and sailtype_ = | Bool | Int of int | Float @@ -75,27 +85,27 @@ type literal = | LString of string let sailtype_of_literal = function -| LBool _ -> dummy_pos,Bool -| LFloat _ -> dummy_pos,Float -| LInt l -> dummy_pos,Int l.size -| LChar _ -> dummy_pos,Char -| LString _ -> dummy_pos,String +| LBool _ -> mk_locatable dummy_pos Bool +| LFloat _ -> mk_locatable dummy_pos Float +| LInt l -> mk_locatable dummy_pos @@ Int l.size +| LChar _ -> mk_locatable dummy_pos Char +| LString _ -> mk_locatable dummy_pos String let rec string_of_sailtype (t : sailtype option) : string = let open Printf in match t with - | Some (_,t) -> + | Some t -> begin - match t with + match t.value with | Bool -> "bool" | Int size -> "i" ^ string_of_int size | Float -> "float" | Char -> "char" | String -> "string" | ArrayType (t,s) -> sprintf "array<%s;%d>" (string_of_sailtype (Some t)) s - | CompoundType {name=(_,x); generic_instances=[];_} -> (* empty compound type -> lookup what it binds to *) sprintf "%s" x - | CompoundType {name=(_,x); generic_instances;_} -> sprintf "%s<%s>" x (String.concat ", " (List.map (fun t -> string_of_sailtype (Some t)) generic_instances)) + | CompoundType {name; generic_instances=[];_} -> (* empty compound type -> lookup what it binds to *) sprintf "%s" name.value + | CompoundType {name; generic_instances;_} -> sprintf "%s<%s>" name.value (String.concat ", " (List.map (fun t -> string_of_sailtype (Some t)) generic_instances)) | Box(t) -> sprintf "ref<%s>" (string_of_sailtype (Some t)) | RefType (t,b) -> if b then sprintf "&mut %s" (string_of_sailtype (Some t)) @@ -136,7 +146,11 @@ type enum_defn = -type interface = {p_params: param list ; p_shared_vars: ((loc * (string * sailtype)) list * (loc * (string * sailtype)) list)} +type interface = +{ + p_params: param list ; + p_shared_vars: (string * sailtype) locatable list * (string * sailtype) locatable list +} type 'a process_defn = { @@ -148,7 +162,7 @@ type 'a process_defn = } type 'e proc_init = { - mloc : l_str option; + mloc : l_str option; id : string; proc : string; params : 'e list; @@ -189,7 +203,7 @@ type enum_proto = type struct_proto = { generics : string list; - fields : (loc * sailtype * int) dict + fields : (sailtype * int) locatable dict } type method_proto = @@ -202,8 +216,8 @@ type method_proto = type process_proto = { - read : (loc * (string * sailtype)) list; - write : (loc * (string * sailtype)) list; + read : (string * sailtype) locatable list; + write :(string * sailtype) locatable list; params : param list; generics : string list; } @@ -226,7 +240,7 @@ let defn_to_proto (type proto) (decl: proto decl) : proto = match decl with and generics = d.p_generics and params = d.p_interface.p_params in {read;write;generics;params} -| Struct d -> {generics=d.s_generics;fields=List.mapi (fun i ((l,n),t) -> n,(l,t,i)) d.s_fields} +| Struct d -> {generics=d.s_generics;fields=List.mapi (fun i (n,t) -> n.value,mk_locatable n.loc (t,i)) d.s_fields} | Enum d -> {generics=d.e_generics;injections=d.e_injections} type import = @@ -237,7 +251,6 @@ type import = proc_order: int; } - module ImportCmp = struct type t = import let compare i1 i2 = String.compare i1.mname i2.mname end module ImportSet = Set.Make(ImportCmp) diff --git a/src/parsing/astParser.ml b/src/parsing/astParser.ml index 93e94c9..6bd08bd 100644 --- a/src/parsing/astParser.ml +++ b/src/parsing/astParser.ml @@ -24,7 +24,7 @@ open Common open TypesCommon (* expressions are control free *) -type expression = loc * expression_ and expression_ = +type expression = expression_ locatable and expression_ = Variable of string | Deref of expression | StructRead of expression * l_str @@ -34,7 +34,7 @@ type expression = loc * expression_ and expression_ = | BinOp of binOp * expression * expression | Ref of bool * expression | ArrayStatic of expression list - | StructAlloc of l_str option * l_str * (loc * expression) dict + | StructAlloc of l_str option * l_str * expression locatable dict | EnumAlloc of l_str * expression list | MethodCall of l_str option * l_str * expression list @@ -43,10 +43,10 @@ type pattern = | PVar of string | PCons of string * pattern list -type statement = loc * statement_ and statement_ = +type statement = statement_ locatable and statement_ = | DeclVar of bool * string * sailtype option * expression option | Skip - | Assign of expression * expression + | Assign of {path:expression; value:expression} | Seq of statement * statement | If of expression * statement * statement option | While of expression * statement @@ -61,7 +61,7 @@ type statement = loc * statement_ and statement_ = type pgroup_ty = Sequence | Parallel -type ('s,'e) p_statement = loc * ('s,'e) p_statement_ and ('s,'e) p_statement_ = +type ('s,'e) p_statement = ('s,'e) p_statement_ locatable and ('s,'e) p_statement_ = | Run of l_str * 'e option | Statement of 's * 'e option | PGroup of {p_ty : pgroup_ty ; cond : 'e option ; children : ('s,'e) p_statement list} @@ -69,7 +69,7 @@ type ('s,'e) p_statement = loc * ('s,'e) p_statement_ and ('s,'e) p_statement_ = type ('s,'e) process_body = { locals : (l_str * sailtype) list; init : 's; - proc_init : (loc * 'e proc_init) list; + proc_init : ('e proc_init locatable) list; loop : ('s,'e) p_statement; } @@ -100,7 +100,7 @@ let mk_program (md:metadata) (imports: ImportSet.t) l : (statement, (statement in (env,m,p) | Struct d -> - let s_fields = List.sort_uniq (fun ((_,s1),_) ((_,s2),_) -> String.compare s1 s2) d.s_fields in + let s_fields = List.sort_uniq (fun (s1,_) (s2,_) -> String.compare s1.value s2.value) d.s_fields in E.throw_if Logging.(make_msg d.s_pos "duplicate fields" ) (List.(length s_fields <> length d.s_fields)) >>= fun () -> let+ env = DeclEnv.add_decl d.s_name (d.s_pos, defn_to_proto (Struct {d with s_fields})) Struct e |> rethrow d.s_pos @@ -116,7 +116,7 @@ let mk_program (md:metadata) (imports: ImportSet.t) l : (statement, (statement ListM.fold_left (fun (e,f) d -> let* () = E.throw_if Logging.(make_msg d.m_proto.pos "calling a method 'main' is not allowed") (d.m_proto.name = "main") in let true_name = (match d.m_body with Left (sname,_) -> sname | Right _ -> d.m_proto.name) in - let+ env = DeclEnv.add_decl d.m_proto.name ((d.m_proto.pos,true_name),defn_to_proto (Method d)) Method e + let+ env = DeclEnv.add_decl d.m_proto.name (mk_locatable d.m_proto.pos true_name,defn_to_proto (Method d)) Method e in (env,d::f) ) (e,m) d in (env,funs,p) diff --git a/src/parsing/parser.mly b/src/parsing/parser.mly index 71b8fb7..8df4866 100644 --- a/src/parsing/parser.mly +++ b/src/parsing/parser.mly @@ -24,6 +24,7 @@ open Common open TypesCommon open AstParser + module SailParser = struct end %} %token TYPE_BOOL TYPE_FLOAT TYPE_CHAR TYPE_STRING %token TYPE_INT @@ -131,9 +132,10 @@ let process_body := proc_init = midrule(P_PROC_INIT ; ":" ; ~ = list(located(proc_init)) ; <>)? ; loop = midrule(P_LOOP ; ":" ; loop?)? ; { - let init = Option.(join init |> value ~default:($loc,Skip)) in + + let init = Option.(join init |> value ~default:(mk_locatable $loc Skip)) in let proc_init = Option.value proc_init ~default:[] in - let loop = Option.join loop |> function Some x -> x | None -> $loc,(Statement (($loc,Skip),None)) in + let loop = Option.join loop |> function Some x -> x | None -> mk_locatable $loc (Statement (mk_locatable $loc Skip,None)) in {locals;init;proc_init;loop} } @@ -167,7 +169,7 @@ let generic := loption(delimited("<", separated_list(",", UID), ">")) let returnType := preceded(":", sailtype)? -let mutable_var(X) := (loc,id) = located(ID) ; ":" ; mut = mut ; ty =X ; { {id;mut;loc;ty} } +let mutable_var(X) := id = located(ID) ; ":" ; mut = mut ; ty =X ; { {id=id.value;mut;loc=id.loc;ty} } let separated_nonempty_list_opt(separator, X) := | x = X ; separator? ; { [ x ] } @@ -200,8 +202,10 @@ let expression := | "*" ; ~ = expression ; %prec UNARY | e1 = expression ; op =binOp ; e2 =expression ; { BinOp(op,e1,e2) } | ~ = delimited ("[", separated_list(",", expression), "]") ; - | ~ = ioption(module_loc) ; ~ =located(ID) ; ~ = midrule(l = brace_del_sep_list(",", id_colon(expression)); - {List.fold_left (fun l ((ly,y),z) -> (y,(ly,z))::l) [] l}) ; + | ~ = ioption(module_loc) ; ~ =located(ID) ; + ~ = midrule(l = brace_del_sep_list(",", id_colon(expression)); + { List.fold_left (fun accu (f,e) -> (f.value,mk_locatable f.loc e)::accu) [] l } + ) ; | ~ = located(UID) ; ~ = loption(parenthesized (separated_list(",", expression))) ; | ~ = ioption(module_loc) ; ~ = located(ID) ; ~ = parenthesized(separated_list (",", expression)) ; ) @@ -244,15 +248,15 @@ let iterable_or_range := | rl = INT ; "," ; rr = INT ; { let rl = Z.to_int rl in let rr = Z.to_int rr in - ArrayStatic (List.init (rr - rl) (fun i -> dummy_pos,Literal (LInt {l=Z.of_int (i + rl); size=32}))) + ArrayStatic (List.init (rr - rl) (fun i -> mk_locatable dummy_pos (Literal (LInt {l=Z.of_int (i + rl); size=32})) ) ) } -| e = expression ; {snd e} +| e = expression ; {e.value} let single_statement := | located ( | vardecl - | l = expression ; "=" ; e = expression ; + | path = expression ; "=" ; value = expression ; {Assign {path;value} } | CASE ; ~ = parenthesized(expression) ; ~ = brace_del_sep_list(",", case) ; | ~ = ioption(module_loc) ; ~ = located(ID) ; ~ = parenthesized(separated_list(",", expression)) ; | RETURN ; ~ = expression? ; @@ -265,13 +269,15 @@ let vardecl := VAR ; ~ = mut ; ~ = ID ; ~ = preceded(":", sailtype)? ; ~ = prece let brace_del_sep_list(sep,x) := delimited("{", separated_nonempty_list(sep, x), "}") -let located(x) == ~ = x ; { ($loc,x) } +let located(x) == ~ = x ; { mk_locatable $loc x } let case := separated_pair(pattern, ":", statement) let mut := boption(MUT) -let module_loc := ~ = located(ID); "::" ; <> | x = located(SELF) ; "::" ; { (fst x),Constants.sail_module_self} +let module_loc := + | ~ = located(ID); "::" ; <> + | located(SELF) ; "::" ; { mk_locatable dummy_pos Constants.sail_module_self } let parenthesized(e) == delimited("(",e,")") diff --git a/src/passes/ir/sailHir/astHir.ml b/src/passes/ir/sailHir/astHir.ml deleted file mode 100644 index 8dfd46c..0000000 --- a/src/passes/ir/sailHir/astHir.ml +++ /dev/null @@ -1,95 +0,0 @@ -(**************************************************************************) -(* *) -(* SAIL *) -(* *) -(* Frédéric Dabrowski, LMV, Orléans University *) -(* *) -(* Copyright (C) 2022 Frédéric Dabrowski *) -(* *) -(* This program is free software: you can redistribute it and/or modify *) -(* it under the terms of the GNU General Public License as published by *) -(* the Free Software Foundation, either version 3 of the License, or *) -(* (at your option) any later version. *) -(* *) -(* This program is distributed in the hope that it will be useful, *) -(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) -(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) -(* GNU General Public License for more details. *) -(* *) -(* You should have received a copy of the GNU General Public License *) -(* along with this program. If not, see . *) -(**************************************************************************) - -open Common.TypesCommon - -type ('info,'import) expression = {info: 'info ; exp: ('info,'import) _expression} and ('info,'import) _expression = - | Variable of string - | Deref of ('info,'import) expression - | StructRead of 'import * ('info,'import) expression * l_str - | ArrayRead of ('info,'import) expression * ('info,'import) expression - | Literal of literal - | UnOp of unOp * ('info,'import) expression - | BinOp of binOp * ('info,'import) expression * ('info,'import) expression - | Ref of bool * ('info,'import) expression - | ArrayStatic of ('info,'import) expression list - | StructAlloc of 'import * l_str * (loc * ('info,'import) expression) dict - | EnumAlloc of l_str * ('info,'import) expression list - | MethodCall of l_str * 'import * ('info,'import) expression list - - - -type ('info,'import,'exp) statement = {info: 'info; stmt: ('info,'import,'exp) _statement} and ('info,'import,'exp) _statement = - | DeclVar of bool * string * sailtype option * 'exp option - | Skip - | Assign of 'exp * 'exp - | Seq of ('info,'import,'exp) statement * ('info,'import,'exp) statement - | If of 'exp * ('info,'import,'exp) statement * ('info,'import,'exp) statement option - | Loop of ('info,'import,'exp) statement - | Break - | Case of 'exp * (string * string list * ('info,'import,'exp) statement) list - | Invoke of {ret_var:string option; import: 'import; id: l_str; args:'exp list} - | Return of 'exp option - (* - | DeclSignal of string - | Emit of string - | Await of string - | When of string * ('info,'import,'exp) statement - | Watching of string * ('info,'import,'exp) statement - | Par of ('info,'import,'exp) statement * ('info,'import,'exp) statement - *) - | Block of ('info,'import,'exp) statement - -let buildExp info (exp: (_,_) _expression) : (_,_) expression = {info;exp} -let buildStmt info stmt : (_,_,_) statement = {info;stmt} - - -module Syntax = struct - let skip = buildStmt dummy_pos Skip - - let (=) = fun l r -> buildStmt dummy_pos (Assign (l,r)) - - let var (loc,id,ty) = buildStmt loc (DeclVar (true,id,Some ty,None)) - - let _true = buildExp dummy_pos (Literal (LBool true)) - let _false = buildExp dummy_pos (Literal (LBool false)) - - - let (+) = fun l r -> buildExp dummy_pos (BinOp(Plus,l,r)) - let (%) = fun l r -> buildExp dummy_pos (BinOp(Rem,l,r)) - let (==) = fun l r -> buildExp dummy_pos (BinOp(Eq, l,r)) - - let (&&) = fun s1 s2 -> buildStmt dummy_pos (Seq (s1,s2)) - - - let (!@) = fun id -> buildExp dummy_pos (Variable id) - - let (!) = fun n -> buildExp dummy_pos (Literal (LInt {l=Z.of_int n; size=32})) - - let (!!) = fun b -> buildStmt dummy_pos (Block b) - - let _if cond _then _else = - let _else = match _else.stmt with Skip -> None | stmt -> Some {_else with stmt} in - buildStmt dummy_pos (If (cond,_then,_else)) - - -end \ No newline at end of file diff --git a/src/passes/ir/sailHir/hir.ml b/src/passes/ir/sailHir/hir.ml index 53e384a..63ad0db 100644 --- a/src/passes/ir/sailHir/hir.ml +++ b/src/passes/ir/sailHir/hir.ml @@ -2,11 +2,9 @@ open SailParser open Common open TypesCommon open Monad -open AstHir +open HirAst open HirUtils open M -type expression = HirUtils.expression -type statement = HirUtils.statement module Pass = Pass.MakeFunctionPass (V)( @@ -31,16 +29,16 @@ struct let open MonadSyntax(M.ECS) in let open MonadOperator(M.ECS) in - let rec aux (info,s : m_in) : m_out M.ECS.t = + let rec aux (stmt:m_in) : m_out M.ECS.t = - let buildSeq s1 s2 = {info; stmt = Seq (s1, s2)} in - let buildStmt stmt = {info;stmt} in + let buildSeq = buildSeq stmt.loc in + let buildStmt = buildStmt stmt.loc in let buildSeqStmt s1 s2 = buildSeq s1 @@ buildStmt s2 in - match s with + match stmt.value with | DeclVar (mut, id, t, e ) -> - M.ECS.set_var info id >>= fun () -> - let* t = match t with + M.ECS.set_var stmt.loc id >>= fun () -> + let* ty = match t with | None -> return None | Some t -> let* (ve,d) = M.ECS.get in @@ -50,69 +48,78 @@ struct in begin match e with | Some e -> let+ (e, s) = lower_expression e in - buildSeqStmt s (DeclVar (mut,id, t, Some e)) - | None -> return {info;stmt=DeclVar (mut,id, t, None)} + buildSeqStmt s (DeclVar {mut;id;ty;value=Some e}) + | None -> return (buildStmt (DeclVar {mut;id;ty;value=None})) end - | Skip -> return {info;stmt=Skip} - | Assign(e1, e2) -> - let* e1,s1 = lower_expression e1 in - let+ e2,s2 = lower_expression e2 in - buildSeq s1 @@ buildSeqStmt s2 @@ Assign (e1, e2) - - | Seq (c1, c2) -> let+ c1 = aux c1 and* c2 = aux c2 in {info;stmt=Seq (c1, c2)} - | If (e, c1, Some c2) -> - let+ e,s = lower_expression e and* c1 = aux c1 and* c2 = aux c2 in - buildSeqStmt s (If (e, c1, Some c2)) - - | If ( e, c1, None) -> - let+ (e, s) = lower_expression e and* c1 = aux c1 in - buildSeqStmt s (If (e, c1, None)) + + | Skip -> return (buildStmt Skip) + + | Assign a -> + let+ path,s1 = lower_expression a.path + and* value,s2 = lower_expression a.value in + buildSeq s1 @@ buildSeqStmt s2 @@ Assign {path; value} + + | Seq (c1, c2) -> + buildSeq <$> aux c1 <*> aux c2 + + | If (e, then_, else_) -> + let+ cond,s = lower_expression e + and* then_ = aux then_ + and* else_ = match else_ with None -> return None | Some else_ -> aux else_ >>| Option.some in + buildSeqStmt s (If {cond;then_;else_}) | While (e, c) -> - let+ e,s = lower_expression e and* c = aux c in - let c = buildStmt (If (e,c,Some (buildStmt Break))) in + let+ cond,s = lower_expression e + and* then_ = aux c in + let c = buildStmt (If {cond;then_;else_=Some (buildStmt Break)}) in buildSeqStmt s (Loop c) | Loop c -> let+ c = aux c in buildStmt (Loop c) - | For {var;iterable;body} -> + | For {var;iterable;body} -> (* fixme : temporary *) begin - match iterable with - | _,ArrayStatic el -> + match iterable.value with + | ArrayStatic el -> let open AstParser in let arr_id = "_for_arr_" ^ var in let i_id = "_for_i_" ^ var in let arr_length = List.length el in - - let tab_decl = info,DeclVar (false, arr_id, Some (dummy_pos,ArrayType ((dummy_pos,Int 32),arr_length)), Some iterable) in - let var_decl = info,DeclVar (true, var, Some (dummy_pos,Int 32), None) in - let i_decl = info,DeclVar (true, i_id, Some (dummy_pos,Int 32), Some (info,(Literal (LInt {l=Z.zero;size=32})))) in - - let tab = info,Variable arr_id in - let var = info,Variable var in - let i = info,Variable i_id in + let loc x = mk_locatable iterable.loc x in + + let tab_decl = loc @@ DeclVar (false, arr_id, Some (loc @@ ArrayType (loc (Int 32),arr_length)), Some iterable) in + let var_decl = loc @@ DeclVar (true, var, Some (loc @@ Int 32), None) in + let i_decl = loc @@ DeclVar (true, i_id, Some (loc @@ Int 32), Some (mk_locatable dummy_pos @@ (Literal (LInt {l=Z.zero;size=32})))) in + + let tab = loc (Variable arr_id) in + let var = loc (Variable var) in + let i =loc (Variable i_id) in - let cond = info,BinOp (Lt, i, (info,Literal (LInt {l=Z.of_int arr_length;size=32}))) in - let incr = info,Assign (i,(info,BinOp (Plus, i, (info, Literal (LInt {l=Z.one;size=32}))))) in - let init = info,Seq ((info,Seq (tab_decl,var_decl)), i_decl) in - let vari = info, Assign (var,(info,ArrayRead(tab,i))) in - - let body = info,Seq((info,Seq(vari,body)),incr) in - let _while = info,While (cond,body) in - let _for = info,Seq(init,_while) in + let cond = loc @@ BinOp (Lt, i, (mk_locatable dummy_pos @@ Literal (LInt {l=Z.of_int arr_length;size=32}))) in + let incr = loc @@ Assign {path=i;value= loc @@ BinOp (Plus, i, (loc @@ Literal (LInt {l=Z.one;size=32})))} in + let init = loc @@ Seq (loc @@ Seq (tab_decl,var_decl), i_decl) in + let vari = loc @@ Assign {path=var;value=loc @@ ArrayRead(tab,i)} in + + let body = loc @@ Seq(loc @@ Seq(vari,body),incr) in + let _while = loc @@ While (cond,body) in + let _for = loc @@ Seq(init,_while) in aux _for - | loc,_ -> M.ECS.throw Logging.(make_msg loc "for loop only allows a static array expression at the moment") + | _ -> M.ECS.throw Logging.(make_msg iterable.loc "for loop only allows a static array expression at the moment") end - | Break () -> return {info; stmt=Break} + | Break () -> return (buildStmt Break) + (* | Case(loc, e, cases) -> Case (loc, e, List.map (fun (p,c) -> (p, aux c)) cases) *) - | Case (e, _cases) -> let+ e,s = lower_expression e in - buildSeqStmt s (Case (e, [])) + | Case (e, cases) -> + let open MonadFunctions(M.ECS) in + let+ switch,s = lower_expression e + and* _cases = ListM.map (pairMap2 aux) cases in + buildSeqStmt s (Case {switch;cases=[]}) | Invoke (mod_loc, id, args) -> - let+ args,s = ListM.map lower_expression args in - buildSeqStmt s (Invoke {ret_var=None;import=mod_loc;id;args}) + let+ args,s1 = ListM.map lower_expression args in + let s2 = Invoke (mk_importable mod_loc {ret_var=None;id;args}) in + buildSeqStmt s1 s2 | Return e -> begin match e with @@ -129,7 +136,7 @@ struct in - M.E.(bind (M.run aux env body) (fun (r,venv) -> pure (r,venv,tenv))) |> M.E.recover ({info=dummy_pos;stmt=Skip},snd env,tenv) + M.E.(bind (M.run aux env body) (fun (r,venv) -> pure (r,venv,tenv))) |> M.E.recover (MonoidSeq.mempty,snd env,tenv) let lower_process (p:p_in process_defn) ((env,decls),tenv: HIREnv.t * _ ) sm : (p_out * HIREnv.D.t * _) M.E.t = @@ -138,10 +145,10 @@ struct let open UseMonad(E) in let params = List.to_seq p.p_interface.p_params |> Seq.map (fun (p:param) -> p.id,p.loc) |> FieldMap.of_seq in - let locals = List.to_seq p.p_body.locals |> Seq.map (fun ((l,id),_) -> id,l ) |> FieldMap.of_seq in + let locals = List.to_seq p.p_body.locals |> Seq.map (fun (id,_) -> id.value,id.loc ) |> FieldMap.of_seq in let read,write = p.p_interface.p_shared_vars in - let read = List.to_seq read |> Seq.map (fun (l,(id,_)) -> id,l) |> FieldMap.of_seq in - let write = List.to_seq write |> Seq.map (fun (l,(id,_)) -> id,l) |> FieldMap.of_seq in + let read = List.to_seq read |> Seq.map (fun r -> fst r.value,r.loc) |> FieldMap.of_seq in + let write = List.to_seq write |> Seq.map (fun w -> fst w.value,w.loc) |> FieldMap.of_seq in let union_no_dupl = FieldMap.union (fun _k _loca _locb -> None) in @@ -153,28 +160,33 @@ struct E.throw_if Logging.(make_msg dummy_pos @@ Fmt.str "process '%s' : name conflict between params,local decls or shared variables" p.p_name) has_name_conflict >>= fun () -> - let add_locals v e = ListM.fold_left (fun e ((l,id),_) -> HIREnv.declare_var id (l,()) e) (e,decls) v >>| fst in - let add_rw r e = ListM.fold_left (fun e (l,(id,_)) -> HIREnv.declare_var id (l,()) e) (e,decls) r >>| fst in + let add_locals v e = ListM.fold_left (fun e (id,_) -> HIREnv.declare_var id.value (id.loc,()) e) (e,decls) v >>| fst in + let add_rw r e = ListM.fold_left (fun e rw -> HIREnv.declare_var (fst rw.value) (rw.loc,()) e) (e,decls) r >>| fst in let* env = add_locals p.p_body.locals env >>= add_rw (fst p.p_interface.p_shared_vars) >>= add_rw (snd p.p_interface.p_shared_vars) in - let* init,decls,tenv = lower_method (p.p_body.init,()) ((env,decls),tenv) sm |> E.recover ({info=dummy_pos;stmt=Skip},decls,tenv) in + let* init,decls,tenv = lower_method (p.p_body.init,()) ((env,decls),tenv) sm |> E.recover (MonoidSeq.mempty,decls,tenv) in - let* (proc_init,_),decls = F.ListM.map (fun ((l,p): loc * _ proc_init) -> + let* (proc_init,_),decls = F.ListM.map (fun (p : _ proc_init locatable) -> let open UseMonad(M) in - let+ params = F.ListM.map lower_expression p.params in l,{p with params} + let+ params = F.ListM.map lower_expression p.value.params in + mk_locatable p.loc {p.value with params} ) p.p_body.proc_init (env,decls) |> M.ECS.run in let+ loop = let process_cond = function None -> return None | Some c -> let+ (cond,_),_ = lower_expression c (env,decls) |> M.ECS.run in Some cond in - let rec aux (l,s) = match s with - | Statement (s,cond) -> let+ s,_,_ = lower_method (s,()) ((env,decls),tenv) sm and* cond = process_cond cond in l,Statement (s,cond) - | Run (proc,cond) -> let+ cond = process_cond cond in l,Run (proc,cond) + let rec aux (stmt :(m_in, AstParser.expression) SailParser.AstParser.p_statement) = match stmt.value with + | Statement (s,cond) -> + let+ s,_,_ = lower_method (s,()) ((env,decls),tenv) sm + and* cond = process_cond cond in + mk_locatable stmt.loc (Statement (s,cond)) + + | Run (proc,cond) -> let+ cond = process_cond cond in mk_locatable stmt.loc @@ Run (proc,cond) | PGroup g -> let* cond = process_cond g.cond in let+ children = ListM.map aux g.children in - l,PGroup {g with cond ; children} + mk_locatable stmt.loc @@ PGroup {g with cond ; children} in aux p.p_body.loop in {p.p_body with init ; proc_init; loop},decls,tenv diff --git a/src/passes/ir/sailHir/hirAst.ml b/src/passes/ir/sailHir/hirAst.ml new file mode 100644 index 0000000..47379de --- /dev/null +++ b/src/passes/ir/sailHir/hirAst.ml @@ -0,0 +1,123 @@ +(**************************************************************************) +(* *) +(* SAIL *) +(* *) +(* Frédéric Dabrowski, LMV, Orléans University *) +(* *) +(* Copyright (C) 2022 Frédéric Dabrowski *) +(* *) +(* This program is free software: you can redistribute it and/or modify *) +(* it under the terms of the GNU General Public License as published by *) +(* the Free Software Foundation, either version 3 of the License, or *) +(* (at your option) any later version. *) +(* *) +(* This program is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU General Public License for more details. *) +(* *) +(* You should have received a copy of the GNU General Public License *) +(* along with this program. If not, see . *) +(**************************************************************************) + +open Common.TypesCommon + +type 'e struct_read = {field: l_str; strct: 'e} +type 'e struct_alloc = {name:l_str; fields: 'e locatable dict} +type 'e fcall = {ret_var : string option; id: l_str; args:'e list} + +(* based on https://icfp23.sigplan.org/details/ocaml-2023-papers/4/Modern-DSL-compiler-architecture-in-OCaml-our-experience-with-Catala *) + +type ('tag,'kind) generic_exp = ('tag, ('tag,'kind,'kind) relaxed_exp ) tagged_node +and ('tag,'deep_kind,'shallow_kind) relaxed_exp = + | Variable : string -> ('tag,'deep_kind,'shallow_kind) relaxed_exp + + | Deref : ('tag,'deep_kind) generic_exp -> ('tag,'deep_kind,'shallow_kind) relaxed_exp + + | Ref : bool * ('tag,'deep_kind) generic_exp -> ('tag,'deep_kind,'shallow_kind) relaxed_exp + + | Literal : literal -> ('tag,'deep_kind,'shallow_kind) relaxed_exp + + | UnOp : unOp * ('tag,'deep_kind) generic_exp -> ('tag,'deep_kind,'shallow_kind) relaxed_exp + + | BinOp : {left: ('tag,'deep_kind) generic_exp; right:('tag,'deep_kind) generic_exp; op: binOp} -> ('tag,'deep_kind,'shallow_kind) relaxed_exp + + | ArrayRead : {array: ('tag,'deep_kind) generic_exp ; idx: ('tag,'deep_kind) generic_exp} -> ('tag,'deep_kind,'shallow_kind) relaxed_exp + + | ArrayStatic : ('tag,'deep_kind) generic_exp list -> ('tag,'deep_kind,'shallow_kind) relaxed_exp + + | MethodCall : (l_str option, ('tag,'deep_kind) generic_exp fcall) importable -> ('tag, 'deep_kind,[> `MethodCall]) relaxed_exp + + | StructRead : (l_str option, ('tag,'deep_kind) generic_exp struct_read) importable -> ('tag,'deep_kind, [> `UnresolvedImports]) relaxed_exp + + | StructRead2 : (l_str, ('tag,'deep_kind) generic_exp struct_read) importable -> ('tag, 'deep_kind,[> `ResolvedImports]) relaxed_exp + + | StructAlloc : (l_str option, ('tag,'deep_kind) generic_exp struct_alloc) importable -> ('tag, 'deep_kind,[> `UnresolvedImports]) relaxed_exp + + | StructAlloc2 : (l_str, ('tag,'deep_kind) generic_exp struct_alloc) importable -> ('tag, 'deep_kind, [> `ResolvedImports]) relaxed_exp + + | EnumAlloc : l_str * ('tag,'deep_kind) generic_exp list -> ('tag,'deep_kind,'shallow_kind) relaxed_exp + + +type ('stag,'etag, 'kind) generic_stmt = ('stag, ('stag,'etag,'kind,'kind) relaxed_stmt ) tagged_node and ('stag,'etag,'deep_kind,'shallow_kind) relaxed_stmt = +| DeclVar : {mut:bool; id:string; ty: sailtype option; value: ('etag,'deep_kind) generic_exp option} -> ('stag,'etag,'deep_kind, [> `UntypedDeclVar]) relaxed_stmt + +| DeclVar2 : {mut:bool; id:string; ty: sailtype} -> ('stag,'etag,'deep_kind,[> `TypedDeclVar] ) relaxed_stmt + +| Invoke : (l_str option, ('etag,'deep_kind) generic_exp fcall) importable -> ('stag,'etag,'deep_kind,[> `UnresolvedImports]) relaxed_stmt + +| Invoke2 : (l_str, ('etag,'deep_kind) generic_exp fcall) importable -> ('stag,'etag,'deep_kind,[> `ResolvedImports]) relaxed_stmt + +| Skip : ('stag,'etag,'deep_kind,'shallow_kind) relaxed_stmt +| Assign : {path: ('etag,'deep_kind) generic_exp ; value: ('etag,'deep_kind) generic_exp} -> ('stag,'etag,'deep_kind,'shallow_kind) relaxed_stmt + +| Seq : ('stag,'etag,'deep_kind) generic_stmt * ('stag,'etag,'deep_kind) generic_stmt -> ('stag,'etag,'deep_kind,'shallow_kind) relaxed_stmt + +| If : + { + cond: ('etag,'deep_kind) generic_exp; + then_:('stag,'etag,'deep_kind) generic_stmt ; + else_: ('stag,'etag,'deep_kind) generic_stmt option + } -> ('stag,'etag,'deep_kind,'shallow_kind) relaxed_stmt + +| Loop : ('stag,'etag,'deep_kind) generic_stmt -> ('stag,'etag,'deep_kind,'shallow_kind) relaxed_stmt + +| Break : ('stag,'etag,'deep_kind,'shallow_kind) relaxed_stmt + +| Case : {switch : ('etag,'deep_kind) generic_exp ; cases : (string * string list * ('stag,'etag,'deep_kind) generic_stmt) list} -> ('stag,'etag,'deep_kind,'shallow_kind) relaxed_stmt + +| Return : ('etag,'deep_kind) generic_exp option -> ('stag,'etag,'deep_kind,'shallow_kind) relaxed_stmt + +| Block : ('stag,'etag,'deep_kind) generic_stmt -> ('stag,'etag,'deep_kind,'shallow_kind) relaxed_stmt + + +let buildExp tag node : _ generic_exp = {tag;node} +let buildStmt tag node : _ generic_stmt = {tag;node} +let buildSeq tag s1 s2 = buildStmt tag (Seq (s1, s2)) +let buildSeqStmt tag s1 s2 = buildSeq tag s1 @@ buildStmt tag s2 + + +module Syntax = struct + + let skip : unit -> (loc,'a,'b) generic_stmt = fun () -> buildStmt dummy_pos Skip + + let (=) path value = buildStmt dummy_pos (Assign {path;value}) + + let var loc id ty value = buildStmt loc (DeclVar {mut=true;id;ty=Some ty;value}) + + let true_ : unit -> (_,_) generic_exp = fun () -> buildExp dummy_pos (Literal (LBool true)) + let false_ : unit -> (_,_) generic_exp = fun () -> buildExp dummy_pos (Literal (LBool false)) + + let (+) = fun left right -> buildExp dummy_pos (BinOp {op=Plus;left;right}) + let (%) = fun left right -> buildExp dummy_pos (BinOp {op=Rem;left;right}) + let (==) = fun left right -> buildExp dummy_pos (BinOp {op=Eq;left;right}) + + let (&&) = fun s1 s2 -> buildStmt dummy_pos (Seq (s1,s2)) + let (!@) = fun id -> buildExp dummy_pos (Variable id) + let (!) = fun n -> buildExp dummy_pos (Literal (LInt {l=Z.of_int n; size=32})) + let (!!) = fun b -> buildStmt dummy_pos (Block b) + + let if_ cond then_ else_ = + let else_ = match else_.node with Skip -> None | node -> Some {else_ with node} in + buildStmt dummy_pos (If {cond;then_;else_}) +end \ No newline at end of file diff --git a/src/passes/ir/sailHir/hirMonad.ml b/src/passes/ir/sailHir/hirMonad.ml index 53ac723..714f3b3 100644 --- a/src/passes/ir/sailHir/hirMonad.ml +++ b/src/passes/ir/sailHir/hirMonad.ml @@ -5,7 +5,10 @@ module Make(MonoidSeq : Monad.Monoid) = struct type t = unit let string_of_var _ = "" let param_to_var _ = () - end + end + + module MonoidSeq = MonoidSeq + module HIREnv = SailModule.SailEnv(V) diff --git a/src/passes/ir/sailHir/hirUtils.ml b/src/passes/ir/sailHir/hirUtils.ml index 4ce2e08..a1589c9 100644 --- a/src/passes/ir/sailHir/hirUtils.ml +++ b/src/passes/ir/sailHir/hirUtils.ml @@ -2,102 +2,111 @@ open Common open TypesCommon open Monad open SailParser +module Ast = HirAst -type expression = (loc,l_str option) AstHir.expression -type statement = (loc,l_str option,expression) AstHir.statement +type expression = (loc,[`UnresolvedImports | `MethodCall| `UntypedDeclVar]) Ast.generic_exp +type statement = (loc,loc,[`UnresolvedImports | `MethodCall | `UntypedDeclVar]) Ast.generic_stmt module M = HirMonad.Make( struct type t = statement - let mempty : t = {info=dummy_pos; stmt=Skip} - let mconcat : t -> t -> t = fun x y -> {info=dummy_pos; stmt=Seq (x,y)} + let mempty = Ast.buildStmt dummy_pos Skip + let mconcat = fun x y -> Ast.buildStmt dummy_pos (Seq (x,y)) end ) open M module D = SailModule.DeclEnv -let lower_expression (e : AstParser.expression) : expression M.t = +let lower_expression (exp : AstParser.expression) : expression M.t = let open UseMonad(M) in - let rec aux (info,e : AstParser.expression) : expression M.t = - let open AstHir in - match e with - | Variable id -> - let* v = (M.ECS.find_var id |> M.lift) in - M.throw_if_none Logging.(make_msg info @@ Fmt.str "undeclared variable '%s'" id) v >>| fun _ -> - {info; exp=Variable id} + let rec aux (exp : AstParser.expression) : expression M.t = + let open Ast in + let buildExp = buildExp exp.loc in + match exp.value with + | Variable id -> + let* v = (M.ECS.find_var id |> M.lift) in + M.throw_if_none Logging.(make_msg exp.loc @@ Fmt.str "undeclared variable '%s'" id) v >>| fun _ -> + buildExp (Variable id) - | Deref e -> - let+ e = aux e in {info;exp=Deref e} + | Deref e -> + let+ e = aux e in buildExp (Deref e) - | StructRead (e, id) -> - let+ e = aux e in {info; exp=StructRead (None, e, id)} + | StructRead (e, field) -> + let+ strct = aux e in buildExp (StructRead {import=None;value={field;strct}}) - | ArrayRead (e1, e2) -> - let* e1 = aux e1 in - let+ e2 = aux e2 in - {info; exp=ArrayRead(e1,e2)} - | Literal l -> return {info; exp=Literal l} + | ArrayRead (array, idx) -> + let+ array = aux array + and* idx = aux idx in + buildExp (ArrayRead {array;idx}) - | UnOp (op, e) -> - let+ e = aux e in {info;exp=UnOp (op, e)} + | Literal l -> return (buildExp (Literal l)) - | BinOp(op,e1,e2)-> - let* e1 = aux e1 in - let+ e2 = aux e2 in - {info; exp=BinOp (op, e1, e2)} + | UnOp (op, e) -> + let+ e = aux e in buildExp (UnOp (op, e)) - | Ref (b, e) -> - let+ e = aux e in {info;exp=Ref(b, e)} + | BinOp(op,left,right)-> + let+ left = aux left + and* right = aux right in + buildExp (BinOp {op;left;right}) - | ArrayStatic el -> - let+ el = ListM.map aux el in {info;exp=ArrayStatic el} + | Ref (b, e) -> + let+ e = aux e in + buildExp (Ref (b,e)) - | StructAlloc (origin,id, m) -> - let m' = List.sort_uniq (fun (id1,_) (id2,_) -> String.compare id1 id2) m in - let* () = M.throw_if Logging.(make_msg info "duplicate fields") List.(length m <> length m') in - let+ m' = ListM.map (aux |> pairMap2 |> pairMap2) m' in - {info; exp=StructAlloc (origin, id, m')} + | ArrayStatic el -> + let+ el = ListM.map aux el in + buildExp (ArrayStatic el) - | EnumAlloc (id, el) -> - let+ el = ListM.map aux el in {info;exp=EnumAlloc (id, el)} + | StructAlloc (import,name,fields) -> + let fields' = List.sort_uniq (fun (id1,_) (id2,_) -> String.compare id1 id2) fields in + let* () = M.throw_if Logging.(make_msg exp.loc "duplicate fields") List.(length fields <> length fields') in + let+ fields = ListM.map (pairMap2 (fun e -> let+ value = aux e.value in {e with value})) fields' in + buildExp (StructAlloc (mk_importable import {name;fields})) - - | MethodCall (import, id, el) -> let+ el = ListM.map aux el in {info ; exp=MethodCall(id, import, el)} - in aux e + | EnumAlloc (id, el) -> + let+ el = ListM.map aux el in + buildExp (EnumAlloc (id, el)) + + + | MethodCall (import, id, args) -> + let+ args = ListM.map aux args in + buildExp (MethodCall (mk_importable import {id;args;ret_var=None})) + + in aux exp open UseMonad(M.E) -let find_symbol_source ?(filt = [E (); S (); T ()] ) (loc,id: l_str) (import : l_str option) (env : D.t) : (l_str * D.decls) M.E.t = +let find_symbol_source ?(filt = [E (); S (); T ()] ) (id: l_str) (import : l_str option) (env : D.t) : (l_str * D.decls) M.E.t = match import with - | Some (iloc,name) -> - if name = Constants.sail_module_self || name = D.get_name env then + | Some iname -> + if iname.value = Constants.sail_module_self || iname.value = D.get_name env then let+ decl = - D.find_decl id (Self (Filter filt)) env - |> M.E.throw_if_none Logging.(make_msg loc @@ "no declaration named '" ^ id ^ "' in current module ") + D.find_decl id.value (Self (Filter filt)) env + |> M.E.throw_if_none Logging.(make_msg id.loc @@ "no declaration named '" ^ id.value ^ "' in current module ") in - (iloc,D.get_name env),decl + {iname with value=D.get_name env},decl else let+ t = - M.E.throw_if_none Logging.(make_msg iloc ~hint:(Some (None,Fmt.str "try importing the module with 'import %s'" name)) @@ "unknown module " ^ name) - (List.find_opt (fun {mname;_} -> mname = name) (D.get_imports env)) >>= fun _ -> - M.E.throw_if_none Logging.(make_msg loc @@ "declaration " ^ id ^ " not found in module " ^ name) - (D.find_decl id (Specific (name, Filter filt)) env) + M.E.throw_if_none Logging.(make_msg iname.loc ~hint:(Some (None,Fmt.str "try importing the module with 'import %s'" iname.value)) @@ "unknown module " ^ iname.value) + (List.find_opt (fun {mname;_} -> mname = iname.value) (D.get_imports env)) >>= fun _ -> + M.E.throw_if_none Logging.(make_msg id.loc @@ "declaration " ^ id.value ^ " not found in module " ^ iname.value) + (D.find_decl id.value (Specific (iname.value, Filter filt)) env) in - (iloc,name),t + iname,t | None -> (* find it ourselves *) begin - let decl = D.find_decl id (All (Filter filt)) env in + let decl = D.find_decl id.value (All (Filter filt)) env in match decl with | [i,m] -> (* Logs.debug (fun m -> m "'%s' is from %s" id i.mname); *) - return ((dummy_pos,i.mname),m) + return (mk_locatable dummy_pos i.mname,m) - | [] -> M.E.throw Logging.(make_msg loc @@ "unknown declaration " ^ id) + | [] -> M.E.throw Logging.(make_msg id.loc @@ "unknown declaration " ^ id.value) | _ as choice -> M.E.throw - @@ Logging.make_msg loc ~hint:(Some (None,"specify one with '::' annotation")) - @@ Fmt.str "multiple definitions for declaration %s : \n\t%s" id + @@ Logging.make_msg id.loc ~hint:(Some (None,"specify one with '::' annotation")) + @@ Fmt.str "multiple definitions for declaration %s : \n\t%s" id .value (List.map (fun (i,def) -> match def with T def -> Fmt.str "from %s : %s" i.mname (string_of_sailtype (def.ty)) | _ -> "") choice |> String.concat "\n\t") end @@ -106,10 +115,10 @@ let follow_type ty env : (sailtype * D.t) M.E.t = (* Logs.debug (fun m -> m "following type '%s'" (string_of_sailtype (Some ty))); *) - let rec aux (l_ty,ty') path : (sailtype * ty_defn list) M.E.t = + let rec aux ty path : (sailtype * ty_defn list) M.E.t = (* Logs.debug (fun m -> m "path: %s" (List.map (fun ({name;_}:ty_defn) -> name)path |> String.concat " ")); *) let+ (t,path : sailtype_ * ty_defn list) = - match ty' with + match ty.value with | ArrayType (t,n) -> let+ t,path = aux t path in ArrayType (t,n),path | Box t -> let+ t,path = aux t path in Box t,path @@ -118,11 +127,11 @@ let follow_type ty env : (sailtype * D.t) M.E.t = (* Logs.debug (fun m -> m "'%s' resolves to '%s'" (string_of_sailtype (Some ty)) (string_of_sailtype (Some ty'))); *) return (t,path) | CompoundType {origin;name=id;generic_instances;_} -> (* compound type, find location and definition *) - let* (l,origin),def = find_symbol_source id origin env in - let default = fun ty -> CompoundType {origin=Some (l,origin);name=id; generic_instances;decl_ty=Some ty} in + let* origin,def = find_symbol_source id origin env in + let default = fun ty -> CompoundType {origin=Some origin;name=id; generic_instances;decl_ty=Some ty} in begin match def with - | T def when origin=current -> + | T def when origin.value = current -> begin match def.ty with | Some ty -> ( @@ -133,7 +142,7 @@ let follow_type ty env : (sailtype * D.t) M.E.t = ) (List.find_opt (fun (d:ty_defn) -> d.name = def.name) path) - >>= fun () -> let+ ((_,t),p) = aux ty (def::path) in t,p + >>= fun () -> let+ (t,p) = aux ty (def::path) in t.value,p ) | None -> (* abstract type *) (* Logs.debug (fun m -> m "'%s' resolves to abstract type '%s' " (string_of_sailtype (Some ty)) def.name); *) @@ -141,7 +150,7 @@ let follow_type ty env : (sailtype * D.t) M.E.t = end | _ -> return (default @@ unit_decl_of_decl def,path) (* must point to an enum or struct, nothing to resolve *) end - in (l_ty,t),path + in {ty with value=t},path in let+ res,p = aux ty [] in (* p only contains type_def from the current module *) @@ -161,12 +170,12 @@ let check_non_cyclic_struct (name:string) (l,proto) env : unit M.E.t = (List.mem id checked) in let checked = id::checked in ListM.iter ( - fun (_,(l,t,_)) -> match snd t with + fun (_,{value=(ty,_);_}) -> match ty.value with - | CompoundType {name=_,name;origin=Some (_,origin); decl_ty = Some S ();_} -> + | CompoundType {name;origin=Some origin; decl_ty = Some S ();_} -> begin - match D.find_decl name (Specific (origin,(Filter [S ()]))) env with - | Some (S (_,d)) -> aux name l d checked + match D.find_decl name.value (Specific (origin.value,(Filter [S ()]))) env with + | Some (S (_,d)) -> aux name.value l d checked | _ -> failwith "invariant : all compound types must have a correct origin and type at this step" end | CompoundType {origin=None;decl_ty=None;_} -> M.E.throw Logging.(make_msg l "follow type not called") @@ -175,62 +184,87 @@ let check_non_cyclic_struct (name:string) (l,proto) env : unit M.E.t = in aux name l proto [] -let rename_var_exp (f: string -> string) (e: _ AstHir.expression) = - let open AstHir in - let rec aux (e : _ expression) = - let buildExp = buildExp e.info in - match e.exp with + +let rename_var_exp (f: string -> string) (e: expression) = + let open Ast in + let rec aux (e : expression) = + let buildExp = buildExp (e.tag : loc) in + match e.node with | Variable id -> buildExp @@ Variable (f id) + | Deref e -> let e = aux e in buildExp @@ Deref e - | StructRead (mod_loc,e, id) -> let e = aux e in buildExp @@ StructRead(mod_loc,e,id) - | ArrayRead (e1, e2) -> - let e1 = aux e1 in - let e2 = aux e2 in - buildExp @@ ArrayRead (e1,e2) + + | StructRead s -> let strct = aux s.value.strct in buildExp @@ StructRead {s with value={s.value with strct}} + + | ArrayRead a -> + let idx = aux a.idx in + let array = aux a.array in + buildExp @@ ArrayRead {idx;array} + | Literal _ as l -> buildExp l + | UnOp (op, e) -> let e = aux e in buildExp @@ UnOp (op,e) - | BinOp(op,e1,e2)-> - let e1 = aux e1 in - let e2 = aux e2 in - buildExp @@ BinOp(op,e1,e2) + + | BinOp bop -> + buildExp @@ BinOp {bop with left=aux bop.left; right=aux bop.right} + | Ref (b, e) -> let e = aux e in buildExp @@ Ref(b,e) + | ArrayStatic el -> let el = List.map aux el in buildExp @@ ArrayStatic el - | StructAlloc (origin,id, m) -> let m = List.map (fun (n,(l,e)) -> n,(l,aux e)) m in buildExp @@ StructAlloc (origin,id,m) + + | StructAlloc s -> let fields = List.map (fun (n,(e: _ locatable)) -> n,mk_locatable e.loc (aux e.value)) s.value.fields in + buildExp @@ StructAlloc {s with value={s.value with fields}} + | EnumAlloc (id, el) -> let el = List.map aux el in buildExp @@ EnumAlloc (id,el) - | MethodCall (mod_loc, id, el) -> let el = List.map aux el in buildExp @@ MethodCall (mod_loc,id,el) + + | MethodCall m -> let args = List.map aux m.value.args in buildExp @@ MethodCall {m with value={m.value with args}} in aux e let rename_var_stmt (f:string -> string) s = - let open AstHir in - let rec aux (s : _ statement) = - let buildStmt = buildStmt s.info in - match s.stmt with - | DeclVar (mut, id, opt_t,opt_exp) -> - let e = MonadOption.M.fmap (rename_var_exp f) opt_exp in - buildStmt @@ DeclVar (mut,f id,opt_t,e) - | Assign(e1, e2) -> - let e1 = rename_var_exp f e1 - and e2 = rename_var_exp f e2 in - buildStmt @@ Assign(e1, e2) + let open Ast in + let rec aux (s : statement) = + let buildStmt = buildStmt s.tag in + match s.node with + | DeclVar v -> + let value = MonadOption.M.fmap (rename_var_exp f) v.value in + let id = f v.id in + buildStmt @@ DeclVar {v with value; id} + + | Assign a -> + let path = rename_var_exp f a.path + and value = rename_var_exp f a.value in + buildStmt @@ Assign {path;value} + | Seq(c1, c2) -> let c1 = aux c1 in let c2 = aux c2 in buildStmt @@ Seq(c1, c2) - | If(cond_exp, then_s, else_s) -> - let cond_exp = rename_var_exp f cond_exp in - let then_s = aux then_s in - let else_s = MonadOption.M.fmap aux else_s in - buildStmt (If(cond_exp, then_s, else_s)) + + | If if_ -> + let cond = rename_var_exp f if_.cond in + let then_ = aux if_.then_ in + let else_ = MonadOption.M.fmap aux if_.else_ in + buildStmt (If {cond;then_;else_}) + | Loop c -> let c = aux c in buildStmt (Loop c) + | Break -> buildStmt Break - | Case(e, _cases) -> let e = rename_var_exp f e in buildStmt (Case (e, [])) + + | Case c -> + let switch = rename_var_exp f c.switch in + let cases = List.map (fun (s,s2,s3) -> s,s2,aux s3) c.cases in + buildStmt (Case {switch;cases}) + | Invoke i -> - let args = List.map (rename_var_exp f) i.args in - let ret_var = MonadOption.M.fmap f i.ret_var in - buildStmt @@ Invoke {i with ret_var;args} + let args = List.map (rename_var_exp f) i.value.args in + let ret_var = MonadOption.M.fmap f i.value.ret_var in + buildStmt @@ Invoke {i with value = {i.value with ret_var;args}} + | Return e -> let e = MonadOption.M.fmap (rename_var_exp f) e in buildStmt @@ Return e + | Block c -> let c = aux c in buildStmt (Block c) + | Skip -> buildStmt Skip in aux s @@ -275,11 +309,11 @@ let resolve_names (sm : ('a,'b) SailModule.methods_processes SailModule.t) = let* () = SEnv.iter ( fun (id,(l,{fields; generics})) -> let* fields = ListM.map ( - fun (name,(l,t,n)) -> + fun (name,({value=(t,i);_} as f)) -> let* env = ES.get_env in let* t,env = (follow_type t env) |> ES.S.lift in let+ () = ES.set_env env in - name,(l,t,n) + name,{f with value=t,i} ) fields in let proto = l,{fields;generics} in @@ -304,7 +338,7 @@ let resolve_names (sm : ('a,'b) SailModule.methods_processes SailModule.t) = ) m_proto.params in let m = {m with m_proto={m_proto with params; rtype}} in let true_name = (match m_body with Left (sname,_) -> sname | Right _ -> m_proto.name) in - let+ () = ES.update_env (update_decl m_proto.name ((m_proto.pos,true_name), defn_to_proto (Method m)) (Self Method)) + let+ () = ES.update_env (update_decl m_proto.name (mk_locatable m_proto.pos true_name, defn_to_proto (Method m)) (Self Method)) in m ) sm.body.methods in diff --git a/src/passes/ir/sailHir/pp_hir.ml b/src/passes/ir/sailHir/pp_hir.ml index 05b04f3..79c3d1d 100644 --- a/src/passes/ir/sailHir/pp_hir.ml +++ b/src/passes/ir/sailHir/pp_hir.ml @@ -1,57 +1,59 @@ open Common open PpCommon open Format -open AstHir +open HirAst open Hir - +open HirUtils +open TypesCommon let rec ppPrintExpression (pf : Format.formatter) (e : expression) : unit = let open Format in - match e.exp with + match e.node with | Variable s -> fprintf pf "%s" s | Deref e -> fprintf pf "*%a" ppPrintExpression e - | StructRead (_,e, (_,s)) -> fprintf pf "%a.%s" ppPrintExpression e s - | ArrayRead (e1, e2) -> fprintf pf "%a[%a]" ppPrintExpression e1 ppPrintExpression e2 + | StructRead st -> fprintf pf "%a.%s" ppPrintExpression st.value.strct st.value.field.value + | ArrayRead ar -> fprintf pf "%a[%a]" ppPrintExpression ar.array ppPrintExpression ar.idx | Literal (l) -> fprintf pf "%a" PpCommon.pp_literal l | UnOp (o, e) -> fprintf pf "%a %a" pp_unop o ppPrintExpression e - | BinOp ( o, e1, e2) -> fprintf pf "%a %a %a" ppPrintExpression e1 pp_binop o ppPrintExpression e2 + | BinOp bop -> fprintf pf "%a %a %a" ppPrintExpression bop.left pp_binop bop.op ppPrintExpression bop.right | Ref (true,e) -> fprintf pf "&mut %a" ppPrintExpression e | Ref (false,e) -> fprintf pf "&%a" ppPrintExpression e | ArrayStatic el -> fprintf pf "[%a]" (pp_print_list ~pp_sep:pp_comma ppPrintExpression) el - |StructAlloc (_,id, m) -> - let pp_field pf (x, (_,y)) = fprintf pf "%s:%a" x ppPrintExpression y in - fprintf pf "%s{%a}" (snd id) + | StructAlloc st -> + let pp_field pf (x, (y: _ locatable)) = fprintf pf "%s:%a" x ppPrintExpression y.value in + fprintf pf "%s{%a}" st.value.name.value (pp_print_list ~pp_sep:pp_comma pp_field) - m + st.value.fields | EnumAlloc (id,el) -> - fprintf pf "[%s(%a)]" (snd id) + fprintf pf "[%s(%a)]" id.value (pp_print_list ~pp_sep:pp_comma ppPrintExpression) el - | MethodCall ((_,id),mod_loc,el) -> + | MethodCall m -> fprintf pf "%a%s(%a)" - (pp_print_option (fun fmt (_,ml) -> fprintf fmt "%s::" ml)) mod_loc - id - (pp_print_list ~pp_sep:pp_comma ppPrintExpression) el + (pp_print_option (fun fmt ml -> fprintf fmt "%s::" ml.value)) m.import + m.value.id.value + (pp_print_list ~pp_sep:pp_comma ppPrintExpression) m.value.args + +let rec ppPrintStatement (pf : Format.formatter) (s : statement) : unit = match s.node with +| DeclVar d -> fprintf pf "\nvar %s%a%a;" d.id + (pp_print_option (fun fmt -> fprintf fmt " : %a" pp_type)) d.ty + (pp_print_option (fun fmt -> fprintf fmt " = %a" ppPrintExpression)) d.value -let rec ppPrintStatement (pf : Format.formatter) (s : statement) : unit = match s.stmt with -| DeclVar (_mut, id, opt_t,opt_exp) -> fprintf pf "\nvar %s%a%a;" id - (pp_print_option (fun fmt -> fprintf fmt " : %a" pp_type)) opt_t - (pp_print_option (fun fmt -> fprintf fmt " = %a" ppPrintExpression)) opt_exp -| Assign(e1, e2) -> fprintf pf "\n%a = %a;" ppPrintExpression e1 ppPrintExpression e2 +| Assign a -> fprintf pf "\n%a = %a;" ppPrintExpression a.path ppPrintExpression a.value | Seq(c1, c2) -> fprintf pf "%a%a" ppPrintStatement c1 ppPrintStatement c2 -| If(cond_exp, then_s,else_s) -> fprintf pf "\nif (%a) {\n%a\n}\n%a" - ppPrintExpression cond_exp - ppPrintStatement then_s - (pp_print_option (fun pf -> fprintf pf "else {%a\n}" ppPrintStatement)) else_s +| If if_ -> fprintf pf "\nif (%a) {\n%a\n}\n%a" + ppPrintExpression if_.cond + ppPrintStatement if_.then_ + (pp_print_option (fun pf -> fprintf pf "else {%a\n}" ppPrintStatement)) if_.else_ | Loop c -> fprintf pf "\nloop {%a\n}" ppPrintStatement c | Break -> fprintf pf "break;" -| Case(_e, _cases) -> () +| Case _ -> () | Invoke i -> fprintf pf "\n%a%a%s(%a);" - (pp_print_option (fun fmt v -> fprintf fmt "%s = " v)) i.ret_var - (pp_print_option (fun fmt (_,ml) -> fprintf fmt "%s::" ml)) i.import - (snd i.id) - (pp_print_list ~pp_sep:pp_comma ppPrintExpression) i.args + (pp_print_option (fun fmt v -> fprintf fmt "%s = " v)) i.value.ret_var + (pp_print_option (fun fmt ml -> fprintf fmt "%s::" ml.value)) i.import + i.value.id.value + (pp_print_list ~pp_sep:pp_comma ppPrintExpression) i.value.args | Return e -> fprintf pf "\nreturn %a;" (pp_print_option ppPrintExpression) e | Block c -> fprintf pf "\n{\n@[ %a @]\n}" ppPrintStatement c | Skip -> () diff --git a/src/passes/ir/sailMir/mir.ml b/src/passes/ir/sailMir/mir.ml index 23f0387..fd2aa1f 100644 --- a/src/passes/ir/sailMir/mir.ml +++ b/src/passes/ir/sailMir/mir.ml @@ -1,4 +1,4 @@ -open AstMir +open MirAst open Common open TypesCommon open Monad @@ -16,45 +16,49 @@ module Pass = MakeFunctionPass(V)( struct let name = "MIR" - type m_in = Thir.statement + type m_in = ThirUtils.statement type m_out = mir_function type p_in = (HirUtils.statement,HirUtils.expression) AstParser.process_body type p_out = p_in - let rec lexpr (e : Thir.expression) : expression M.t = - let open AstHir in - let lt = e.info in - match e.exp with + let rec lexpr (exp : ThirUtils.expression) : expression M.t = + let open HirAst in + match exp.node with | Variable name -> let* id = find_scoped_var name in - let+ () = M.update_var (fst lt) id assign_var in buildExp lt (Variable id) + let+ () = M.update_var exp.tag.loc id assign_var in buildExp exp.tag (Variable id) | Deref e -> rexpr e - | ArrayRead (e1, e2) -> let+ e1' = lexpr e1 and* e2' = rexpr e2 in buildExp lt (ArrayRead(e1',e2')) - | StructRead (origin,e,field) -> let+ e = lexpr e in buildExp lt (StructRead (origin,e,field)) - | Ref _ -> M.error Logging.(make_msg (fst lt) "todo") - | _ -> M.error Logging.(make_msg (fst lt) @@ "thir didn't lower correctly this expression") - and rexpr (e : Thir.expression) : expression M.t = - let lt = e.info in - let open AstHir in - match e.exp with + | ArrayRead a -> let+ array = lexpr a.array and* idx = rexpr a.idx in + buildExp exp.tag (ArrayRead {array;idx}) + + | StructRead2 s -> let+ strct = lexpr s.value.strct in buildExp exp.tag (StructRead2 {s with value={s.value with strct}}) + + | Ref _ -> M.error Logging.(make_msg exp.tag.loc "todo") + + | Literal _ | UnOp _ | BinOp _ | ArrayStatic _ | StructAlloc2 _ | EnumAlloc _ -> + M.error Logging.(make_msg exp.tag.loc @@ "compiler error : not a lexpr") + + and rexpr (exp : ThirUtils.expression) : expression M.t = + let open HirAst in + match exp.node with | Variable name -> - let+ id = find_scoped_var name in buildExp lt (Variable id) - | Literal l -> buildExp lt (Literal l) |> M.pure + let+ id = find_scoped_var name in buildExp exp.tag (Variable id) + | Literal l -> buildExp exp.tag (Literal l) |> M.pure | Deref e -> lexpr e - | ArrayRead (array_exp,idx) -> let+ arr = rexpr array_exp and* idx' = rexpr idx in buildExp lt (ArrayRead(arr,idx')) - | UnOp (o, e) -> let+ e' = rexpr e in buildExp lt (UnOp (o, e')) - | BinOp (o ,e1, e2) -> let+ e1' = rexpr e1 and* e2' = rexpr e2 in buildExp lt (BinOp(o, e1', e2')) - | Ref (b, e) -> let+ e' = rexpr e in buildExp lt (Ref(b, e')) - | ArrayStatic el -> let+ el' = ListM.map rexpr el in buildExp lt (ArrayStatic el') - | StructRead (origin,struct_exp,field) -> - let+ exp = rexpr struct_exp in - buildExp lt (StructRead (origin,exp,field)) + | ArrayRead a -> let+ array = rexpr a.array and* idx = rexpr a.idx in buildExp exp.tag (ArrayRead {array;idx}) + | UnOp (o, e) -> let+ e' = rexpr e in buildExp exp.tag (UnOp (o, e')) + | BinOp bop -> let+ left = rexpr bop.left and* right = rexpr bop.right in buildExp exp.tag (BinOp {bop with left;right}) + | Ref (b, e) -> let+ e' = rexpr e in buildExp exp.tag (Ref(b, e')) + | ArrayStatic el -> let+ el' = ListM.map rexpr el in buildExp exp.tag (ArrayStatic el') + | StructRead2 s -> + let+ strct = rexpr s.value.strct in + buildExp exp.tag (StructRead2 {s with value={s.value with strct}}) - | StructAlloc (origin,id, fields) -> - let+ fields = ListM.map (rexpr |> pairMap2 |> pairMap2) fields in - buildExp lt (StructAlloc(origin,id,fields)) - | MethodCall _ - | _ -> M.error @@ Logging.(make_msg (fst lt) @@ "thir didn't lower correctly this expression") + | StructAlloc2 s -> + let+ fields = ListM.map (fun (f : string * (ThirUtils.exp_tag, [ `ResolvedImports | `TypedDeclVar ]) + generic_exp locatable) -> pairMap2 (fun f -> let+ value = rexpr f.value in {f with value}) f) s.value.fields in + buildExp exp.tag (StructAlloc2 {s with value={s.value with fields}}) + | EnumAlloc _ -> M.error @@ Logging.(make_msg exp.tag.loc @@ "compiler error : not a rexpr") @@ -62,32 +66,20 @@ struct let lower_method (body,_ : m_in * method_sig) (env,tenv) (_sm: (m_in,p_in) SailModule.methods_processes SailModule.t) : (m_out * SailModule.DeclEnv.t * _) M.E.t = - let rec aux (s : Thir.statement) : m_out M.t = + let rec aux (s : ThirUtils.statement) : m_out M.t = let open UseMonad(M) in - let loc = s.info in - match s.stmt with - | DeclVar(mut, id, Some ty, None) -> + let loc = s.tag in + match s.node with + | DeclVar2 d -> let* bb = emptyBasicBlock loc in - let* id = M.fresh_scoped_var >>| get_scoped_var id in - let+ () = M.declare_var loc id {ty;mut;id;loc} in - [{location=loc; mut; id; varType=ty}],bb - - | DeclVar(mut, id, Some ty, Some e) -> - let* id_ty = M.get_type_id ty in - let* expression = rexpr e in - let* id = M.fresh_scoped_var >>| get_scoped_var id in - let* () = M.declare_var loc id {ty;mut;id;loc} in - let target = AstHir.buildExp (loc,id_ty) (Variable id) in - let+ bn = assignBasicBlock loc {location=loc; target; expression } in - [{location=loc; mut; id=id; varType=ty}],bn - (* ++ other statements *) - - | DeclVar (_,name,None,_) -> failwith @@ "thir broken : variable declaration should have a type : " ^name + let* id = M.fresh_scoped_var >>| get_scoped_var d.id in + let+ () = M.declare_var loc id {ty=d.ty;mut=d.mut;id;loc} in + [{location=loc; mut=d.mut; id; varType=d.ty}],bb | Skip -> let+ bb = emptyBasicBlock loc in ([], bb) - | Assign (e1, e2) -> - let* expression = rexpr e2 and* target = lexpr e1 in + | Assign a -> + let* expression = rexpr a.value and* target = lexpr a.path in let+ bb = assignBasicBlock loc {location=loc; target; expression} in [],bb | Seq (s1, s2) -> @@ -97,16 +89,16 @@ struct (* let* () = M.set_env env in *) let+ bb = buildSeq cfg1 cfg2 in d1@d2,bb - | If (e, s, None) -> - let* e' = rexpr e in - let* d, cfg = aux s in - let+ ite = buildIfThen loc e' cfg in + | If ({else_=None;_} as if_) -> + let* cond = rexpr if_.cond in + let* d, cfg = aux if_.then_ in + let+ ite = buildIfThen loc cond cfg in (d,ite) - | If (e, s1, Some s2) -> - let* e' = rexpr e in - let* d1,cfg1 = aux s1 and* d2,cfg2 = aux s2 in - let+ ite = buildIfThenElse loc e' cfg1 cfg2 in + | If ({else_=Some else_;_} as if_) -> + let* cond = rexpr if_.cond in + let* d1,cfg1 = aux if_.then_ and* d2,cfg2 = aux else_ in + let+ ite = buildIfThenElse loc cond cfg1 cfg2 in (d1@d2, ite) | Loop s -> @@ -120,12 +112,12 @@ struct let+ cfg = singleBlock bb in ([],cfg) - | Invoke i -> - let* ((_,realname),_) = M.throw_if_none Logging.(make_msg loc @@ Fmt.str "Compiler Error : function '%s' must exist" (snd i.id)) - (SailModule.DeclEnv.find_decl (snd i.id) (Specific (snd i.import,Method)) (snd env)) + | Invoke2 i -> + let* (realname,_) = M.throw_if_none Logging.(make_msg loc @@ Fmt.str "Compiler Error : function '%s' must exist" i.value.id.value) + (SailModule.DeclEnv.find_decl i.value.id.value (Specific (i.import.value,Method)) (snd env)) in - let* args = ListM.map rexpr i.args in - let+ invoke = buildInvoke loc i.import (fst i.id,realname) i.ret_var args in + let* args = ListM.map rexpr i.value.args in + let+ invoke = buildInvoke loc i.import (mk_locatable i.value.id.loc realname.value) i.value.ret_var args in ([], invoke) | Return e -> diff --git a/src/passes/ir/sailMir/astMir.ml b/src/passes/ir/sailMir/mirAst.ml similarity index 96% rename from src/passes/ir/sailMir/astMir.ml rename to src/passes/ir/sailMir/mirAst.ml index 22f2b24..bc482d4 100644 --- a/src/passes/ir/sailMir/astMir.ml +++ b/src/passes/ir/sailMir/mirAst.ml @@ -31,8 +31,8 @@ type statement = Assign of lvalue * rvalue | Drop of drop_kind * lvalue *) -type expression = Thir.expression -type statement = Thir.statement +type expression = ThirUtils.expression +type statement = ThirUtils.statement type declaration = {location : loc; mut : bool; id : string; varType : sailtype} type assignment = {location : loc; target : expression; expression : expression} diff --git a/src/passes/ir/sailMir/mirMonad.ml b/src/passes/ir/sailMir/mirMonad.ml index b3af329..8c52509 100644 --- a/src/passes/ir/sailMir/mirMonad.ml +++ b/src/passes/ir/sailMir/mirMonad.ml @@ -1,4 +1,4 @@ -open AstMir +open MirAst open Common module M = struct diff --git a/src/passes/ir/sailMir/mirUtils.ml b/src/passes/ir/sailMir/mirUtils.ml index f6f755d..f209e81 100644 --- a/src/passes/ir/sailMir/mirUtils.ml +++ b/src/passes/ir/sailMir/mirUtils.ml @@ -1,9 +1,10 @@ -open AstMir +open MirAst open Common open TypesCommon open Monad open MirMonad open UseMonad(M) + let assign_var (var_l,v:VE.variable) = (var_l,v) |> M.E.pure @@ -179,7 +180,7 @@ let buildInvoke (l : loc) (origin:l_str) (id : l_str) (target : string option) ( forward_info = env; backward_info = (); location=l; - terminator = Some (Invoke {id = (snd id); origin; target; params = el; next = returnLbl}) + terminator = Some (Invoke {id=id.value; origin; target; params = el; next = returnLbl}) } in let returnBlock = {assignments = []; predecessors = LabelSet.singleton invokeLbl ; forward_info = env; backward_info = () ; location = dummy_pos; terminator = None} in { @@ -220,7 +221,7 @@ let find_scoped_var name : string M.t = let seqOfList (l : statement list) : statement = - List.fold_left (fun s l : statement -> {info=dummy_pos; stmt=Seq (s, l)}) {info=dummy_pos;stmt=Skip} l + List.fold_left (fun s l : statement -> {tag=dummy_pos; node=Seq (s, l)}) {tag=dummy_pos;node=Skip} l diff --git a/src/passes/ir/sailMir/pp_mir.ml b/src/passes/ir/sailMir/pp_mir.ml index e1f7509..b1732ae 100644 --- a/src/passes/ir/sailMir/pp_mir.ml +++ b/src/passes/ir/sailMir/pp_mir.ml @@ -1,30 +1,30 @@ open Common open PpCommon open Format -open AstMir +open MirAst +open TypesCommon -let rec ppPrintExpression (pf : Format.formatter) (e : AstMir.expression) : unit = - match e.exp with +let rec ppPrintExpression (pf : Format.formatter) (e : MirAst.expression) : unit = + match e.node with | Variable s -> fprintf pf "%s" s | Deref e -> fprintf pf "*%a" ppPrintExpression e - | StructRead (_,e, (_,s)) -> fprintf pf "%a.%s" ppPrintExpression e s - | ArrayRead (e1, e2) -> fprintf pf "%a[%a]" ppPrintExpression e1 ppPrintExpression e2 + | StructRead2 s -> fprintf pf "%a.%s" ppPrintExpression s.value.strct s.value.field.value + | ArrayRead a -> fprintf pf "%a[%a]" ppPrintExpression a.array ppPrintExpression a.idx | Literal (l) -> fprintf pf "%a" PpCommon.pp_literal l | UnOp (o, e) -> fprintf pf "%a %a" pp_unop o ppPrintExpression e - | BinOp ( o, e1, e2) -> fprintf pf "%a %a %a" ppPrintExpression e1 pp_binop o ppPrintExpression e2 + | BinOp bop -> fprintf pf "%a %a %a" ppPrintExpression bop.left pp_binop bop.op ppPrintExpression bop.right | Ref (true,e) -> fprintf pf "&mut %a" ppPrintExpression e | Ref (false,e) -> fprintf pf "&%a" ppPrintExpression e | ArrayStatic el -> Format.fprintf pf "[%a]" (Format.pp_print_list ~pp_sep:pp_comma ppPrintExpression) el - |StructAlloc (_,id, m) -> - let pp_field pf (x, (_ , y)) = Format.fprintf pf "%s:%a" x ppPrintExpression y in - Format.fprintf pf "%s{%a}" (snd id) - (Format.pp_print_list ~pp_sep:pp_comma pp_field) m + |StructAlloc2 s -> + let pp_field pf (x, (y: 'a locatable)) = Format.fprintf pf "%s:%a" x ppPrintExpression y.value in + Format.fprintf pf "%s{%a}" s.value.name.value + (Format.pp_print_list ~pp_sep:pp_comma pp_field) s.value.fields | EnumAlloc (id,el) -> - Format.fprintf pf "[%s(%a)]" (snd id) + Format.fprintf pf "[%s(%a)]" id.value (Format.pp_print_list ~pp_sep:pp_comma ppPrintExpression) el - | MethodCall _ -> () let ppPrintPredecessors (pf : Format.formatter) (preds : LabelSet.t ) : unit = if LabelSet.is_empty preds then fprintf pf "// no precedessors" @@ -40,8 +40,8 @@ let ppPrintAssignement (pf : Format.formatter) (a : assignment) : unit = let ppPrintTerminator (pf : Format.formatter) (t : terminator) : unit = match t with | Goto lbl -> fprintf pf "\t\tgoto %d;" lbl - | Invoke {id; params;next;origin=(_,mname);target} -> fprintf pf "\t\t%a%s(%a) -> [return: bb%d]" - (Format.pp_print_option (fun fmt id -> fprintf fmt "%s = %s::" id mname)) target + | Invoke {id; params;next;origin;target} -> fprintf pf "\t\t%a%s(%a) -> [return: bb%d]" + (Format.pp_print_option (fun fmt id -> fprintf fmt "%s = %s::" id origin.value)) target id (Format.pp_print_list ~pp_sep:pp_comma ppPrintExpression) params next diff --git a/src/passes/ir/sailThir/thir.ml b/src/passes/ir/sailThir/thir.ml index fa93a4c..36a7a3a 100644 --- a/src/passes/ir/sailThir/thir.ml +++ b/src/passes/ir/sailThir/thir.ml @@ -3,191 +3,193 @@ open TypesCommon open Logging open Monad open IrHir -open AstHir open SailParser +open HirAst open ThirUtils open M open UseMonad(M) module SM = SailModule - -type expression = ThirUtils.expression -type statement = ThirUtils.statement - module Pass = Pass.MakeFunctionPass(V)( struct let name = "THIR" type m_in = HirUtils.statement - type m_out = statement + type m_out = ThirUtils.statement type p_in = (m_in,HirUtils.expression) AstParser.process_body type p_out = p_in - let rec lower_lexp (e : Hir.expression) : expression M.t = - let rec aux (e:Hir.expression) : expression M.t = - let loc = e.info in match e.exp with + + let rec lower_lexp (exp : HirUtils.expression) : expression M.t = + let rec aux (exp:HirUtils.expression) : expression M.t = + + match exp.node with | Variable id -> - let* _,t = M.get_var id >>= M.throw_if_none (make_msg loc @@ Printf.sprintf "unknown variable %s" id) in + let* _,t = M.get_var id >>= M.throw_if_none (make_msg exp.tag @@ Printf.sprintf "unknown variable %s" id) in let* venv,tenv = M.get_env in let t,tenv = t tenv in let+ () = M.set_env (venv,tenv) in - buildExp (loc,t) @@ Variable id + buildExp exp.tag t @@ Variable id | Deref e -> let* e = lower_rexp e in (* return @@ Deref((l,extract_exp_loc_ty e |> snd), e) *) begin - match e.exp with - | Ref (_,r) -> return @@ buildExp r.info @@ Deref e + match e.node with + | Ref (_,r) -> return @@ buildExp r.tag.loc r.tag.ty @@ Deref e | _ -> return e end - | ArrayRead (array_exp,idx) -> - let* array_exp = aux array_exp and* idx_exp = lower_rexp idx in - let* array_ty = M.get_type_from_id (array_exp.info) - and* idx_ty = M.get_type_from_id (idx_exp.info) in + + | ArrayRead ar -> + let* array = aux ar.array and* idx = lower_rexp ar.idx in + let* array_ty = M.get_type_from_id (mk_locatable array.tag.loc array.tag.ty) and* idx_ty = M.get_type_from_id (mk_locatable idx.tag.loc idx.tag.ty) in begin - match array_ty with - | l,ArrayType (t,sz) -> + match array_ty.value with + | ArrayType (t,sz) -> let* t = M.get_type_id t in - let* _ = matchArgParam l idx_ty (dummy_pos,Int 32) |> M.ESC.lift |> M.lift in + let* _ = matchArgParam array_ty.loc idx_ty (mk_locatable dummy_pos @@ Int 32) |> M.ESC.lift |> M.lift in begin (* can do a simple oob check if the type is an int literal *) - match idx.exp with + match idx.node with | Literal (LInt n) -> - M.throw_if (make_msg (fst idx_exp.info) @@ Printf.sprintf "index out of bounds : must be between 0 and %i (got %s)" + M.throw_if (make_msg idx.tag.loc @@ Printf.sprintf "index out of bounds : must be between 0 and %i (got %s)" (sz - 1) Z.(to_string n.l) ) Z.( n.l < ~$0 || n.l > ~$sz) | _ -> return () - end >>| fun () -> buildExp (loc,t) @@ ArrayRead (array_exp,idx_exp) - | _ -> M.throw (make_msg loc "not an array !") + end >>| fun () -> buildExp exp.tag t @@ ArrayRead {array;idx} + | _ -> M.throw (make_msg exp.tag "not an array !") end - | StructRead (origin,e,(fl,field)) -> - let* e = lower_lexp e in - let* ty_e = M.get_type_from_id e.info in - let+ origin,t = + | StructRead s -> + let* strct = lower_lexp s.value.strct in + let* ty = M.get_type_from_id (mk_locatable strct.tag.loc strct.tag.ty) in + let+ import,t = begin - match ty_e with - | _,CompoundType {name=l,name;decl_ty=Some S ();_} -> - let* origin,(_,strct) = find_struct_source (l,name) origin |> M.ESC.lift |> M.lift in - let* _,t,_ = List.assoc_opt field strct.fields - |> M.throw_if_none (make_msg fl @@ Fmt.str "field '%s' is not part of structure '%s'" field name) + match ty.value with + | CompoundType {name;decl_ty=Some S ();_} -> + let* origin,(_,strct) = find_struct_source name s.import |> M.ESC.lift |> M.lift in + let* f = List.assoc_opt s.value.field.value strct.fields + |> M.throw_if_none (make_msg s.value.field.loc @@ Fmt.str "field '%s' is not part of structure '%s'" s.value.field.value name.value) in - let+ t_id = M.get_type_id t in + let+ t_id = M.get_type_id (fst f.value) in origin,t_id - | l,t -> - let* str = string_of_sailtype_thir (Some (l,t)) |> M.ESC.lift |> M.lift in - M.throw (make_msg l @@ Fmt.str "expected a structure but got type '%s'" str) + | t -> + let* str = string_of_sailtype_thir (Some (mk_locatable ty.loc t)) |> M.ESC.lift |> M.lift in + M.throw (make_msg ty.loc @@ Fmt.str "expected a structure but got type '%s'" str) end in - let x : expression = buildExp (loc,t) (StructRead (origin,e,(fl,field))) in - x - - | _ -> M.throw (make_msg loc "not a lvalue !") - - in aux e - and lower_rexp (e : Hir.expression) : expression M.t = - let rec aux (e:Hir.expression) : expression M.t = - let loc = e.info in match e.exp with + buildExp ty.loc t @@ StructRead2 (mk_importable import {field=s.value.field;strct}) + + | BinOp _ | Literal _ | UnOp _ | Ref _ | ArrayStatic _ | StructAlloc _ | EnumAlloc _ | MethodCall _ -> + M.throw (make_msg exp.tag "not a lvalue !") + + in aux exp + and lower_rexp (exp : HirUtils.expression) : expression M.t = + let rec aux (exp: HirUtils.expression) : expression M.t = + match exp.node with | Variable id -> - let* _,t = M.get_var id >>= M.throw_if_none (make_msg loc @@ Printf.sprintf "unknown variable %s" id) in + let* _,t = M.get_var id >>= M.throw_if_none (make_msg exp.tag @@ Printf.sprintf "unknown variable %s" id) in let* venv,tenv = M.get_env in let t,tenv = t tenv in let+ () = M.set_env (venv,tenv) in - buildExp (loc,t) @@ Variable id + buildExp exp.tag t @@ Variable id | Literal li -> let* () = match li with | LInt t -> - let* () = M.throw_if Logging.(make_msg loc "signed integers use a minimum of 2 bits") (t.size < 2) in + let* () = M.throw_if Logging.(make_msg exp.tag "signed integers use a minimum of 2 bits") (t.size < 2) in let max_int = Z.( ~$2 ** t.size - ~$1) in let min_int = Z.( ~-max_int + ~$1) in M.throw_if ( - make_msg loc @@ Fmt.str "type suffix can't contain int literal : i%i is between %s and %s but literal is %s" + make_msg exp.tag @@ Fmt.str "type suffix can't contain int literal : i%i is between %s and %s but literal is %s" t.size (Z.to_string min_int) (Z.to_string max_int) (Z.to_string t.l) ) Z.(lt t.l min_int || gt t.l max_int) | _ -> return () in let+ t = M.get_type_id (sailtype_of_literal li) in - buildExp (loc,t) @@ Literal li + buildExp exp.tag t @@ Literal li - | UnOp (op,e) -> let+ e = aux e in buildExp e.info @@ UnOp (op,e) + | UnOp (op,e) -> let+ e = aux e in buildExp exp.tag e.tag.ty @@ UnOp (op,e) - | BinOp (op,le,re) -> - let* le = aux le in - let* re = aux re in - let+ t = check_binop op le.info re.info |> M.recover (snd le.info) in - buildExp (loc,t) @@ BinOp (op,le,re) + | BinOp bop -> + let* left = aux bop.left in + let* right = aux bop.right in + let+ t = check_binop bop.op (mk_locatable left.tag.loc left.tag.ty) (mk_locatable right.tag.loc right.tag.ty) |> M.recover left.tag.ty in + buildExp exp.tag t @@ BinOp {op=bop.op;left;right} | Ref (mut,e) -> let* e = lower_lexp e in - let* e_t = M.get_type_from_id e.info in - let+ t = M.get_type_id (dummy_pos,RefType (e_t,mut)) in - buildExp (loc,t) @@ Ref(mut, e) + let* e_t = M.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) in + let+ t = M.get_type_id (mk_locatable dummy_pos @@ RefType (e_t,mut)) in + buildExp exp.tag t @@ Ref(mut, e) | ArrayStatic el -> let* first_t = aux (List.hd el) in - let* first_t = M.get_type_from_id first_t.info in + let* first_t = M.get_type_from_id (mk_locatable first_t.tag.loc first_t.tag.ty) in let* el = ListM.map (fun e -> let* e = aux e in - let+ e_t = M.get_type_from_id e.info in - matchArgParam (fst e.info) e_t first_t |> M.ESC.lift |> M.lift >>| fun _ -> e + let+ e_t = M.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) in + matchArgParam e.tag.loc e_t first_t |> M.ESC.lift |> M.lift >>| fun _ -> e ) el in let* el = ListM.sequence el in - let t = dummy_pos,ArrayType (first_t, List.length el) in - let+ t_id = M.get_type_id t in - buildExp (loc,t_id) (ArrayStatic el) - - | MethodCall (lid,source,args) -> - let* (args: expression list) = ListM.map lower_rexp args in - let* mod_loc,(_realname,m) = find_function_source e.info None lid source args |> M.ESC.lift |> M.lift in - let* ret = M.throw_if_none (make_msg e.info "methods in expressions should return a value") m.ret in - let* ret_t = M.get_type_id ret in - let* x = M.fresh_fvar in - M.write {info=loc; stmt=DeclVar (false, x, Some ret, None)} >>= fun () -> - M.write {info=loc; stmt=Invoke {args;id=lid; ret_var = Some x;import=mod_loc}} >>| fun () -> - buildExp (loc,ret_t) (Variable x) - - | ArrayRead _ -> lower_lexp e (* todo : some checking *) - | Deref _ -> lower_lexp e (* todo : some checking *) - | StructRead _ -> lower_lexp e (* todo : some checking *) - | StructAlloc (origin,name,fields) -> - let* origin,(_l,strct) = find_struct_source name origin |> M.ESC.lift |> M.lift in + let t = exp.tag,ArrayType (first_t, List.length el) in + let+ t_id = M.get_type_id (mk_locatable (fst t) (snd t)) in + buildExp exp.tag t_id (ArrayStatic el) + + | MethodCall mc -> + let* args = ListM.map lower_rexp mc.value.args in + let* import,(_realname,m) = find_function_source exp.tag None mc.value.id mc.import args |> M.ESC.lift |> M.lift in + let* ty = M.throw_if_none (make_msg exp.tag "methods in expressions should return a value") m.ret in + let* ty_t = M.get_type_id ty in + let* ret_var = M.fresh_fvar in + M.write {tag=exp.tag; node=DeclVar2 {mut=false;id=ret_var;ty}} >>= fun () -> + let x = {args;id=mc.value.id; ret_var = Some ret_var} in + M.write {tag=exp.tag; node=Invoke2 (mk_importable import x)} >>| fun () -> + buildExp exp.tag ty_t (Variable ret_var) + + | ArrayRead _ + | Deref _ + | StructRead _ -> lower_lexp exp (* todo : some checking *) + + | StructAlloc s -> + let* import,(_l,strct) = find_struct_source s.value.name s.import |> M.ESC.lift |> M.lift in let struct_fields = List.to_seq strct.fields in let fields = FieldMap.( merge ( fun n f1 f2 -> match f1,f2 with - | Some _, Some (l,e) -> Some (let+ e = lower_rexp e in n,(l,e)) + | Some _, Some e -> Some (let+ e = lower_rexp e.value in n,e) | None,None -> None - | None, Some (l,_) -> Some (M.throw @@ make_msg l @@ Fmt.str "no field '%s' in struct '%s'" n (snd name)) - | Some _, None -> Some (M.throw @@ make_msg loc @@ Fmt.str "missing field '%s' from struct '%s'" n (snd name)) + | None, Some l -> Some (M.throw @@ make_msg l.loc @@ Fmt.str "no field '%s' in struct '%s'" n s.value.name.value) + | Some _, None -> Some (M.throw @@ make_msg s.value.name.loc @@ Fmt.str "missing field '%s' from struct '%s'" n s.value.name.value) ) (struct_fields |> of_seq) - (fields |> List.to_seq |> of_seq) + (s.value.fields |> List.to_seq |> of_seq) |> to_seq ) in - let* () = M.throw_if (make_msg (fst name) "missing fields ") Seq.(length fields < Seq.length struct_fields) in + let* () = M.throw_if (make_msg s.value.name.loc "missing fields ") Seq.(length fields < Seq.length struct_fields) in - let* fields = SeqM.sequence (Seq.map snd fields) in + let* fields: (string * expression) Seq.t = SeqM.sequence (Seq.map snd fields) in - let* () = SeqM.iter2 (fun (_name1,(l,(e:expression))) (_name2,(_,t,_)) -> - let* e_t = M.get_type_from_id e.info in - matchArgParam l e_t t |> M.ESC.lift |> M.lift >>| fun _ -> () - ) - fields - struct_fields + let* () = + SeqM.iter2 (fun (_name1,e) (_name2,t) -> + let* e_t = M.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) in + matchArgParam e.tag.loc e_t (fst t.value) |> M.ESC.lift |> M.lift >>| fun _ -> () + ) + fields + struct_fields in - let ty = dummy_pos,CompoundType {origin= Some origin;decl_ty=Some (S ()); name; generic_instances=[]} in - let+ ty = M.get_type_id ty in - (buildExp (loc,ty) (StructAlloc (origin,name, List.of_seq fields) )) + let _fields = List.of_seq fields in + let l,ty = dummy_pos,CompoundType {origin= Some import;decl_ty=Some (S ()); name=s.value.name; generic_instances=[]} in + let+ ty = M.get_type_id (mk_locatable l ty) in + buildExp exp.tag ty @@ StructAlloc2 (mk_importable import {name=s.value.name;fields=[] (* FIXMEEEEEEEEEEEEE *)} ) - | EnumAlloc _ -> M.throw (make_msg loc "todo enum alloc ") - in aux e + | EnumAlloc _ -> M.throw (make_msg exp.tag "todo enum alloc ") + in aux exp let lower_method (body,proto : _ * method_sig) (env,tenv:THIREnv.t * _) _ : (m_out * THIREnv.D.t * _) M.E.t = @@ -196,58 +198,60 @@ struct let log_and_skip e = M.ESC.log e >>| fun () -> buildStmt e.where Skip in - let rec aux s : m_out M.ESC.t = - let loc = s.info in + let rec aux (s: m_in) : m_out M.ESC.t = + let loc = s.tag in let buildStmt = buildStmt loc in - let buildSeq s1 s2 = {info=loc; stmt = Seq (s1, s2)} in + let buildSeq s1 s2 = {tag=loc; node = Seq (s1, s2)} in let buildSeqStmt s1 s2 = buildSeq s1 @@ buildStmt s2 in - match s.stmt with - | DeclVar (mut, id, opt_t, (opt_exp : Hir.expression option)) -> - let* ((ty,opt_e,s):sailtype * 'b * 'c) = + match s.node with + | DeclVar d -> + let* ty,opt_e,s = begin - match opt_t,opt_exp with + match d.ty,d.value with | Some t, Some e -> let* e,s = lower_rexp e in - let* e_t = M.ES.get_type_from_id e.info |> M.ESC.lift in - matchArgParam (fst e.info) e_t t |> M.ESC.lift >>| fun _ -> t,Some e,s + let* e_t = M.ES.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) |> M.ESC.lift in + matchArgParam e.tag.loc e_t t |> M.ESC.lift >>| fun _ -> t,Some e,s | None,Some e -> let* e,s = lower_rexp e in - let+ e_t = M.ES.get_type_from_id e.info |> M.ESC.lift in + let+ e_t = M.ES.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) |> M.ESC.lift in e_t,Some e,s | Some t,None -> return (t,None,buildStmt Skip) | None,None -> M.ESC.throw (make_msg loc "can't infere type with no expression") end in let* ty_id = M.ES.get_type_id ty |> M.ESC.lift in - let decl_var = THIREnv.declare_var id (loc,fun e -> ty_id,e) in - M.ESC.update_env (fun (st,t) -> M.E.(bind (decl_var st) (fun st -> pure (st,t)))) - >>| fun () -> (buildSeqStmt s @@ DeclVar (mut,id,Some ty,opt_e)) + let decl_var = THIREnv.declare_var d.id (loc,fun e -> ty_id,e) in + M.ESC.update_env (fun (st,t) -> M.E.(bind (decl_var st) (fun st -> pure (st,t)))) >>| fun () -> + let s1 = buildSeqStmt s @@ DeclVar2 {mut=d.mut;ty;id=d.id;} in + let s2 = match opt_e with None -> Skip | Some value -> Assign {path=buildExp loc ty_id (Variable d.id); value} in + buildSeqStmt s1 s2 - | Assign(e1, e2) -> - let* e1,s1 = lower_lexp e1 - and* e2,s2 = lower_rexp e2 in - let* e1_t = M.ES.get_type_from_id e1.info |> M.ESC.lift - and* e2_t = M.ES.get_type_from_id e2.info |> M.ESC.lift in - matchArgParam (fst e2.info) e2_t e1_t |> M.ESC.lift >>| - fun _ -> buildSeq s1 @@ buildSeqStmt s2 @@ Assign(e1, e2) - - | Seq(c1, c2) -> + | Assign a -> + let* value,s1 = lower_rexp a.value + and* path,s2 = lower_lexp a.path in + let* value_t = M.ES.get_type_from_id (mk_locatable value.tag.loc value.tag.ty) |> M.ESC.lift + and* path_t = M.ES.get_type_from_id (mk_locatable path.tag.loc path.tag.ty) |> M.ESC.lift in + matchArgParam path.tag.loc path_t value_t |> M.ESC.lift >>| + fun _ -> buildSeq s1 @@ buildSeqStmt s2 @@ Assign {path;value} + + | Seq (c1, c2) -> let* c1 = aux c1 in let+ c2 = aux c2 in - buildStmt (Seq(c1, c2)) + buildStmt (Seq (c1, c2)) - | If(cond_exp, then_s, else_s) -> - let* cond_exp,s = lower_rexp cond_exp in - let* cond_t = M.ES.get_type_from_id cond_exp.info |> M.ESC.lift in - let* _ = matchArgParam (fst cond_exp.info) cond_t (dummy_pos,Bool) |> M.ESC.lift in - let* res = aux then_s in + | If if_ -> + let* cond,s = lower_rexp if_.cond in + let* cond_t = M.ES.get_type_from_id (mk_locatable cond.tag.loc cond.tag.ty) |> M.ESC.lift in + let* _ = matchArgParam cond.tag.loc cond_t (mk_locatable dummy_pos Bool) |> M.ESC.lift in + let* then_ = aux if_.then_ in begin - match else_s with - | None -> return @@ buildSeqStmt s (If(cond_exp, res, None)) - | Some else_ -> let+ else_ = aux else_ in buildSeqStmt s (If(cond_exp, res, Some else_)) + match if_.else_ with + | None -> return @@ buildSeqStmt s (If {cond;then_;else_=None}) + | Some else_ -> let+ else_ = aux else_ in buildSeqStmt s (If {cond;then_;else_=Some else_}) end | Loop c -> @@ -256,29 +260,29 @@ struct | Break -> return (buildStmt Break) - | Case(e, _cases) -> - let+ e,s = lower_rexp e in - buildSeqStmt s (Case (e, [])) + | Case c -> + let+ switch,s = lower_rexp c.switch in + buildSeqStmt s (Case {switch;cases=[]}) | Invoke i -> (* todo: handle var *) - let* args,s = MF.ListM.map lower_rexp i.args in - let* import,_ = find_function_source s.info i.ret_var i.id i.import args |> M.ESC.lift in - buildSeqStmt s (Invoke { i with import ; args} ) |> return + let* args,s = MF.ListM.map lower_rexp i.value.args in + let* import,_ = find_function_source s.tag i.value.ret_var i.value.id i.import args |> M.ESC.lift in + buildSeqStmt s (Invoke2 (mk_importable import {args;ret_var=i.value.ret_var;id=i.value.id} )) |> return - | Return None as r -> - if proto.rtype = None then return (buildStmt r) else + | Return None -> + if proto.rtype = None then return (buildStmt (Return None)) else log_and_skip (make_msg loc @@ Printf.sprintf "void return but %s returns %s" proto.name (string_of_sailtype proto.rtype)) | Return (Some e) -> let* e,s = lower_rexp e in - let* t = M.ES.get_type_from_id e.info |> M.ESC.lift in + let* t = M.ES.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) |> M.ESC.lift in begin match proto.rtype with | None -> log_and_skip (make_msg loc @@ Printf.sprintf "returns %s but %s doesn't return anything" (string_of_sailtype (Some t)) proto.name) | Some r -> - matchArgParam (fst e.info) t r |> M.ESC.lift >>| fun _ -> + matchArgParam e.tag.loc t r |> M.ESC.lift >>| fun _ -> buildSeqStmt s (Return (Some e)) end diff --git a/src/passes/ir/sailThir/thirUtils.ml b/src/passes/ir/sailThir/thirUtils.ml index 09e4fb4..354b97d 100644 --- a/src/passes/ir/sailThir/thirUtils.ml +++ b/src/passes/ir/sailThir/thirUtils.ml @@ -4,13 +4,17 @@ open Monad open IrHir module D = SailModule.Declarations -type expression = (loc * string, l_str) AstHir.expression (* string is the key for the type map *) -type statement = (loc,l_str,expression) AstHir.statement +type exp_tag = {loc:loc; ty:string} +type expression = (exp_tag,[`ResolvedImports | `TypedDeclVar]) HirAst.generic_exp +type statement = (loc,exp_tag,[`ResolvedImports | `TypedDeclVar]) HirAst.generic_stmt + +let buildExp loc ty = HirAst.buildExp {loc;ty} + module M = ThirMonad.Make(struct type t = statement - let mempty : t = {info=dummy_pos; stmt=Skip} - let mconcat : t -> t -> t = fun x y -> {info=dummy_pos; stmt=Seq (x,y)} + let mempty : t = {tag=dummy_pos; node=Skip} + let mconcat : t -> t -> t = fun x y -> {tag=dummy_pos; node=Seq (x,y)} end ) @@ -19,27 +23,27 @@ open UseMonad(M.ES) -let rec resolve_alias (l,ty : sailtype) : (sailtype,string) Either.t M.ES.t = - match ty with - | CompoundType {origin;name=(_,name);decl_ty=Some (T ());_} -> - let* (_,mname) = M.ES.throw_if_none Logging.(make_msg l @@ "unknown type '" ^ name ^ "' , all types must have an origin (problem with HIR)") origin in - let* ty_t = M.ES.get_decl name (Specific (mname,Type)) - >>= M.ES.throw_if_none Logging.(make_msg l @@ Fmt.str "declaration '%s' requires importing module '%s'" name mname) in +let rec resolve_alias (ty : sailtype) : (sailtype,string) Either.t M.ES.t = + match ty.value with + | CompoundType {origin;name;decl_ty=Some (T ());_} -> + let* mname = M.ES.throw_if_none Logging.(make_msg name.loc @@ "unknown type '" ^ name.value ^ "' , all types must have an origin (problem with HIR)") origin in + let* ty_t = M.ES.get_decl name.value (Specific (mname.value,Type)) + >>= M.ES.throw_if_none Logging.(make_msg name.loc @@ Fmt.str "declaration '%s' requires importing module '%s'" name.value mname.value) in begin match ty_t.ty with - | Some (_,CompoundType _ as ct) -> resolve_alias ct + | Some ({value=CompoundType _;_} as ct) -> resolve_alias ct | Some t -> return (Either.left t) - | None -> return (Either.right name) (* abstract type, only look at name *) + | None -> return (Either.right name.value) (* abstract type, only look at name *) end - | _ -> return (Either.left (l,ty)) + | _ -> return (Either.left ty) let string_of_sailtype_thir (t : sailtype option) : string M.ES.t = let+ res = match t with - | Some (_,CompoundType {origin; name=(loc,x); _}) -> - let* (_,mname) = M.ES.throw_if_none Logging.(make_msg loc "no origin in THIR (problem with HIR)") origin in - let+ decl = M.ES.(get_decl x (Specific (mname,Filter [E (); S (); T()])) - >>= throw_if_none Logging.(make_msg loc "decl is null (problem with HIR)")) in + | Some {value=CompoundType {origin; name; _};_} -> + let* mname = M.ES.throw_if_none Logging.(make_msg name.loc "no origin in THIR (problem with HIR)") origin in + let+ decl = M.ES.(get_decl name.value (Specific (mname.value,Filter [E (); S (); T()])) + >>= throw_if_none Logging.(make_msg name.loc "decl is null (problem with HIR)")) in begin match decl with | T ty_def -> @@ -48,7 +52,7 @@ let string_of_sailtype_thir (t : sailtype option) : string M.ES.t = | Some t -> Fmt.str " (= %s)" @@ string_of_sailtype (Some t) | None -> "(abstract)" end - | S (_,s) -> Fmt.str " (= struct <%s>)" (List.map (fun (n,(_,t,_)) -> Fmt.str "%s:%s" n @@ string_of_sailtype (Some t) ) s.fields |> String.concat ", ") + | S (_,s) -> Fmt.str " (= struct <%s>)" (List.map (fun (n,f) -> Fmt.str "%s:%s" n @@ string_of_sailtype (Some (fst f.value)) ) s.fields |> String.concat ", ") | _ -> failwith "can't happen" end | _ -> return "" @@ -59,36 +63,36 @@ let matchArgParam (loc : loc) (arg: sailtype) (m_param : sailtype) : sailtype let rec aux (a:sailtype) (m:sailtype) : sailtype M.ES.t = let* lt = resolve_alias a in let* rt = resolve_alias m in - + let mk_locatable = fun x -> mk_locatable loc x |> return in match lt,rt with - | Left (loc_l,l), Left (_,r) -> + | Left l, Left r -> begin - match l,r with - | Bool, Bool -> return (loc_l,Bool) - | (Int i1), (Int i2) when i1 = i2 -> return (loc_l,Int i1) - | Float, Float -> return (loc_l,Float) - | Char, Char -> return (loc_l,Char) - | String, String -> return (loc_l,String) + match l.value,r.value with + | Bool, Bool -> mk_locatable Bool + | (Int i1), (Int i2) when i1 = i2 -> mk_locatable (Int i1) + | Float, Float -> mk_locatable Float + | Char, Char -> mk_locatable Char + | String, String -> mk_locatable String | ArrayType (at,s), ArrayType (mt,s') -> if s = s' then - let+ t = aux at mt in loc_l,ArrayType (t,s) + let* t = aux at mt in mk_locatable (ArrayType (t,s)) else - M.ES.throw Logging.(make_msg loc_l (Printf.sprintf "array length mismatch : wants %i but %i provided" s' s)) + M.ES.throw Logging.(make_msg l.loc (Printf.sprintf "array length mismatch : wants %i but %i provided" s' s)) - | Box _at, Box _mt -> M.ES.throw Logging.(make_msg loc_l "todo box") + | Box _at, Box _mt -> M.ES.throw Logging.(make_msg l.loc "todo box") | RefType (at,am), RefType (mt,mm) -> - if am <> mm then M.ES.throw Logging.(make_msg loc_l "different mutability") + if am <> mm then M.ES.throw Logging.(make_msg l.loc "different mutability") else aux at mt | at, GenericType _ - | GenericType _, at -> return (loc_l,at) + | GenericType _, at -> mk_locatable at - | CompoundType c1, CompoundType c2 when snd c1.name = snd c2.name -> + | CompoundType c1, CompoundType c2 when c1.name.value = c2.name.value -> return arg | _ -> let* param = string_of_sailtype_thir (Some m_param) and* arg = string_of_sailtype_thir (Some arg) in - M.ES.throw Logging.(make_msg loc_l @@ Printf.sprintf "wants %s but %s provided" param arg) + M.ES.throw Logging.(make_msg l.loc @@ Printf.sprintf "wants %s but %s provided" param arg) end | Right name, Right name' -> @@ -105,15 +109,18 @@ let matchArgParam (loc : loc) (arg: sailtype) (m_param : sailtype) : sailtype let check_binop op l r : string M.ES.t = let* l_t = M.ES.get_type_from_id l and* r_t = M.ES.get_type_from_id r in + + let mk_locatable = fun x -> mk_locatable l_t.loc x in + match op with | Lt | Le | Gt | Ge | Eq | NEq -> - let* _ = matchArgParam (fst l_t) r_t l_t in M.ES.get_type_id (fst l_t,Bool) + let* _ = matchArgParam l_t.loc r_t l_t in M.ES.get_type_id (mk_locatable Bool) | And | Or -> - let* _ = matchArgParam (fst l_t) l_t (fst l_t,Bool) - and* _ = matchArgParam (fst l_t) r_t (fst l_t,Bool) - in M.ES.get_type_id (fst l_t,Bool) + let* _ = matchArgParam l_t.loc l_t (mk_locatable Bool) + and* _ = matchArgParam l_t.loc r_t (mk_locatable Bool) + in M.ES.get_type_id (mk_locatable Bool) | Plus | Mul | Div | Minus | Rem -> - let+ _ = matchArgParam (fst l_t) r_t l_t in snd l + let+ _ = matchArgParam l_t.loc r_t l_t in l.value let check_call (name:string) (f : method_proto) (args: expression list) loc : unit M.ES.t = @@ -126,8 +133,8 @@ let check_call (name:string) (f : method_proto) (args: expression list) loc : un ListM.iter2 ( fun (ca:expression) ({ty=a;_}:param) -> - let* rty = M.ES.get_type_from_id ca.info in - let+ _ = matchArgParam (fst ca.info) rty a in () + let* rty = M.ES.get_type_from_id (mk_locatable ca.tag.loc ca.tag.ty) in + let+ _ = matchArgParam ca.tag.loc rty a in () ) args f.args @@ -139,7 +146,7 @@ let find_function_source (fun_loc:loc) (_var: string option) (name : l_str) (imp let* mname,def = HirUtils.find_symbol_source ~filt:[M ()] name import env |> M.ES.lift in match def with | M decl -> - let+ _ = check_call (snd name) (snd decl) el fun_loc in mname,decl + let+ _ = check_call name.value (snd decl) el fun_loc in mname,decl (* let _x = fun_loc and _y = el in return (mname,decl) *) | _ -> failwith "non method returned" (* cannot happen because we only requested methods *) @@ -206,11 +213,11 @@ let find_struct_source (name: l_str) (import : l_str option) : (l_str * D.struct let* () = SEnv.iter ( fun (id,(l,{fields; generics})) -> let* fields = ListM.map ( - fun (name,(l,t,n)) -> + fun (name, ({value=t,n;_} as tn)) -> let* env = ES.get_env in let* t,decls = (follow_type t env.decls) |> ES.S.lift in let+ () = ES.set_env {env with decls} in - name,(l,t,n) + name,mk_locatable tn.loc (t,n) ) fields in let proto = l,{fields;generics} in @@ -237,7 +244,7 @@ let find_struct_source (name: l_str) (import : l_str option) : (l_str * D.struct let true_name = (match m_body with Left (sname,_) -> sname | Right _ -> m_proto.name) in let+ () = ES.update_env (fun e -> - let decls = update_decl m_proto.name ((m_proto.pos,true_name), defn_to_proto (Method m)) (Self Method) e.decls in + let decls = update_decl m_proto.name (mk_locatable m_proto.pos true_name, defn_to_proto (Method m)) (Self Method) e.decls in {e with decls} ) in m diff --git a/src/passes/misc/cfg_analysis.ml b/src/passes/misc/cfg_analysis.ml index 8e42a2f..4dc462c 100644 --- a/src/passes/misc/cfg_analysis.ml +++ b/src/passes/misc/cfg_analysis.ml @@ -4,7 +4,7 @@ open Pass open IrMir open SailParser open IrHir -open AstMir +open MirAst open Monad @@ -66,7 +66,7 @@ let check_returns (proto : method_sig) (decls,cfg : mir_function) : mir_functio module Pass = Make(struct let name = "analysis on mir" - type in_body = (AstMir.mir_function,(HirUtils.statement,HirUtils.expression) AstParser.process_body) SailModule.methods_processes + type in_body = (MirAst.mir_function,(HirUtils.statement,HirUtils.expression) AstParser.process_body) SailModule.methods_processes type out_body = in_body let transform (sm: in_body SailModule.t) : out_body SailModule.t E.t = diff --git a/src/passes/monomorphization/monomorphization.ml b/src/passes/monomorphization/monomorphization.ml index 2a63bf0..ac6701e 100644 --- a/src/passes/monomorphization/monomorphization.ml +++ b/src/passes/monomorphization/monomorphization.ml @@ -3,7 +3,7 @@ open Monad open TypesCommon module E = Common.Logging open Monad.MonadSyntax (E.Logger) -open IrMir.AstMir +open IrMir.MirAst open MonomorphizationMonad module M = MonoMonad open MonomorphizationUtils @@ -19,45 +19,45 @@ module Pass = Pass.Make (struct let mono_fun (f : sailor_function) (sm : in_body SailModule.t) : unit M.t = - let mono_exp (e : expression) (decls :declaration list) : sailtype M.t = - let rec aux (e : expression) : sailtype M.t = - match e.exp with + let mono_exp (exp : expression) (decls :declaration list) : sailtype M.t = + let rec aux (exp : expression) : sailtype M.t = + match exp.node with | Variable s -> M.get_var s <&> (function | Some v -> Some (snd v).ty (* var is a function param *) | None -> Option.bind (List.find_opt (fun v -> v.id = s) decls) (fun decl -> Some decl.varType) (* var is function declaration *) ) - >>= M.throw_if_none Logging.(make_msg (fst e.info) @@ Fmt.str "compiler error : var '%s' not found" s) + >>= M.throw_if_none Logging.(make_msg exp.tag.loc @@ Fmt.str "compiler error : var '%s' not found" s) | Literal l -> return (sailtype_of_literal l) - | ArrayRead (e, idx) -> + | ArrayRead a -> begin - let* l,t = aux e in - match t with + let* t = aux a.array in + match t.value with | ArrayType (t, _) -> - let+ idx_t = aux idx in - let _ = resolveType idx_t (l,Int 32) [] [] in + let+ idx_t = aux a.idx in + let _ = resolveType idx_t (mk_locatable t.loc @@ Int 32) [] [] in t | _ -> failwith "cannot happen" end | UnOp (_, e) -> aux e - | BinOp (_, e1, e2) -> - let* t1 = aux e1 in - let+ t2 = aux e2 in - let _ = resolveType t1 t2 [] [] in - t1 + | BinOp bop -> + let* left = aux bop.left in + let+ right = aux bop.right in + let _ = resolveType left right [] [] in + left | Ref (m, e) -> let+ t = aux e in - dummy_pos,RefType (t, m) + mk_locatable exp.tag.loc @@ RefType (t, m) | Deref e -> ( - let+ l,t = aux e in - match t with - | RefType _ -> l,t + let+ t = aux e in + match t.value with + | RefType _ -> t | _ -> failwith "cannot happen" ) @@ -70,15 +70,14 @@ module Pass = Pass.Make (struct next_t ) t h in - dummy_pos,ArrayType (t, List.length (e :: h)) + mk_locatable exp.tag.loc @@ ArrayType (t, List.length (e :: h)) | ArrayStatic [] -> failwith "error : empty array" - | StructAlloc (_, _, _) -> failwith "todo: struct alloc" + | StructAlloc2 _ -> failwith "todo: struct alloc" | EnumAlloc (_, _) -> failwith "todo: enum alloc" - | StructRead (_, _, _) -> failwith "todo: struct read" - | MethodCall _ -> failwith "no method call at this stage" + | StructRead2 _ -> failwith "todo: struct read" in - aux e + aux exp in let construct_call (calle : string) (el : expression list) decls : (string * sailtype option) M.t = @@ -111,47 +110,47 @@ module Pass = Pass.Make (struct begin let* f = find_callable calle sm |> M.lift in match f with - | None -> (*import *) return (mname,Some (dummy_pos,Int 32) (*fixme*)) + | None -> (*import *) return (mname,Some (mk_locatable dummy_pos @@ Int 32) (*fixme*)) | Some f -> - begin - Logs.debug (fun m -> m "found call to %s, variadic : %b" f.m_proto.name f.m_proto.variadic ); - match f.m_body with - | Right _ -> - (* process and method - - we make sure they correspond to what the callable wants - if the callable is generic we check all the generic types are present at least once - - we build a (string*sailtype) list of generic to type correspondance - if the generic is not found in the list, we add it with the corresponding type - if the generic already exists with the same type as the new one, we are good else we fail - *) - let* resolved_generics = check_args call_args f |> M.lift in - List.iter (fun (n, t) -> Logs.debug (fun m -> m "resolved %s to %s " n (string_of_sailtype (Some t)))) resolved_generics; - - let* () = M.push_monos calle resolved_generics in - - let* rtype = - match f.m_proto.rtype with - | Some t -> - (* Logs.warn (fun m -> m "TYPE BEFORE : %s" (string_of_sailtype (Some t))); *) - let+ t = (degenerifyType t resolved_generics|> M.lift) in - (* Logs.warn (fun m -> m "TYPE AFTER : %s" (string_of_sailtype (Some t))); *) - Some t - | None -> return None - in - - let params = List.map2 (fun (p:param) ty -> {p with ty}) f.m_proto.params call_args in - let name = mname in - let methd = { f with m_proto = { f.m_proto with rtype ; params } } in - let+ () = - let* f = M.get_decl name (Self Method) in - if Option.is_none f then - M.add_decl name ((dummy_pos,name),(defn_to_proto (Method methd))) Method - else return () - in - mname,rtype - | Left _ -> (* external method *) return (calle,f.m_proto.rtype) + begin + Logs.debug (fun m -> m "found call to %s, variadic : %b" f.m_proto.name f.m_proto.variadic ); + match f.m_body with + | Right _ -> + (* process and method + + we make sure they correspond to what the callable wants + if the callable is generic we check all the generic types are present at least once + + we build a (string*sailtype) list of generic to type correspondance + if the generic is not found in the list, we add it with the corresponding type + if the generic already exists with the same type as the new one, we are good else we fail + *) + let* resolved_generics = check_args call_args f |> M.lift in + List.iter (fun (n, t) -> Logs.debug (fun m -> m "resolved %s to %s " n (string_of_sailtype (Some t)))) resolved_generics; + + let* () = M.push_monos calle resolved_generics in + + let* rtype = + match f.m_proto.rtype with + | Some t -> + (* Logs.warn (fun m -> m "TYPE BEFORE : %s" (string_of_sailtype (Some t))); *) + let+ t = (degenerifyType t resolved_generics|> M.lift) in + (* Logs.warn (fun m -> m "TYPE AFTER : %s" (string_of_sailtype (Some t))); *) + Some t + | None -> return None + in + + let params = List.map2 (fun (p:param) ty -> {p with ty}) f.m_proto.params call_args in + let name = mname in + let methd = { f with m_proto = { f.m_proto with rtype ; params } } in + let+ () = + let* f = M.get_decl name (Self Method) in + if Option.is_none f then + M.add_decl name ((mk_locatable dummy_pos name),(defn_to_proto (Method methd))) Method + else return () + in + mname,rtype + | Left _ -> (* external method *) return (calle,f.m_proto.rtype) end end in diff --git a/src/passes/monomorphization/monomorphizationUtils.ml b/src/passes/monomorphization/monomorphizationUtils.ml index 3511a66..5f39551 100644 --- a/src/passes/monomorphization/monomorphizationUtils.ml +++ b/src/passes/monomorphization/monomorphizationUtils.ml @@ -3,10 +3,10 @@ open TypesCommon open Monad open IrHir module E = Logging.Logger -module Env = SailModule.SailEnv(IrMir.AstMir.V) +module Env = SailModule.SailEnv(IrMir.MirAst.V) open UseMonad(E) -type in_body = IrMir.AstMir.mir_function +type in_body = IrMir.MirAst.mir_function type out_body = { monomorphics : in_body method_defn list; polymorphics : in_body method_defn list; @@ -35,39 +35,46 @@ let print_method_proto (name : string) (methd : in_body sailor_method) = let resolveType (arg : sailtype) (m_param : sailtype) (generics : string list) (resolved_generics : sailor_args) : (sailtype * sailor_args) E.t = - let rec aux ((aloc, a) : sailtype) ((mloc, m) : sailtype) (g : sailor_args) : (sailtype * sailor_args) E.t = - match a,m with - | Bool, Bool -> return ((aloc,Bool), g) - | Int x, Int y when x = y -> return ((aloc,Int x), g) - | Float, Float -> return ((aloc,Float), g) - | Char, Char -> return ((aloc,Char), g) - | String, String -> return ((aloc,String), g) - | ArrayType (at, s), ArrayType (mt, _) -> let+ t,g = aux at mt g in (aloc,ArrayType (t, s)), g - | GenericType _g1, GenericType _g2 -> return ((aloc,Int 32),g) + let rec aux (at : sailtype) (mt : sailtype) (g : sailor_args) : (sailtype * sailor_args) E.t = + match at.value,mt.value with + | Bool, Bool -> return ((mk_locatable at.loc Bool), g) + | Int x, Int y when x = y -> return ((mk_locatable at.loc @@ Int x), g) + | Float, Float -> return ((mk_locatable at.loc Float), g) + | Char, Char -> return ((mk_locatable at.loc Char), g) + | String, String -> return ((mk_locatable at.loc String), g) + | ArrayType (at, s), ArrayType (mt, _) -> let+ t,g = aux at mt g in (mk_locatable at.loc @@ ArrayType (t, s)), g + | GenericType _g1, GenericType _g2 -> return ((mk_locatable at.loc @@ Int 32),g) (* E.throw Logging.(make_msg dummy_pos @@ Fmt.str "resolveType between generic %s and %s" g1 g2) *) | _, GenericType gt -> - let* () = E.throw_if Logging.(make_msg mloc @@ Fmt.str "generic type %s not declared" gt) (not @@ List.mem gt generics) in + let* () = E.throw_if + Logging.(make_msg mt.loc @@ Fmt.str "generic type %s not declared" gt) + (not @@ List.mem gt generics) in begin match List.assoc_opt gt g with - | None -> return ((aloc,a), (gt, (aloc,a)) :: g) - | Some (lt,t) -> + | None -> return (at, (gt, at) :: g) + | Some t -> E.throw_if - Logging.(make_msg lt @@ Fmt.str "generic type mismatch : %s -> %s vs %s" gt (string_of_sailtype (Some (lt,t))) (string_of_sailtype (Some (aloc,a)))) - (t <> a) - >>| fun () -> (aloc,a), g + Logging.(make_msg t.loc @@ Fmt.str "generic type mismatch : %s -> %s vs %s" gt + (string_of_sailtype (Some t)) + (string_of_sailtype (Some at))) + (t.value <> at.value) + >>| fun () -> at, g end + | RefType (at, _), RefType (mt, _) -> aux at mt g | CompoundType _, CompoundType _ -> failwith "todocompoundtype" | Box _at, Box _mt -> failwith "todobox" - | _ -> E.throw Logging.(make_msg dummy_pos @@ Fmt.str "cannot happen : %s vs %s" (string_of_sailtype (Some (aloc,a))) (string_of_sailtype (Some (mloc,m)))) + | _ -> E.throw Logging.(make_msg dummy_pos @@ Fmt.str "cannot happen : %s vs %s" + (string_of_sailtype (Some at)) + (string_of_sailtype (Some mt))) in aux arg m_param resolved_generics let degenerifyType (t : sailtype) (generics : sailor_args) : sailtype E.t = - let rec aux (l,t) = - let+ t = match t with + let rec aux t = + let+ t' = match t.value with | Bool -> return Bool | Int n -> return (Int n) | Float -> return Float @@ -77,10 +84,13 @@ let degenerifyType (t : sailtype) (generics : sailor_args) : sailtype E.t = | Box t -> let+ t = aux t in Box t | RefType (t, m) -> let+ t = aux t in RefType (t, m) | GenericType n -> - let+ t = E.throw_if_none Logging.(make_msg dummy_pos @@ Fmt.str "generic type %s not present in the generics list" n) (List.assoc_opt n generics) in - snd t + let+ t = E.throw_if_none + Logging.(make_msg dummy_pos @@ Fmt.str "generic type %s not present in the generics list" n) + (List.assoc_opt n generics) + in + t.value | CompoundType _ -> failwith "todo compoundtype" - in l,t + in mk_locatable t.loc t' in aux t diff --git a/src/passes/process/process.ml b/src/passes/process/process.ml index dbaa83d..0373abd 100644 --- a/src/passes/process/process.ml +++ b/src/passes/process/process.ml @@ -4,7 +4,7 @@ open IrHir open SailParser open ProcessUtils module H = HirUtils -module HirS = AstHir.Syntax +module HirS = HirAst.Syntax module E = Logging.Logger open ProcessMonad open Monad.UseMonad(M) @@ -19,7 +19,7 @@ module Pass = Pass.Make(struct let rec compute_tree closed (l,pi:loc * _ proc_init) : H.statement M.t = let closed = FieldSet.add pi.proc closed in (* no cycle *) - let* p = find_process_source (l,pi.proc) pi.mloc procs (*fixme : grammar to allow :: syntax *) in + let* p = find_process_source (mk_locatable l pi.proc) pi.mloc procs (*fixme : grammar to allow :: syntax *) in let* p = M.throw_if_none Logging.(make_msg l @@ Fmt.str "process '%s' is unknown" pi.proc) p in let* tag = M.fresh_prefix p.p_name in let prefix = (Fmt.str "%s_%s_" tag) in @@ -33,48 +33,51 @@ module Pass = Pass.Make(struct let* () = param_arg_mismatch "init" p.p_interface.p_params pi.params in - let rename_l = List.map2 (fun (_,subx) (_,(x,_)) -> (x,subx) ) (pi.read @ pi.write) (fst p.p_interface.p_shared_vars @ snd p.p_interface.p_shared_vars) in + let rename_l = List.map2 (fun subx x -> (fst x.value,subx.value) ) + (pi.read @ pi.write) + (fst p.p_interface.p_shared_vars @ snd p.p_interface.p_shared_vars) in let rename = fun id -> match List.assoc_opt id rename_l with Some v -> v | None -> id in (* add process local (but persistant) vars *) - ListM.iter (fun ((l,id),ty) -> + ListM.iter (fun (id,ty) -> let* ty,_ = HirUtils.follow_type ty sm.declEnv |> M.EC.lift |> M.ECW.lift |> M.lift in - M.(write_decls HirS.(var (l,prefix id,ty))) + M.(write_decls HirS.(var l (prefix id.value) ty None)) ) p.p_body.locals >>= fun () -> let* params = ListM.fold_right2 (fun (p:param) arg params -> let param = prefix p.id in (* add process parameters to the decls *) - M.(write_decls HirS.(var (p.loc,param,p.ty))) >>| fun () -> + M.(write_decls HirS.(var p.loc param p.ty None)) >>| fun () -> HirS.(params && !@param = arg) ) p.p_interface.p_params pi.params M.SeqMonoid.empty in (* add process init *) - let init = H.rename_var_stmt prefix p.p_body.init in + let init = HirUtils.rename_var_stmt prefix p.p_body.init in M.write_init HirS.(!! (params && init)) >>= fun () -> (* inline process calls *) - let rec aux ((_,s) : (H.statement, H.expression) AstParser.p_statement) (_ty:AstParser.pgroup_ty) : H.statement M.t = + let rec aux (stmt : (H.statement, H.expression) AstParser.p_statement) (_ty:AstParser.pgroup_ty) : H.statement M.t = let replace_or_prefix = fun id -> let new_id = rename id in if new_id <> id then new_id else prefix id in - let process_cond c s = match c with Some c -> HirS.(_if (H.rename_var_exp replace_or_prefix c) s skip) | None -> s in + let process_cond c s = match c with Some c -> HirS.(if_ (HirUtils.rename_var_exp replace_or_prefix c) s (skip ())) | None -> s in - match s with + match stmt.value with | Statement (s,cond) -> - let s = H.rename_var_stmt replace_or_prefix s in + let s = HirUtils.rename_var_stmt replace_or_prefix s in return (process_cond cond s) - | Run ((l,id),cond) -> - M.throw_if Logging.(make_msg l "not allowed to call Main process explicitely") (id = Constants.main_process) >>= fun () -> - M.throw_if Logging.(make_msg l "not allowed to have recursive process") (FieldSet.mem id closed) >>= fun () -> - let* l,pi = M.throw_if_none Logging.(make_msg l @@ Fmt.str "no proc init called '%s'" id) (List.find_opt (fun (_,p: loc * _ proc_init) -> p.id = id) p.p_body.proc_init) in - let read = List.map (fun (l,id) -> l,prefix id) pi.read in - let write = List.map (fun (l,id) -> l,prefix id) pi.write in - let params = List.map (H.rename_var_exp prefix) pi.params in - compute_tree closed (l,{pi with read ; write ; params}) >>| process_cond cond + | Run (id,cond) -> + M.throw_if Logging.(make_msg l "not allowed to call Main process explicitely") (id.value = Constants.main_process) >>= fun () -> + M.throw_if Logging.(make_msg l "not allowed to have recursive process") (FieldSet.mem id.value closed) >>= fun () -> + let* pi = M.throw_if_none Logging.(make_msg l @@ Fmt.str "no proc init called '%s'" id.value) + (List.find_opt (fun p -> p.value.id = id.value) p.p_body.proc_init) in + let read = List.map (fun (id:l_str) -> mk_locatable id.loc @@ prefix id.value) pi.value.read in + let write = List.map (fun (id:l_str) -> mk_locatable id.loc @@ prefix id.value) pi.value.write in + let params = List.map (HirUtils.rename_var_exp prefix) pi.value.params in + compute_tree closed (l,{pi.value with read ; write ; params}) >>| process_cond cond | PGroup g -> - ListM.fold_right (fun child s -> let+ res = aux child g.p_ty in HirS.(s && res)) g.children HirS.skip >>| process_cond g.cond + ListM.fold_right (fun child s -> let+ res = aux child g.p_ty in HirS.(s && res)) g.children (HirS.skip ()) >>| process_cond g.cond in aux p.p_body.loop Parallel diff --git a/src/passes/process/processMonad.ml b/src/passes/process/processMonad.ml index e4fcc1c..722349a 100644 --- a/src/passes/process/processMonad.ml +++ b/src/passes/process/processMonad.ml @@ -12,7 +12,7 @@ module V = ( module M = struct - open AstHir + open HirAst module E = Logging.Logger module Env = Env.VariableDeclEnv(SailModule.Declarations)(V) @@ -21,11 +21,11 @@ module M = struct type t = {decls : HirUtils.statement; init : HirUtils.statement ; loop : HirUtils.statement} - let empty = {info=dummy_pos ; stmt=Skip} - let concat s1 s2 = match s1.stmt,s2.stmt with + let empty = {tag=dummy_pos ; node=Skip} + let concat s1 s2 = match s1.node,s2.node with | Skip, Skip -> empty - | Skip,stmt | stmt,Skip -> {info=dummy_pos ; stmt} - | _ -> {info=dummy_pos ; stmt=Seq (s1,s2)} + | Skip,node | node,Skip -> {tag=dummy_pos ; node} + | _ -> {tag=dummy_pos ; node=Seq (s1,s2)} let mempty = {decls=empty; init=empty ; loop=empty} let mconcat s1 s2 = {decls=concat s1.decls s2.decls; init = concat s1.init s2.init ; loop = concat s1.loop s2.loop} diff --git a/src/passes/process/processUtils.ml b/src/passes/process/processUtils.ml index cff855a..e9c1f4b 100644 --- a/src/passes/process/processUtils.ml +++ b/src/passes/process/processUtils.ml @@ -6,7 +6,7 @@ open IrHir module E = Logging.Logger module D = SailModule.Declarations -type body = (Hir.statement,(Hir.statement,Hir.expression) SailParser.AstParser.process_body) SailModule.methods_processes +type body = (HirUtils.statement,(HirUtils.statement,HirUtils.expression) SailParser.AstParser.process_body) SailModule.methods_processes let method_of_main_process (p : 'a process_defn): 'a method_defn = let m_proto = {pos=p.p_pos; name="main"; generics = p.p_generics; params = p.p_interface.p_params; variadic=false; rtype=None; extern=false} @@ -15,7 +15,7 @@ let method_of_main_process (p : 'a process_defn): 'a method_defn = let finalize (proc_def,(new_body: M.ECW.elt)) = - let open AstHir in + let open HirAst in let (++) = M.SeqMonoid.concat in let main = method_of_main_process proc_def in @@ -38,13 +38,13 @@ let ppPrintModule (pf : Format.formatter) (m : body SailModule.t ) : unit = let find_process_source (name: l_str) (import : l_str option) procs : 'a process_defn option M.t = let* _,env = M.get in - let* (_,origin),_ = HirUtils.find_symbol_source ~filt:[P()] name import env |> M.from_error in + let* origin,_ = HirUtils.find_symbol_source ~filt:[P()] name import env |> M.from_error in let+ procs = - if origin = HirUtils.D.get_name env then return procs + if origin.value = HirUtils.D.get_name env then return procs else - let find_import = List.find_opt (fun i -> i.mname = origin) (HirUtils.D.get_imports env) in + let find_import = List.find_opt (fun i -> i.mname = origin.value) (HirUtils.D.get_imports env) in let+ i = M.throw_if_none Logging.(make_msg dummy_pos "can't happen") find_import in let sm = In_channel.with_open_bin (i.dir ^ i.mname ^ Constants.mir_file_ext) @@ fun c -> (Marshal.from_channel c : Mono.MonomorphizationUtils.out_body SailModule.t) in sm.body.processes in - List.find_opt (fun (p:_ process_defn) -> p.p_name = snd name) procs + List.find_opt (fun (p:_ process_defn) -> p.p_name = name.value) procs