Skip to content

Commit

Permalink
Constrain generic AST
Browse files Browse the repository at this point in the history
  • Loading branch information
terencode committed Sep 4, 2023
1 parent eacb0f1 commit 1b57858
Show file tree
Hide file tree
Showing 29 changed files with 923 additions and 804 deletions.
30 changes: 15 additions & 15 deletions src/codegen/codegenEnv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ open MakeOrderedFunctions(ImportCmp)
module Declarations = struct
include SailModule.Declarations
type process_decl = unit
type method_decl = {defn : AstMir.mir_function method_defn ; llval : llvalue ; extern : bool}
type method_decl = {defn : MirAst.mir_function method_defn ; llval : llvalue ; extern : bool}
type struct_decl = {defn : struct_proto; ty : lltype}
type enum_decl = unit
end
Expand All @@ -34,23 +34,23 @@ open Declarations
type in_body = Monomorphization.Pass.out_body


let getLLVMBasicType f t llc llm : lltype E.t =
let rec aux t =
match snd t with
let getLLVMBasicType f ty llc llm : lltype E.t =
let rec aux ty =
match ty.value with
| Bool -> i1_type llc |> return
| Int n -> integer_type llc n |> return
| Float -> double_type llc |> return
| Char -> i8_type llc |> return
| String -> i8_type llc |> pointer_type |> return
| ArrayType (t,s) -> let+ t = aux t in array_type t s
| Box t | RefType (t,_) -> aux t <&> pointer_type
| GenericType _ -> E.throw Logging.(make_msg (fst t) "no generic type in codegen")
| CompoundType {name=(_,name); _} when name = "_value" -> i64_type llc |> return (* for extern functions *)
| GenericType _ -> E.throw Logging.(make_msg ty.loc "no generic type in codegen")
| CompoundType {name; _} when name.value = "_value" -> i64_type llc |> return (* for extern functions *)
| CompoundType {origin=None;_}
| CompoundType {decl_ty=None;_} -> E.throw Logging.(make_msg (fst t) "compound type with no origin or decl_ty")
| CompoundType {origin=Some (_,mname); name=(_,name); decl_ty=Some d;_} ->
f (mname,name,d) llc llm aux
in aux t
| CompoundType {decl_ty=None;_} -> E.throw Logging.(make_msg ty.loc "compound type with no origin or decl_ty")
| CompoundType {origin=Some mname; name; decl_ty=Some d;_} ->
f (mname.value,name.value,d) llc llm aux
in aux ty


let handle_compound_type_codegen env (mname,name,d) llc _llm (aux : sailtype -> lltype E.t) : lltype E.t =
Expand Down Expand Up @@ -80,7 +80,7 @@ let getLLVMBasicType f t llc llm : lltype E.t =
| Some (E _enum) -> failwith "todo enum"
| Some (S (_,defn)) ->
let _,f_types = List.split defn.fields in
let* elts = ListM.map (fun (_,t,_) -> aux t) f_types <&> Array.of_list in
let* elts = ListM.map (fun ty -> aux (fst ty.value)) f_types <&> Array.of_list in
begin
match type_by_name llm ("struct." ^ name) with
| Some ty -> return ty
Expand Down Expand Up @@ -129,16 +129,16 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t =
);

let valueify_method_sig (m:method_sig) : method_sig =
let value = fun pos -> dummy_pos,CompoundType{origin=None;name=(pos,"_value");generic_instances=[];decl_ty=None} in
let value = fun pos -> mk_locatable dummy_pos @@ CompoundType{origin=None;name=(mk_locatable pos "_value");generic_instances=[];decl_ty=None} in
let rtype = m.rtype in (* keep the current type *)
let params = List.map (fun (p:param) -> {p with ty=(value p.loc)}) m.params in
{m with params; rtype}
in

(* because the imports are at the mir stage, we also have to do some codegen for them *)

let load_methods (methods: IrMir.AstMir.mir_function method_defn list) is_import env =
ListM.fold_left ( fun d (m:IrMir.AstMir.mir_function method_defn) ->
let load_methods (methods: IrMir.MirAst.mir_function method_defn list) is_import env =
ListM.fold_left ( fun d (m:IrMir.MirAst.mir_function method_defn) ->
let extern,proto =
if (Either.is_left m.m_body) then (* extern method, all parameters must be of type value *)
true,valueify_method_sig m.m_proto
Expand Down Expand Up @@ -167,7 +167,7 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t =
let load_structs structs write_env =
SEnv.fold (fun acc (name,(_,defn)) ->
let _,f_types = List.split defn.fields in
let* elts = ListM.map (fun (_,t,_) -> _getLLVMType sm.declEnv t llc llm) f_types <&> Array.of_list in
let* elts = ListM.map (fun ty-> _getLLVMType sm.declEnv (fst ty.value) llc llm) f_types <&> Array.of_list in
let ty = match type_by_name llm ("struct." ^ name) with
| Some ty -> ty
| None -> let ty = named_struct_type llc ("struct." ^ name) in
Expand Down
20 changes: 10 additions & 10 deletions src/codegen/codegenUtils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,24 @@ let getLLVMLiteral (l:literal) (llvm:llvm_args) : llvalue =
| LChar c -> const_int (i8_type llvm.c) (Char.code c)
| LString s -> build_global_stringptr s ".str" llvm.b

let ty_of_alias(t:sailtype) env : sailtype =
match snd t with
| CompoundType {origin=Some (_,mname); name=(_,name);decl_ty=Some T ();_} ->
let ty_of_alias(ty:sailtype) env : sailtype =
match ty.value with
| CompoundType {origin=Some mname; name;decl_ty=Some T ();_} ->
begin
match DeclEnv.find_decl name (Specific (mname,Type)) env with
match DeclEnv.find_decl name.value (Specific (mname.value,Type)) env with
| Some {ty=Some t';_} -> t'
| Some {ty=None;_} -> t
| None -> failwith @@ Fmt.str "ty_of_alias : '%s' not found in %s" (string_of_sailtype (Some t)) mname
| Some {ty=None;_} -> ty
| None -> failwith @@ Fmt.str "ty_of_alias : '%s' not found in %s" (string_of_sailtype (Some ty)) mname.value
end
| _ -> t
| _ -> ty

let unary (op:unOp) (t,v) : llbuilder -> llvalue =
let f =
match snd t,op with
| Float,Neg -> build_fneg
| Int _,Neg -> build_neg
| _,Not -> build_not
| _ -> Printf.sprintf "bad unary operand type : '%s'. Only double and int are supported" (string_of_sailtype (Some t)) |> failwith
| _ -> Printf.sprintf "bad unary operand type : '%s'. Only double and int are supported" (string_of_sailtype (Some (mk_locatable (fst t) (snd t)))) |> failwith
in f v ""


Expand Down Expand Up @@ -76,8 +76,8 @@ let binary (op:binOp) (t:sailtype) (l1:llvalue) (l2:llvalue) : llbuilder -> llva
| And -> "and" | Or -> "or" | Le -> "le" | Lt -> "lt" | Ge -> "ge" | Gt -> "gt" | Mul -> "mul"
| NEq -> "neq" | Div -> "div"
in
let t = if snd t = Bool then fst t,Int 1 else t in (* thir will have checked for correctness *)
let l = operators (snd t) in
let t = if t.value = Bool then mk_locatable t.loc @@ Int 1 else t in (* thir will have checked for correctness *)
let l = operators t.value in
let open Common.Monad.MonadOperator(Common.MonadOption.M) in
match l >>| List.assoc_opt op |> Option.join with
| Some oper -> oper l1 l2 ""
Expand Down
103 changes: 54 additions & 49 deletions src/codegen/codegen_.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ open IrMir
open Monad.UseMonad(E)
module L = Llvm
module E = Logging.Logger
let get_type (e:AstMir.expression) = snd e.info
let get_type (e:MirAst.expression) = e.tag.ty

let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x: AstMir.expression) : L.llvalue E.t =
match x.exp with
let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp: MirAst.expression) : L.llvalue E.t =
match exp.node with
| Variable x ->
let+ _,v = match (SailEnv.get_var x venv) with
| Some (_,n) -> return n
Expand All @@ -18,50 +18,57 @@ let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x:

| Deref x -> eval_r env llvm x

| ArrayRead (array_exp, index_exp) ->
let* array_val = eval_l env llvm array_exp in
let+ index = eval_r env llvm index_exp in
| ArrayRead a ->
let* array_val = eval_l env llvm a.array in
let+ index = eval_r env llvm a.idx in
let llvm_array = L.build_in_bounds_gep array_val [|L.(const_int (i64_type llvm.c) 0 ); index|] "" llvm.b in
llvm_array

| StructRead ((_,mname),struct_exp,(_,field)) ->
let* st = eval_l env llvm struct_exp in
let+ st_type_name = Env.TypeEnv.get_from_id struct_exp.info tenv >>= function _,CompoundType c -> return (snd c.name) | _ -> E.throw Logging.(make_msg dummy_pos "problem with structure type") in
let fields = (SailEnv.get_decl st_type_name (Specific (mname,Struct)) venv |> Option.get).defn.fields in
let _,_,idx = List.assoc field fields in
| StructRead2 s ->
let* st = eval_l env llvm s.value.strct in
let* st_type_name = Env.TypeEnv.get_from_id (mk_locatable s.value.strct.tag.loc s.value.strct.tag.ty) tenv >>= function
| {value=CompoundType c;_} -> return c.name.value
| _ -> E.throw Logging.(make_msg dummy_pos "problem with structure type")
in
let+ decl = (SailEnv.get_decl st_type_name (Specific (s.import.value,Struct)) venv
|> E.throw_if_none Logging.(make_msg exp.tag.loc @@ Fmt.str "compiler error : no decl '%s' found" st_type_name)) in

let fields = decl.defn.fields in
let {value=_,idx;_} = List.assoc s.value.field.value fields in
L.build_struct_gep st idx "" llvm.b

| StructAlloc (_,(_,name),fields) ->
let _,fieldlist = fields |> List.split in
let* strct_ty = match L.type_by_name llvm.m ("struct." ^ name) with
| StructAlloc2 s ->
let _,fieldlist = s.value.fields |> List.split in
let* strct_ty = match L.type_by_name llvm.m ("struct." ^ s.value.name.value) with
| Some s -> return s
| None ->
E.throw Logging.(make_msg (fst x.info) @@ "unknown structure : " ^ ("struct." ^ name))
E.throw Logging.(make_msg exp.tag.loc @@ "unknown structure : " ^ ("struct." ^ s.value.name.value))
in
let struct_v = L.build_alloca strct_ty "" llvm.b in
let+ () = ListM.iteri ( fun i (_,f) ->
let+ v = eval_r env llvm f in
let+ () = ListM.iteri ( fun i f ->
let+ v = eval_r env llvm f.value in
let v_f = L.build_struct_gep struct_v i "" llvm.b in
L.build_store v v_f llvm.b |> ignore
) fieldlist in
struct_v

| _ -> E.throw Logging.(make_msg (fst x.info) "unexpected rvalue for codegen")
| Literal _ | UnOp _ | BinOp _ | Ref _ | ArrayStatic _ | EnumAlloc _ ->
E.throw Logging.(make_msg exp.tag.loc "unexpected lvalue for codegen")


and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x:AstMir.expression) : L.llvalue E.t =
let* ty = Env.TypeEnv.get_from_id x.info tenv in
match x.exp with
| Variable _ | StructRead _ | ArrayRead _ | StructAlloc _ -> let+ v = eval_l env llvm x in L.build_load v "" llvm.b
and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp:MirAst.expression) : L.llvalue E.t =
let* ty = Env.TypeEnv.get_from_id (mk_locatable exp.tag.loc exp.tag.ty) tenv in
match exp.node with
| Variable _ | StructRead2 _ | ArrayRead _ | StructAlloc2 _ -> let+ v = eval_l env llvm exp in L.build_load v "" llvm.b

| Literal l -> return @@ getLLVMLiteral l llvm

| UnOp (op,e) -> let+ l = eval_r env llvm e in unary op (ty_of_alias ty (snd venv),l) llvm.b
| UnOp (op,e) -> let+ l = eval_r env llvm e in unary op (let ty = ty_of_alias ty (snd venv) in (ty.loc,ty.value),l) llvm.b

| BinOp (op,e1, e2) ->
let+ l1 = eval_r env llvm e1
and* l2 = eval_r env llvm e2
in binary op (ty_of_alias ty (snd venv)) l1 l2 llvm.b
| BinOp bop ->
let+ l1 = eval_r env llvm bop.left
and* l2 = eval_r env llvm bop.right
in binary bop.op (ty_of_alias ty (snd venv)) l1 l2 llvm.b
| Ref (_,e) -> eval_l env llvm e

| Deref e -> let+ v = eval_l env llvm e in L.build_load v "" llvm.b
Expand All @@ -80,32 +87,31 @@ and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x:AstMi
L.build_load array "" llvm.b
end

| EnumAlloc _ -> E.throw Logging.(make_msg (fst x.info) "enum allocation unimplemented")
| EnumAlloc _ -> E.throw Logging.(make_msg exp.tag.loc "enum allocation unimplemented")

| _ -> E.throw Logging.(make_msg (fst x.info) "problem with thir")

and construct_call (name:string) ((loc,mname):l_str) (args:AstMir.expression list) (venv,tenv as env : SailEnv.t*Env.TypeEnv.t) (llvm:llvm_args) : L.llvalue E.t =
let* args_type,llargs = ListM.map (fun arg -> let+ r = eval_r env llvm arg in arg.info,r) args >>| List.split
and construct_call (name:string) (mname:l_str) (args:MirAst.expression list) (venv,tenv as env : SailEnv.t*Env.TypeEnv.t) (llvm:llvm_args) : L.llvalue E.t =
let* args_type,llargs = ListM.map (fun arg -> let+ r = eval_r env llvm arg in arg.tag,r) args >>| List.split
in
(* let mname = mangle_method_name name origin.mname args_type in *)
let mangled_name = "_" ^ mname ^ "_" ^ name in
let mangled_name = "_" ^ mname.value ^ "_" ^ name in
Logs.debug (fun m -> m "constructing call to %s" name);
let* llval,ext = match SailEnv.get_decl mangled_name (Specific (mname,Method)) venv with
let* llval,ext = match SailEnv.get_decl mangled_name (Specific (mname.value,Method)) venv with
| None ->
begin
match SailEnv.get_decl name (Specific (mname,Method)) venv with
match SailEnv.get_decl name (Specific (mname.value,Method)) venv with
| Some {llval;extern;_} -> return (llval,extern)
| None -> E.throw Logging.(make_msg loc @@ Printf.sprintf "implementation of %s not found" mangled_name )
| None -> E.throw Logging.(make_msg mname.loc @@ Printf.sprintf "implementation of %s not found" mangled_name )
end
| Some {llval;extern;_} -> return (llval,extern)
in

let+ args =
if ext then
ListM.map2 (fun t v ->
let+ t = Env.TypeEnv.get_from_id t tenv in
ListM.map2 (fun (t:IrThir.ThirUtils.exp_tag) v ->
let+ t = Env.TypeEnv.get_from_id (mk_locatable t.loc t.ty) tenv in
let builder =
match snd (ty_of_alias t (snd venv)) with
match (ty_of_alias t (snd venv)).value with
| Bool | Int _ | Char -> L.build_zext
| Float -> L.build_bitcast
| CompoundType _ -> fun v _ _ _ -> v
Expand All @@ -118,16 +124,16 @@ and construct_call (name:string) ((loc,mname):l_str) (args:AstMir.expression lis
in
L.build_call llval (Array.of_list args) "" llvm.b

open AstMir
open MirAst

let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,tenv : SailEnv.t*Env.TypeEnv.t) : unit E.t =
let declare_var (mut:bool) (name:string) (ty:sailtype) (exp:AstMir.expression option) (venv : SailEnv.t) : SailEnv.t E.t =
let declare_var (mut:bool) (name:string) (ty:sailtype) (exp:MirAst.expression option) (venv : SailEnv.t) : SailEnv.t E.t =
let _ = mut in (* todo manage mutable types *)
let entry_b = L.(entry_block proto |> instr_begin |> builder_at llvm.c) in
let* v =
match exp with
| Some e ->
let* t = Env.TypeEnv.get_from_id e.info tenv
let* t = Env.TypeEnv.get_from_id (mk_locatable e.tag.loc e.tag.ty) tenv
and* v = eval_r (venv,tenv) llvm e in
let+ ty = getLLVMType (snd venv) t llvm.c llvm.m in
let x = L.build_alloca ty name entry_b in
Expand All @@ -154,20 +160,21 @@ let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,t
let llvm_bbs = BlockMap.add lbl llvm_bb llvm_bbs in
L.position_at_end llvm_bb llvm.b;
let* () = ListM.iter (fun x -> assign_var x.target x.expression (venv,tenv)) bb.assignments in
match bb.terminator with
| Some (Return e) ->
let* terminator = E.throw_if_none Logging.(make_msg bb.location "no terminator : mir is broken") bb.terminator in
match terminator with
| Return e ->
let+ ret = match e with
| Some r -> let+ v = eval_r (venv,tenv) llvm r in L.build_ret v
| None -> return L.build_ret_void
in ret llvm.b |> ignore; llvm_bbs

| Some (Goto lbl) ->
| Goto lbl ->
let+ llvm_bbs = aux lbl llvm_bbs venv in
L.position_at_end llvm_bb llvm.b;
let _ = L.build_br (BlockMap.find lbl llvm_bbs) llvm.b in
llvm_bbs

| Some (Invoke f) ->
| Invoke f ->
let* c = construct_call f.id f.origin f.params (venv,tenv) llvm in
begin
match f.target with
Expand All @@ -179,7 +186,7 @@ let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,t
L.build_br (BlockMap.find f.next llvm_bbs) llvm.b |> ignore;
llvm_bbs

| Some (SwitchInt si) ->
| SwitchInt si ->
let* sw_val = eval_r (venv,tenv) llvm si.choice in
let sw_val = L.build_intcast sw_val (L.i32_type llvm.c) "" llvm.b in (* for condition, expression val will be bool *)
let* llvm_bbs = aux si.default llvm_bbs venv in
Expand All @@ -192,9 +199,7 @@ let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,t
in L.add_case sw n (BlockMap.find lbl bm);
bm
) llvm_bbs si.paths

| None -> E.throw Logging.(make_msg bb.location "no terminator : mir is broken")
| Some Break -> E.throw Logging.(make_msg bb.location "no break should be there")
| Break -> E.throw Logging.(make_msg bb.location "no break should be there")
end
in
(
Expand Down
2 changes: 1 addition & 1 deletion src/common/builtins.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ let register_builtin name generics p rtype variadic l: method_sig list =

let get_builtins () : method_sig list =
[]
|> register_builtin "box" ["T"] [dummy_pos,GenericType "T"] (Some (dummy_pos,Box (dummy_pos,GenericType "T"))) false
|> register_builtin "box" ["T"] [mk_locatable dummy_pos (GenericType "T")] (Some (mk_locatable dummy_pos (Box (mk_locatable dummy_pos (GenericType "T"))))) false
6 changes: 3 additions & 3 deletions src/common/env.ml
Original file line number Diff line number Diff line change
Expand Up @@ -388,20 +388,20 @@ module TypeEnv = struct
let empty = FieldMap.empty
let get_id ty (te :t) : string * t =
let add_if_no_exists s = FieldMap.update s (Option.fold ~none:(Some ty) ~some:Option.some) in
let s = match snd ty with
let s = match ty.value with
| Bool -> "bool"
| Float -> "float"
| Char -> "char"
| String -> "string"
| Int n -> "int" ^ string_of_int n
| ArrayType _ -> "array"
| GenericType t -> t
| CompoundType t -> snd t.name
| CompoundType t -> t.name.value
| Box _ -> "box"
| RefType _ -> "ref"
in s,add_if_no_exists s te

let get_from_id (lid,id:l_str) te : sailtype E.t = E.throw_if_none Logging.(make_msg lid @@ Fmt.str "id '%s' not found" id) (FieldMap.find_opt id te)
let get_from_id (id:l_str) te : sailtype E.t = E.throw_if_none Logging.(make_msg id.loc @@ Fmt.str "id '%s' not found" id.value) (FieldMap.find_opt id.value te)

end

Expand Down
8 changes: 4 additions & 4 deletions src/common/ppCommon.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ let pp_binop pf b =



let rec pp_type (pf : formatter) (_,t : sailtype) : unit =
match t with
let rec pp_type (pf : formatter) (t : sailtype) : unit =
match t.value with
Bool -> pp_print_string pf "bool"
| Int n -> Format.fprintf pf "i%i" n
| Float -> pp_print_string pf "float"
| Char -> pp_print_string pf "char"
| String -> pp_print_string pf "string"
| ArrayType (t,s) -> Format.fprintf pf "array<%a;%d>" pp_type t s
| CompoundType {name=(_,x); generic_instances;_} ->
Format.fprintf pf "%s<%a>" x (pp_print_list ~pp_sep:pp_comma pp_type) generic_instances
| CompoundType {name; generic_instances;_} ->
Format.fprintf pf "%s<%a>" name.value (pp_print_list ~pp_sep:pp_comma pp_type) generic_instances
| Box(t) -> Format.fprintf pf "ref<%a>" pp_type t
| RefType (t,b) ->
if b then Format.fprintf pf "&mut %a" pp_type t
Expand Down
2 changes: 1 addition & 1 deletion src/common/sailModule.ml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ let method_decl_of_defn (d : 'a method_defn) : Declarations.method_decl =
and args = d.m_proto.params
and generics = d.m_proto.generics
and variadic = d.m_proto.variadic in
((pos,name),{ret;args;generics;variadic})
(mk_locatable pos name,{ret;args;generics;variadic})
Loading

0 comments on commit 1b57858

Please sign in to comment.