Skip to content

Commit

Permalink
Merge pull request #62 from janestreet/function-arg-destruct
Browse files Browse the repository at this point in the history
Backport fix for destructing function args
  • Loading branch information
goldfirere authored May 14, 2024
2 parents 7e41681 + a870c36 commit 292713b
Show file tree
Hide file tree
Showing 5 changed files with 427 additions and 192 deletions.
367 changes: 188 additions & 179 deletions src/analysis/destruct.ml
Original file line number Diff line number Diff line change
Expand Up @@ -255,57 +255,79 @@ let rec get_match = function
let s = Mbrowse.print_node () parent in
raise (Not_allowed s)

let rec get_every_pattern = function

let collect_every_pattern_for_expression parent =
let patterns =
Mbrowse.fold_node (fun env node acc ->
match node with
| Pattern _ -> (* Not expected here *) raise Nothing_to_do
| Case _ ->
Mbrowse.fold_node (fun _env node acc ->
match node with
| Pattern p ->
let ill_typed_pred = Typedtree.{ f = fun p ->
List.memq Msupport.incorrect_attribute ~set:p.pat_attributes }
in
if Typedtree.exists_general_pattern ill_typed_pred p
then raise Ill_typed
else begin
match Typedtree.classify_pattern p with
| Value -> (p : Typedtree.pattern) :: acc
| Computation ->
begin
match Typedtree.split_pattern p with
| Some p, _ -> (p : Typedtree.pattern) :: acc
| None, _ -> acc
end
end
| _ -> acc
) env node acc
| _ -> acc
) Env.empty parent []
in
let loc = Mbrowse.fold_node (fun _ node acc ->
let open Location in
let loc = Mbrowse.node_loc node in
if Lexing.compare_pos loc.loc_end acc.loc_end > 0 then loc else acc
) Env.empty parent Location.none
in loc, patterns

let collect_function_pattern loc param_pattern =
match param_pattern.Typedtree.fp_kind with
| Typedtree.Tparam_pat pattern ->
loc, [pattern]
| Typedtree.Tparam_optional_default _ ->
raise (Not_allowed "value_binding")

let rec get_every_pattern loc = function
| [] -> assert false
| parent :: parents ->
match parent with
| Case _
| Pattern _ ->
(* We are still in the same branch, going up. *)
get_every_pattern parents
get_every_pattern loc parents
| Expression { exp_desc = Typedtree.Texp_ident (Path.Pident id, _, _, _, _) ; _}
when Ident.name id = "*type-error*" ->
raise (Ill_typed)
| Expression { exp_desc = Typedtree.Texp_function {params; _}; _ } ->
(* So we need to deal with the case where we're either in the body of a
function, or in a function parameter. *)
begin
match
List.find_some ~f:(fun param ->
Location_aux.included ~into:param.Typedtree.fp_loc loc
) params with
| Some pattern ->
(* In parameter case *)
collect_function_pattern loc pattern
| None ->
(* In function body *)
collect_every_pattern_for_expression parent
end
| Expression _ ->
(* We are on the right node *)
let patterns : Typedtree.pattern list =
Mbrowse.fold_node (fun env node acc ->
match node with
| Pattern _ ->
raise (Not_allowed ("pattern in function argument"))
| Case _ ->
Mbrowse.fold_node (fun _env node acc ->
match node with
| Pattern p ->
let ill_typed_pred : Typedtree.pattern_predicate =
{ f = fun p ->
List.memq Msupport.incorrect_attribute
~set:p.pat_attributes }
in
if Typedtree.exists_general_pattern ill_typed_pred p then
raise Ill_typed;
begin match Typedtree.classify_pattern p with
| Value -> let p : Typedtree.pattern = p in p :: acc
| Computation -> let val_p, _ = Typedtree.split_pattern p in
(* We ignore computation patterns *)
begin match val_p with
| Some val_p -> val_p :: acc
| None -> acc
end
end
| _ -> acc
) env node acc
| _ -> acc
) Env.empty parent []
in
let loc =
Mbrowse.fold_node (fun _ node acc ->
let open Location in
let loc = Mbrowse.node_loc node in
if Lexing.compare_pos loc.loc_end acc.loc_end > 0 then loc else acc
) Env.empty parent Location.none
in
loc, patterns
collect_every_pattern_for_expression parent
| _ ->
(* We were not in a match *)
let s = Mbrowse.print_node () parent in
Expand Down Expand Up @@ -577,149 +599,136 @@ module Conv = struct
(ps, constrs, labels)
end

let destruct_expression loc config source parents expr =
let ty = expr.Typedtree.exp_type in
let pexp = filter_expr_attr (Untypeast.untype_expression expr) in
let () =
log ~title:"node_expression" "%a"
Logger.fmt (fun fmt -> Printast.expression 0 fmt pexp)
in
let needs_parentheses, result =
if is_package (Types.Transient_expr.repr ty) then
let mode = Ast_helper.Mod.unpack pexp in
false, Ast_helper.Exp.letmodule_no_opt "M" mode placeholder
else
let ps = gen_patterns expr.Typedtree.exp_env ty in
let cases = List.map ps ~f:(fun patt ->
let pc_lhs = filter_pat_attr (Untypeast.untype_pattern patt) in
{ Parsetree. pc_lhs ; pc_guard = None ; pc_rhs = placeholder }
) in
needs_parentheses parents, Ast_helper.Exp.match_ pexp cases
in
let str = Mreader.print_pretty config source (Pretty_expression result) in
let str = if needs_parentheses then "(" ^ str ^ ")" else str in
loc, str

let refine_partial_match last_case_loc config source patterns =
let cases = List.map patterns ~f:(fun pat ->
let _pat, constrs, labels = Conv.conv pat in
let unmangling_tables = constrs, labels in
(* Unmangling and prefixing *)
let pat = qualify_constructors ~unmangling_tables Printtyp.shorten_type_path pat in
(* Untyping and casing *)
let ppat = filter_pat_attr (Untypeast.untype_pattern pat) in
Ast_helper.Exp.case ppat placeholder
) in
let loc = Location.{ last_case_loc with loc_start = last_case_loc.loc_end } in
let str = Mreader.print_pretty config source (Pretty_case_list cases) in
loc, str

let filter_new_branches new_branches patterns =
let unused = Parmatch.return_unused patterns in
List.fold_left unused ~init:new_branches ~f:(fun branches u ->
match u with
| `Unused p -> List.remove ~phys:true p branches
| `Unused_subs (p, lst) ->
List.map branches ~f:(fun branch ->
if branch != p then branch else
List.fold_left lst ~init:branch ~f:rm_sub))

let refine_current_pattern patt config source generated_pattern =
let ppat = filter_pat_attr (Untypeast.untype_pattern generated_pattern) in
let str = Mreader.print_pretty config source (Pretty_pattern ppat) in
patt.Typedtree.pat_loc, str

let refine_and_generate_branches patt config source patterns sub_patterns =
let rev_before, after, top_patt = find_branch patterns patt in
let new_branches =
List.map sub_patterns ~f:(fun by -> subst_patt patt ~by top_patt)
in
let patterns = after @ rev_before @ new_branches in
match filter_new_branches new_branches patterns with
| [] -> raise Useless_refine
| p :: ps ->
let p = List.fold_left ps ~init:p ~f:(fun acc p ->
Tast_helper.Pat.pat_or
top_patt.Typedtree.pat_env
top_patt.Typedtree.pat_type acc p)
in
(* Format.eprintf "por %a \n%!" (Printtyped.pattern 0) p; *)
let ppat = filter_pat_attr (Untypeast.untype_pattern p) in
(* Format.eprintf "ppor %a \n%!" (Pprintast.pattern) ppat; *)
let str = Mreader.print_pretty config source (Pretty_pattern ppat) in
(* Format.eprintf "STR: %s \n %!" str; *)
top_patt.Typedtree.pat_loc, str

let refine_complete_match
(type a) (patt: a Typedtree.general_pattern)
config source patterns =
match Typedtree.classify_pattern patt with
| Computation -> raise (Not_allowed ("computation pattern"))
| Value ->
let _: Typedtree.value Typedtree.general_pattern = patt in
if not (destructible patt) then raise Nothing_to_do
else
let ty = patt.Typedtree.pat_type in
begin match gen_patterns patt.Typedtree.pat_env ty with
| [] -> assert false
| [more_precise_pattern] ->
(* If only one pattern is generated, then we're only refining the
current pattern, not generating new branches. *)
refine_current_pattern patt config source more_precise_pattern
| sub_patterns ->
(* If more than one pattern is generated, then we're generating new
branches. *)
refine_and_generate_branches patt config source patterns sub_patterns
end

let destruct_pattern
(type a) (patt: a Typedtree.general_pattern)
config source loc parents =
let last_case_loc, patterns = get_every_pattern loc parents in
(* Printf.eprintf "tot %d o%!"(List.length patterns); *)
let () = List.iter patterns ~f:(fun p ->
let p = filter_pat_attr (Untypeast.untype_pattern p) in
log ~title:"EXISTING" "%t"
(fun () -> Mreader.print_pretty config source (Pretty_pattern p)))
in
let pss = List.map patterns ~f:(fun x -> [ x ]) in
let m, e_typ = get_match parents in
let pred = Typecore.partial_pred ~lev:Btype.generic_level m.Typedtree.exp_env e_typ in
match Parmatch.complete_partial ~pred pss with
| [] ->
(* The match is already complete, we try to refine it *)
refine_complete_match patt config source patterns
| patterns ->
refine_partial_match last_case_loc config source patterns

let rec destruct_record config source selected_node = function
| Expression { exp_desc = Texp_field _; _ } as parent :: rest ->
node config source parent rest
| Expression e :: rest ->
node config source (Expression e) rest
| _ ->
raise (Not_allowed (string_of_node selected_node))


let rec node config source selected_node parents =
let open Extend_protocol.Reader in
and node config source selected_node parents =
let loc = Mbrowse.node_loc selected_node in
match selected_node with
| Record_field (`Expression _, _, _) ->
begin match parents with
| Expression { exp_desc = Texp_field _; _ } as parent :: rest ->
node config source parent rest
| Expression e :: rest ->
node config source (Expression e) rest
| _ ->
raise (Not_allowed (string_of_node selected_node))
end
destruct_record config source selected_node parents
| Expression expr ->
let ty = expr.Typedtree.exp_type in
let pexp = filter_expr_attr (Untypeast.untype_expression expr) in
log ~title:"node_expression" "%a"
Logger.fmt (fun fmt -> Printast.expression 0 fmt pexp);
let needs_parentheses, result =
if is_package (Types.Transient_expr.repr ty) then (
let mode = Ast_helper.Mod.unpack pexp in
false, Ast_helper.Exp.letmodule_no_opt "M" mode placeholder
) else (
let ps = gen_patterns expr.Typedtree.exp_env ty in
let cases =
List.map ps ~f:(fun patt ->
let pc_lhs = filter_pat_attr (Untypeast.untype_pattern patt) in
{ Parsetree. pc_lhs ; pc_guard = None ; pc_rhs = placeholder }
)
in
needs_parentheses parents, Ast_helper.Exp.match_ pexp cases
)
in
let str = Mreader.print_pretty
config source (Pretty_expression result) in
let str = if needs_parentheses then "(" ^ str ^ ")" else str in
loc, str
| Pattern patt ->
begin let last_case_loc, patterns = get_every_pattern parents in
(* Printf.eprintf "tot %d o%!"(List.length patterns); *)
List.iter patterns ~f:(fun p ->
let p = filter_pat_attr (Untypeast.untype_pattern p) in
log ~title:"EXISTING" "%t"
(fun () -> Mreader.print_pretty config source (Pretty_pattern p))
) ;
let pss = List.map patterns ~f:(fun x -> [ x ]) in
let m, e_typ = get_match parents in
let pred = Typecore.partial_pred
~lev:Btype.generic_level
m.Typedtree.exp_env
e_typ
in
begin match Parmatch.complete_partial ~pred pss with
| _ :: _ as patterns ->
let cases =
List.map patterns ~f:(fun pat ->
let _pat, constrs, labels = Conv.conv pat in
let unmangling_tables = constrs, labels in
(* Unmangling and prefixing *)
let pat =
qualify_constructors ~unmangling_tables
Printtyp.shorten_type_path pat
in

(* Untyping and casing *)
let ppat = filter_pat_attr (Untypeast.untype_pattern pat) in
Ast_helper.Exp.case ppat placeholder
)
in
let loc =
let open Location in
{ last_case_loc with loc_start = last_case_loc.loc_end }
in

(* Pretty printing *)
let str = Mreader.print_pretty config source (Pretty_case_list cases) in
loc, str
| [] ->
(* The match is already complete, we try to refine it *)
begin match Typedtree.classify_pattern patt with
| Computation -> raise (Not_allowed ("computation pattern"));
| Value ->
let _patt : Typedtree.value Typedtree.general_pattern = patt in
if not (destructible patt) then raise Nothing_to_do else
let ty = patt.Typedtree.pat_type in
begin match gen_patterns patt.Typedtree.pat_env ty with
| [] ->
(* gen_patterns might raise Not_allowed, but should never return [] *)
assert false
| [ more_precise ] ->
(* If only one pattern is generated, then we're only refining the
current pattern, not generating new branches. *)
let ppat = filter_pat_attr (Untypeast.untype_pattern more_precise) in
let str = Mreader.print_pretty
config source (Pretty_pattern ppat) in
patt.Typedtree.pat_loc, str
| sub_patterns ->
let rev_before, after, top_patt =
find_branch patterns patt
in
let new_branches =
List.map sub_patterns ~f:(fun by ->
subst_patt patt ~by top_patt
)
in
let patterns =
List.rev_append rev_before
(List.append new_branches after)
in
let unused = Parmatch.return_unused patterns in
let new_branches =
List.fold_left unused ~init:new_branches ~f:(fun branches u ->
match u with
| `Unused p -> List.remove ~phys:true p branches
| `Unused_subs (p, lst) ->
List.map branches ~f:(fun branch ->
if branch != p then branch else
List.fold_left lst ~init:branch ~f:rm_sub
)
)
in
(* List.iter ~f:(Format.eprintf "multi cp %a \n%!" (Printtyped.pattern 0)) new_branches ; *)
match new_branches with
| [] -> raise Useless_refine
| p :: ps ->
let p =
List.fold_left ps ~init:p ~f:(fun acc p ->
Tast_helper.Pat.pat_or top_patt.Typedtree.pat_env
top_patt.Typedtree.pat_type acc p
)
in
(* Format.eprintf "por %a \n%!" (Printtyped.pattern 0) p; *)
let ppat = filter_pat_attr (Untypeast.untype_pattern p) in
(* Format.eprintf "ppor %a \n%!" (Pprintast.pattern) ppat; *)
let str = Mreader.print_pretty
config source (Pretty_pattern ppat) in
(* Format.eprintf "STR: %s \n %!" str; *)
top_patt.Typedtree.pat_loc, str
end
end
end
end
destruct_expression loc config source parents expr
| Pattern patt -> destruct_pattern patt config source loc parents
| node ->
raise (Not_allowed (string_of_node node))
Loading

0 comments on commit 292713b

Please sign in to comment.