Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rigid unification option for hint solve/exact #680

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/ecCommands.ml
Original file line number Diff line number Diff line change
Expand Up @@ -264,17 +264,22 @@ module HiPrinting = struct
let ppe0 = EcPrinting.PPEnv.ofenv env in
EcPrinting.pp_by_theory ppe0 (EcPrinting.pp_axiom) fmt ax

(* ------------------------------------------------------------------ *)
let pr_hint_solve (fmt : Format.formatter) (env : EcEnv.env) =
let hint_solve = EcEnv.Auto.all env in
let hint_solve = List.map (fun p ->
(p, EcEnv.Ax.by_path p env)
let hint_solve = List.map (fun (p, mode) ->
let ax = EcEnv.Ax.by_path p env in
(p, (ax, mode))
) hint_solve in

let ppe = EcPrinting.PPEnv.ofenv env in

let pp_hint_solve ppe fmt pax =
Format.fprintf fmt "%a" (EcPrinting.pp_axiom ppe) pax
let pp_hint_solve ppe fmt = (fun (p, (ax, mode)) ->
let mode =
match mode with
| `Default -> ""
| `Rigid -> "(rigid)" in
Format.fprintf fmt "%a %s" (EcPrinting.pp_axiom ppe) (p, ax) mode
)
in

EcPrinting.pp_by_theory ppe pp_hint_solve fmt hint_solve
Expand Down
59 changes: 40 additions & 19 deletions src/ecEnv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ type preenv = {
env_tci : ((ty_params * ty) * tcinstance) list;
env_tc : TC.graph;
env_rwbase : Sp.t Mip.t;
env_atbase : (path list Mint.t) Msym.t;
env_atbase : atbase Msym.t;
env_redbase : mredinfo;
env_ntbase : ntbase Mop.t;
env_modlcs : Sid.t; (* declared modules *)
Expand Down Expand Up @@ -221,6 +221,10 @@ and env_notation = ty_params * EcDecl.notation

and ntbase = (path * env_notation) list

and atbase0 = path * [`Rigid | `Default]

and atbase = atbase0 list Mint.t

(* -------------------------------------------------------------------- *)
type env = preenv

Expand Down Expand Up @@ -1516,39 +1520,53 @@ end

(* -------------------------------------------------------------------- *)
module Auto = struct
type base0 = path * [`Rigid | `Default]

let dname : symbol = ""

let updatedb ~level ?base (ps : path list) (db : (path list Mint.t) Msym.t) =
let updatedb
~(level : int)
?(base : symbol option)
(ps : atbase0 list)
(db : atbase Msym.t)
=
let nbase = (odfl dname base) in
let ps' = Msym.find_def Mint.empty nbase db in
let ps' =
let base = Msym.find_def Mint.empty nbase db in
let levels =
let doit x = Some (ofold (fun x ps -> ps @ x) ps x) in
Mint.change doit level ps' in
Msym.add nbase ps' db

let add ?(import = import0) ~level ?base (ps : path list) lc (env : env) =
Mint.change doit level base in
Msym.add nbase levels db

let add
?(import = import0)
~(level : int)
?(base : symbol option)
(axioms : atbase0 list)
(locality : is_local)
(env : env)
=
let env =
if import.im_immediate then
{ env with
env_atbase = updatedb ?base ~level ps env.env_atbase; }
env_atbase = updatedb ?base ~level axioms env.env_atbase; }
else env
in
{ env with env_item = mkitem import
(Th_auto (level, base, ps, lc)) :: env.env_item; }
(Th_auto { level; base; axioms; locality; }) :: env.env_item; }

let add1 ?import ~level ?base (p : path) lc (env : env) =
let add1 ?import ~level ?base (p : atbase0) lc (env : env) =
add ?import ?base ~level [p] lc env

let get_core ?base (env : env) =
Msym.find_def Mint.empty (odfl dname base) env.env_atbase

let flatten_db (db : path list Mint.t) =
let flatten_db (db : atbase) =
Mint.fold_left (fun ps _ ps' -> ps @ ps') [] db

let get ?base (env : env) =
flatten_db (get_core ?base env)

let getall (bases : symbol list) (env : env) =
let getall (bases : symbol list) (env : env) : atbase0 list =
let dbs = List.map (fun base -> get_core ~base env) bases in
let dbs =
List.fold_left (fun db mi ->
Expand All @@ -1560,7 +1578,7 @@ module Auto = struct
let db = Msym.find_def Mint.empty base env.env_atbase in
Mint.bindings db

let all (env : env) : path list =
let all (env : env) : atbase0 list =
Msym.values env.env_atbase |> List.map flatten_db |> List.flatten
end

Expand Down Expand Up @@ -2951,8 +2969,8 @@ module Theory = struct
(* ------------------------------------------------------------------ *)
let bind_at_th =
let for1 _path db = function
| Th_auto (level, base, ps, _) ->
Some (Auto.updatedb ?base ~level ps db)
| Th_auto {level; base; axioms; _} ->
Some (Auto.updatedb ?base ~level axioms db)
| _ -> None

in bind_base_th for1
Expand Down Expand Up @@ -3125,9 +3143,12 @@ module Theory = struct
let ps = List.filter ((not) |- inclear |- oget |- EcPath.prefix) ps in
if List.is_empty ps then None else Some (Th_addrw (p, ps,lc))

| Th_auto (lvl, base, ps, lc) ->
let ps = List.filter ((not) |- inclear |- oget |- EcPath.prefix) ps in
if List.is_empty ps then None else Some (Th_auto (lvl, base, ps, lc))
| Th_auto ({ axioms } as auto_rl) ->
let axioms = List.filter (fun (p, _) ->
let p = oget (EcPath.prefix p) in
not (inclear p)
) axioms in
if List.is_empty axioms then None else Some (Th_auto {auto_rl with axioms})

| (Th_export (p, _)) as item ->
if Sp.mem p cleared then None else Some item
Expand Down
14 changes: 8 additions & 6 deletions src/ecEnv.mli
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,15 @@ end

(* -------------------------------------------------------------------- *)
module Auto : sig
type base0 = path * [`Rigid | `Default]

val dname : symbol
val add1 : ?import:import -> level:int -> ?base:symbol -> path -> is_local -> env -> env
val add : ?import:import -> level:int -> ?base:symbol -> path list -> is_local -> env -> env
val get : ?base:symbol -> env -> path list
val getall : symbol list -> env -> path list
val getx : symbol -> env -> (int * path list) list
val all : env -> path list
val add1 : ?import:import -> level:int -> ?base:symbol -> base0 -> is_local -> env -> env
val add : ?import:import -> level:int -> ?base:symbol -> base0 list -> is_local -> env -> env
val get : ?base:symbol -> env -> base0 list
val getall : symbol list -> env -> base0 list
val getx : symbol -> env -> (int * base0 list) list
val all : env -> base0 list
end

(* -------------------------------------------------------------------- *)
Expand Down
33 changes: 19 additions & 14 deletions src/ecLowGoal.ml
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ module Apply = struct

exception NoInstance of (bool * reason * PT.pt_env * (form * form))

let t_apply_bwd_r ?(mode = fmdelta) ?(canview = true) pt (tc : tcenv1) =
let t_apply_bwd_r ?(ri = EcReduction.full_compat) ?(mode = fmdelta) ?(canview = true) pt (tc : tcenv1) =
let ((hyps, concl), pterr) = (FApi.tc1_flat tc, PT.copy pt.ptev_env) in

let noinstance ?(dpe = false) reason =
Expand All @@ -736,7 +736,7 @@ module Apply = struct
match istop && PT.can_concretize pt.PT.ptev_env with
| true ->
let ax = PT.concretize_form pt.PT.ptev_env pt.PT.ptev_ax in
if EcReduction.is_conv ~ri:EcReduction.full_compat hyps ax concl
if EcReduction.is_conv ~ri hyps ax concl
then pt
else instantiate canview false pt

Expand All @@ -747,7 +747,7 @@ module Apply = struct
noinstance `IncompleteInference;
pt
with EcMatching.MatchFailure ->
match TTC.destruct_product hyps pt.PT.ptev_ax with
match TTC.destruct_product ~reduce:(mode.fm_conv) hyps pt.PT.ptev_ax with
| Some _ ->
(* FIXME: add internal marker *)
instantiate canview false (PT.apply_pterm_to_hole pt)
Expand Down Expand Up @@ -800,15 +800,15 @@ module Apply = struct

t_apply pt tc

let t_apply_bwd ?mode ?canview pt (tc : tcenv1) =
let t_apply_bwd ?(ri : EcReduction.reduction_info option) ?mode ?canview pt (tc : tcenv1) =
let hyps = FApi.tc1_hyps tc in
let pt, ax = LowApply.check `Elim pt (`Hyps (hyps, !!tc)) in
let ptenv = ptenv_of_penv hyps !!tc in
let pt = { ptev_env = ptenv; ptev_pt = pt; ptev_ax = ax; } in
t_apply_bwd_r ?mode ?canview pt tc
t_apply_bwd_r ?ri ?mode ?canview pt tc

let t_apply_bwd_hi ?(dpe = true) ?mode ?canview pt (tc : tcenv1) =
try t_apply_bwd ?mode ?canview pt tc
let t_apply_bwd_hi ?(ri : EcReduction.reduction_info option) ?(dpe = true) ?mode ?canview pt (tc : tcenv1) =
try t_apply_bwd ?ri ?mode ?canview pt tc
with (NoInstance (_, r, pt, f)) ->
tc_error_exn !!tc (NoInstance (dpe, r, pt, f))
end
Expand Down Expand Up @@ -2582,22 +2582,27 @@ let t_coq
let t_solve ?(canfail = true) ?(bases = [EcEnv.Auto.dname]) ?(mode = fmdelta) ?(depth = 1) (tc : tcenv1) =
let bases = EcEnv.Auto.getall bases (FApi.tc1_env tc) in

let t_apply1 p tc =

let t_apply1 ((p, rigid): Auto.base0) tc =
let ri, mode =
match rigid with
| `Rigid -> EcReduction.no_red, fmsearch
| `Default -> EcReduction.full_compat, mode in
let pt = PT.pt_of_uglobal !!tc (FApi.tc1_hyps tc) p in
try
Apply.t_apply_bwd_r ~mode ~canview:false pt tc
with Apply.NoInstance _ -> t_fail tc in
Apply.t_apply_bwd_r ~ri ~mode ~canview:false pt tc
with Apply.NoInstance _ ->
t_fail tc
in

let rec t_apply ctn p tc =
let rec t_apply ctn ip tc =
if ctn > depth
then t_fail tc
else (t_apply1 p @! t_trivial @! t_solve (ctn + 1) bases) tc
else (t_apply1 ip @! t_trivial @! t_solve (ctn + 1) bases) tc

and t_solve ctn bases tc =
match bases with
| [] -> t_abort tc
| p::bases -> (FApi.t_or (t_apply ctn p) (t_solve ctn bases)) tc in
| ip::bases -> (FApi.t_or (t_apply ctn ip) (t_solve ctn bases)) tc in

let t = t_solve 0 bases in
let t = if canfail then FApi.t_try t else t in
Expand Down
7 changes: 3 additions & 4 deletions src/ecLowGoal.mli
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,13 @@ module Apply : sig
exception NoInstance of (bool * reason * pt_env * (form * form))

val t_apply_bwd_r :
?mode:fmoptions -> ?canview:bool -> pt_ev -> FApi.backward
?ri:EcReduction.reduction_info -> ?mode:fmoptions -> ?canview:bool -> pt_ev -> FApi.backward

val t_apply_bwd :
?mode:fmoptions -> ?canview:bool -> proofterm -> FApi.backward
?ri:EcReduction.reduction_info -> ?mode:fmoptions -> ?canview:bool -> proofterm -> FApi.backward

val t_apply_bwd_hi:
?dpe:bool -> ?mode:fmoptions -> ?canview:bool
-> proofterm -> FApi.backward
?ri:EcReduction.reduction_info -> ?dpe:bool -> ?mode:fmoptions -> ?canview:bool -> proofterm -> FApi.backward
end

(* -------------------------------------------------------------------- *)
Expand Down
25 changes: 18 additions & 7 deletions src/ecParser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -3699,14 +3699,25 @@ addrw:
| local=is_local HINT REWRITE p=lqident COLON l=lqident*
{ (local, p, l) }

hint:
| local=is_local HINT EXACT base=lident? COLON l=qident*
{ { ht_local = local; ht_prio = 0;
ht_base = base ; ht_names = l; } }
hintoption:
| x=lident {
match unloc x with
| "rigid" -> `Rigid
| _ ->
parse_error x.pl_loc
(Some ("invalid option: " ^ (unloc x)))
}

| local=is_local HINT SOLVE i=word base=lident? COLON l=qident*
{ { ht_local = local; ht_prio = i;
ht_base = base ; ht_names = l; } }
hint:
| local=is_local
HINT opts=ioption(bracket(hintoption)+)
prio=ID(EXACT { 0 } | SOLVE i=word { i })
base=lident? COLON l=qident*
{ { ht_local = local;
ht_prio = prio;
ht_base = base ;
ht_names = l;
ht_options = odfl [] opts; } }

(* -------------------------------------------------------------------- *)
(* User reduction *)
Expand Down
12 changes: 8 additions & 4 deletions src/ecParsetree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1228,12 +1228,16 @@ type save = [ `Qed | `Admit | `Abort ]
(* -------------------------------------------------------------------- *)
type theory_clear = (pqsymbol option) list

(* -------------------------------------------------------------------- *)
type phintoption = [ `Rigid ]

(* -------------------------------------------------------------------- *)
type phint = {
ht_local : is_local;
ht_prio : int;
ht_base : psymbol option;
ht_names : pqsymbol list;
ht_local : is_local;
ht_prio : int;
ht_base : psymbol option;
ht_names : pqsymbol list;
ht_options : phintoption list;
}

(* -------------------------------------------------------------------- *)
Expand Down
35 changes: 25 additions & 10 deletions src/ecPrinting.ml
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,13 @@ let pp_rwname ppe fmt p =
let pp_axname ppe fmt p =
Format.fprintf fmt "%a" EcSymbols.pp_qsymbol (PPEnv.ax_symb ppe p)

let pp_axhnt ppe fmt (p, b) =
let b =
match b with
| `Default -> ""
| `Rigid -> " (rigid)" in
Format.fprintf fmt "%a%s" (pp_axname ppe) p b

(* -------------------------------------------------------------------- *)
let pp_thname ppe fmt p =
EcSymbols.pp_qsymbol fmt (PPEnv.th_symb ppe p)
Expand Down Expand Up @@ -3020,23 +3027,31 @@ let pp_rwbase ppe fmt (p, rws) =
(pp_rwname ppe) p (pp_list ", " (pp_axname ppe)) (Sp.elements rws)

(* -------------------------------------------------------------------- *)
let pp_solvedb ppe fmt db =
let pp_solvedb ppe fmt (db: (int * (P.path * _) list) list) =
List.iter (fun (lvl, ps) ->
Format.fprintf fmt "[%3d] %a\n%!"
lvl (pp_list ", " (pp_axname ppe)) ps)
lvl
(pp_list ", " (pp_axhnt ppe))
ps)
db;

let lemmas = List.flatten (List.map snd db) in
let lemmas = List.pmap (fun p ->
let lemmas = List.pmap (fun (p, ir) ->
let ax = EcEnv.Ax.by_path_opt p ppe.PPEnv.ppe_env in
(omap (fun ax -> (p, ax)) ax))
lemmas
(omap (fun ax -> (ir, (p, ax))) ax)
) lemmas
in

if not (List.is_empty lemmas) then begin
Format.fprintf fmt "\n%!";
List.iter
(fun ax -> Format.fprintf fmt "%a\n\n%!" (pp_axiom ppe) ax)
(fun (ir, ax) ->
let ir =
match ir with
| `Default -> ""
| `Rigid -> " (rigid)" in

Format.fprintf fmt "%a%s\n\n%!" (pp_axiom ppe) ax ir)
lemmas
end

Expand Down Expand Up @@ -3526,11 +3541,11 @@ let rec pp_theory ppe (fmt : Format.formatter) (path, cth) =
(* FIXME: section we should add the lemma in the reduction *)
Format.fprintf fmt "hint simplify."

| EcTheory.Th_auto (lvl, base, p, lc) ->
| EcTheory.Th_auto { level; base; axioms; locality; } ->
Format.fprintf fmt "%ahint solve %d %s : %a."
pp_locality lc
lvl (odfl "" base)
(pp_list "@ " (pp_axname ppe)) p
pp_locality locality
level (odfl "" base)
(pp_list "@ " (pp_axhnt ppe)) axioms

(* -------------------------------------------------------------------- *)
let pp_stmt_with_nums (ppe : PPEnv.t) fmt stmt =
Expand Down
Loading