diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index 238a8d162..21155b490 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -7,6 +7,8 @@ import Mathlib.Data.Finset.Basic import Mathlib.Data.Fintype.Basic import Mathlib.Tactic.Linarith import Mathlib.Tactic.Ring +import SSA.Projects.MLIRSyntax.AST -- TODO post-merge: bring into Core +import SSA.Projects.MLIRSyntax.EDSL -- TODO post-merge: bring into Core open Ctxt (Var VarSet Valuation) open Goedel (toType) @@ -1196,22 +1198,21 @@ theorem denote_rewritePeepholeAt (pr : PeepholeRewrite Op Γ t) section SimpPeephole - /- repeatedly apply peephole on program. -/ section SimpPeepholeApplier /-- rewrite with `pr` to `target` program, at location `ix` and later, running at most `fuel` steps. -/ def rewritePeephole_go (fuel : ℕ) (pr : PeepholeRewrite Op Γ t) - (ix : ℕ) (target : Com Op Γ₂ t₂) : (Com Op Γ₂ t₂) := + (ix : ℕ) (target : Com Op Γ₂ t₂) : (Com Op Γ₂ t₂) := match fuel with | 0 => target - | fuel' + 1 => + | fuel' + 1 => let target' := rewritePeepholeAt pr ix target rewritePeephole_go fuel' pr (ix + 1) target' /-- rewrite with `pr` to `target` program, running at most `fuel` steps. -/ def rewritePeephole (fuel : ℕ) - (pr : PeepholeRewrite Op Γ t) (target : Com Op Γ₂ t₂) : (Com Op Γ₂ t₂) := + (pr : PeepholeRewrite Op Γ t) (target : Com Op Γ₂ t₂) : (Com Op Γ₂ t₂) := rewritePeephole_go fuel pr 0 target /-- `rewritePeephole_go` preserve semantics -/ @@ -1219,9 +1220,9 @@ theorem denote_rewritePeephole_go (pr : PeepholeRewrite Op Γ t) (pos : ℕ) (target : Com Op Γ₂ t₂) : (rewritePeephole_go fuel pr pos target).denote = target.denote := by induction fuel generalizing pr pos target - case zero => + case zero => simp[rewritePeephole_go, denote_rewritePeepholeAt] - case succ fuel' hfuel => + case succ fuel' hfuel => simp[rewritePeephole_go, denote_rewritePeepholeAt, hfuel] /-- `rewritePeephole` preserves semantics. -/ @@ -1239,10 +1240,13 @@ macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : `(tactic| ( try simp (config := {decide := false}) only [ + Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, + Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, Com.denote, Expr.denote, HVector.denote, Var.zero_eq_last, Var.succ_eq_toSnoc, Ctxt.empty, Ctxt.empty_eq, Ctxt.snoc, Ctxt.Valuation.nil, Ctxt.Valuation.snoc_last, Ctxt.ofList, Ctxt.Valuation.snoc_toSnoc, HVector.map, HVector.toPair, HVector.toTuple, OpDenote.denote, Expr.op_mk, Expr.args_mk, + DialectMorphism.mapOp, DialectMorphism.mapTy, List.map, Ctxt.snoc, List.map, $ts,*] generalize $ll { val := 0, property := _ } = a; generalize $ll { val := 1, property := _ } = b; @@ -1270,7 +1274,6 @@ macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : /-- `simp_peephole` with no extra user defined theorems. -/ macro "simp_peephole" "at" ll:ident : tactic => `(tactic| simp_peephole [] at $ll) - end SimpPeephole @@ -1580,7 +1583,7 @@ instance : Goedel ExTy where inductive ExOp : Type | add : ExOp | runK : ℕ → ExOp - deriving DecidableEq + deriving DecidableEq, Repr instance : OpSignature ExOp ExTy where signature @@ -1631,6 +1634,15 @@ def p1 : PeepholeRewrite ExOp [.nat] .nat:= done } +def p1_run : Com ExOp [.nat] .nat := + rewritePeepholeAt p1 0 ex1_lhs + +/- +RegionExamples.ExOp.runK 0[[%0]] +return %1 +-/ +-- #eval p1_run + /-- running `f(x) = x + x` 1 times does return `x + x`. -/ def ex2_lhs : Com ExOp [.nat] .nat := Com.lete (rgn (k := 1) ⟨0, by simp[Ctxt.snoc]⟩ ( diff --git a/SSA/Projects/InstCombine/AliveAutoGenerated.lean b/SSA/Projects/InstCombine/AliveAutoGenerated.lean index 787dee130..c3308cf18 100644 --- a/SSA/Projects/InstCombine/AliveAutoGenerated.lean +++ b/SSA/Projects/InstCombine/AliveAutoGenerated.lean @@ -12,7 +12,48 @@ namespace AliveAutoGenerated set_option pp.proofs false set_option pp.proofs.withType false +def eg0 (w : Nat) := +[alive_icom ( w )| { +^bb0(%C0 : _): + "llvm.return" (%C0) : (_) -> () +}] + +def eg1 (w : Nat) := +[alive_icom ( w )| { +^bb0(%C0 : _): + "llvm.return" (%C0) : (_) -> () +}] + +theorem eg0_eq_eg1 : eg0 w ⊑ eg1 w := by + unfold eg0 + unfold eg1 + simp[DialectMorphism.mapTy, + InstcombineTransformDialect.MOp.instantiateCom, + InstcombineTransformDialect.instantiateMTy, + ConcreteOrMVar.instantiate, Com.Refinement] + +def eg2 (w : Nat) := +[alive_icom ( w )| { +^bb0(%C0 : _): + %v1 = "llvm.add" (%C0, %C0) : (_, _) -> (_) + "llvm.return" (%C0) : (_) -> () +}] + + +def eg3 (w : Nat) := +[alive_icom ( w )| { +^bb0(%C0 : _): + %v1 = "llvm.mlir.constant" () { value = 2 : _ } :() -> (_) + %v2 = "llvm.mul" (%C0, %v1) : (_, _) -> (_) + "llvm.return" (%C0) : (_) -> () +}] +open MLIR AST in +theorem eg2_eq_eg3 : eg2 w ⊑ eg3 w := by + unfold eg2 + unfold eg3 + simp_alive_peephole + simp -- Name:AddSub:1043 -- precondition: true @@ -31,7 +72,7 @@ set_option pp.proofs.withType false -/ def alive_AddSub_1043_src (w : Nat) := -[mlir_icom ( w )| { +[alive_icom ( w )| { ^bb0(%C1 : _, %Z : _, %RHS : _): %v1 = "llvm.and" (%Z,%C1) : (_, _) -> (_) %v2 = "llvm.xor" (%v1,%C1) : (_, _) -> (_) @@ -42,7 +83,7 @@ def alive_AddSub_1043_src (w : Nat) := }] def alive_AddSub_1043_tgt (w : Nat) := -[mlir_icom ( w )| { +[alive_icom ( w )| { ^bb0(%C1 : _, %Z : _, %RHS : _): %v1 = "llvm.not" (%C1) : (_) -> (_) %v2 = "llvm.or" (%Z,%v1) : (_, _) -> (_) @@ -56,8 +97,13 @@ def alive_AddSub_1043_tgt (w : Nat) := theorem alive_AddSub_1043 (w : Nat) : alive_AddSub_1043_src w ⊑ alive_AddSub_1043_tgt w := by unfold alive_AddSub_1043_src alive_AddSub_1043_tgt simp_alive_peephole - apply bitvec_AddSub_1043 + sorry +/-# Early Exit +delete this to check the rest of the file. The `#exit` + is an early exit to allow our `lake` builds to complete in sensible amounts of time. +-/ +#exit /-# Early Exit delete this to check the rest of the file. The `#exit` diff --git a/SSA/Projects/InstCombine/Base.lean b/SSA/Projects/InstCombine/Base.lean index 49e28a115..982a7e224 100644 --- a/SSA/Projects/InstCombine/Base.lean +++ b/SSA/Projects/InstCombine/Base.lean @@ -40,6 +40,9 @@ instance : Repr (MTy φ) where | .bitvec (.concrete w), _ => "i" ++ repr w | .bitvec (.mvar ⟨i, _⟩), _ => f!"i$\{%{i}}" +instance : Lean.ToFormat (MTy φ) where + format := repr + instance : ToString (MTy φ) where toString t := repr t |>.pretty diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 0c4a4acf1..b2dc08875 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -1,28 +1,405 @@ import Qq +import SSA.Projects.InstCombine.Base +import Std.Data.BitVec +import SSA.Projects.MLIRSyntax.AST import SSA.Projects.MLIRSyntax.EDSL import SSA.Projects.InstCombine.LLVM.Transform +-- import SSA.Projects.InstCombine.LLVM.Transform.Dialects.InstCombine -open Qq Lean Meta Elab.Term +open Qq Lean Meta Elab.Term Elab Command +open InstCombine (MOp MTy Width) +open MLIR + +namespace InstcombineTransformDialect + +def mkUnaryOp {Γ : Ctxt (MTy φ)} {ty : (MTy φ)} (op : MOp φ) + (e : Ctxt.Var Γ ty) : MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := + match ty with + | .bitvec w => + match op with + -- Can't use a single arm, Lean won't write the rhs accordingly + | .neg w' => if h : w = w' + then return ⟨ + .neg w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e) .nil, + .nil + ⟩ + else throw <| .widthError w w' + | .not w' => if h : w = w' + then return ⟨ + .not w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e) .nil, + .nil + ⟩ + else throw <| .widthError w w' + | .copy w' => if h : w = w' + then return ⟨ + .copy w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e) .nil, + .nil + ⟩ + else throw <| .widthError w w' + | _ => throw <| .unsupportedUnaryOp + +def mkBinOp {Γ : Ctxt (MTy φ)} {ty : (MTy φ)} (op : MOp φ) + (e₁ e₂ : Ctxt.Var Γ ty) : MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := + match ty with + | .bitvec w => + match op with + -- Can't use a single arm, Lean won't write the rhs accordingly + | .add w' => if h : w = w' + then return ⟨ + .add w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .and w' => if h : w = w' + then return ⟨ + .and w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .or w' => if h : w = w' + then return ⟨ + .or w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .xor w' => if h : w = w' + then return ⟨ + .xor w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .shl w' => if h : w = w' + then return ⟨ + .shl w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .lshr w' => if h : w = w' + then return ⟨ + .lshr w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .ashr w' => if h : w = w' + then return ⟨ + .ashr w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .urem w' => if h : w = w' + then return ⟨ + .urem w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .srem w' => if h : w = w' + then return ⟨ + .srem w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .mul w' => if h : w = w' + then return ⟨ + .mul w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .sub w' => if h : w = w' + then return ⟨ + .sub w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .sdiv w' => if h : w = w' + then return ⟨ + .sdiv w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .udiv w' => if h : w = w' + then return ⟨ + .udiv w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | op => throw <| .unsupportedBinaryOp s!"unsupported binary operation {op}" + +def mkIcmp {Γ : Ctxt _} {ty : (MTy φ)} (op : MOp φ) + (e₁ e₂ : Ctxt.Var Γ ty) : MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ (.bitvec 1) := + match ty with + | .bitvec w => + match op with + | .icmp p w' => if h : w = w' + then return ⟨ + .icmp p w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | _ => throw <| .unsupportedOp "unsupported icmp operation" + +def mkSelect {Γ : Ctxt (MTy φ)} {ty : (MTy φ)} (op : MOp φ) + (c : Ctxt.Var Γ (.bitvec 1)) (e₁ e₂ : Ctxt.Var Γ ty) : + MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := + match ty with + | .bitvec w => + match op with + | .select w' => if h : w = w' + then return ⟨ + .select w', + by simp [OpSignature.outTy, signature, h], + .cons c <|.cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | _ => throw <| .unsupportedOp "Unsupported select operation" + +def mkOpExpr {Γ : Ctxt (MTy φ)} (op : MOp φ) + (arg : HVector (fun t => Ctxt.Var Γ t) (OpSignature.sig op)) : + MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ (OpSignature.outTy op) := + match op with + | .and _ | .or _ | .xor _ | .shl _ | .lshr _ | .ashr _ + | .add _ | .mul _ | .sub _ | .udiv _ | .sdiv _ + | .srem _ | .urem _ => + let (e₁, e₂) := arg.toTuple + mkBinOp op e₁ e₂ + | .icmp _ _ => + let (e₁, e₂) := arg.toTuple + mkIcmp op e₁ e₂ + | .not _ | .neg _ | .copy _ => + mkUnaryOp op arg.head + | .select _ => + let (c, e₁, e₂) := arg.toTuple + mkSelect op c e₁ e₂ + | .const .. => throw <| .unsupportedOp "Tried to build (MOp φ) expression from constant" + +def mkTy : MLIR.AST.MLIRType φ → MLIR.AST.ExceptM (MOp φ) (MTy φ) + | MLIR.AST.MLIRType.int MLIR.AST.Signedness.Signless w => do + return .bitvec w + | _ => throw .unsupportedType -- "Unsupported type" + +instance instTransformTy : MLIR.AST.TransformTy (MOp φ) (MTy φ) φ where + mkTy := mkTy + +def mkExpr (Γ : Ctxt (MTy φ)) (opStx : MLIR.AST.Op φ) : AST.ReaderM (MOp φ) (Σ ty, Expr (MOp φ) Γ ty) := do + match opStx.args with + | v₁Stx::v₂Stx::v₃Stx::[] => + let ⟨.bitvec w₁, v₁⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₁Stx + let ⟨.bitvec w₂, v₂⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₂Stx + let ⟨.bitvec w₃, v₃⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₃Stx + match opStx.name with + | "llvm.select" => + if hw1 : w₁ = 1 then + if hw23 : w₂ = w₃ then + let selectOp ← mkSelect (MOp.select w₂) (hw1 ▸ v₁) v₂ (hw23 ▸ v₃) + return ⟨.bitvec w₂, selectOp⟩ + else + throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" + else throw <| .unsupportedOp s!"expected select condtion to have width 1, found width '{w₁}'" + | op => throw <| .unsupportedOp s!"Unsuported ternary operation or invalid arguments '{op}'" + | v₁Stx::v₂Stx::[] => + let ⟨.bitvec w₁, v₁⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₁Stx + let ⟨.bitvec w₂, v₂⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₂Stx + -- let ty₁ := ty₁.instantiave + let op ← match opStx.name with + | "llvm.and" => pure (MOp.and w₁) + | "llvm.or" => pure (MOp.or w₁) + | "llvm.xor" => pure (MOp.xor w₁) + | "llvm.shl" => pure (MOp.shl w₁) + | "llvm.lshr" => pure (MOp.lshr w₁) + | "llvm.ashr" => pure (MOp.ashr w₁) + | "llvm.urem" => pure (MOp.urem w₁) + | "llvm.srem" => pure (MOp.srem w₁) + | "llvm.add" => pure (MOp.add w₁) + | "llvm.mul" => pure (MOp.mul w₁) + | "llvm.sub" => pure (MOp.sub w₁) + | "llvm.sdiv" => pure (MOp.sdiv w₁) + | "llvm.udiv" => pure (MOp.udiv w₁) + | "llvm.icmp.eq" => pure (MOp.icmp LLVM.IntPredicate.eq w₁) + | "llvm.icmp.ne" => pure (MOp.icmp LLVM.IntPredicate.ne w₁) + | "llvm.icmp.ugt" => pure (MOp.icmp LLVM.IntPredicate.ugt w₁) + | "llvm.icmp.uge" => pure (MOp.icmp LLVM.IntPredicate.uge w₁) + | "llvm.icmp.ult" => pure (MOp.icmp LLVM.IntPredicate.ult w₁) + | "llvm.icmp.ule" => pure (MOp.icmp LLVM.IntPredicate.ule w₁) + | "llvm.icmp.sgt" => pure (MOp.icmp LLVM.IntPredicate.sgt w₁) + | "llvm.icmp.sge" => pure (MOp.icmp LLVM.IntPredicate.sge w₁) + | "llvm.icmp.slt" => pure (MOp.icmp LLVM.IntPredicate.slt w₁) + | "llvm.icmp.sle" => pure (MOp.icmp LLVM.IntPredicate.sle w₁) + | opstr => throw <| .unsupportedOp s!"Unsuported binary operation or invalid arguments '{opstr}'" + match op with + | .icmp .. => + if hty : w₁ = w₂ then + let icmpOp ← mkIcmp op v₁ (hty ▸ v₂) + return ⟨.bitvec 1, icmpOp⟩ + else + throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" + | _ => + if hty : w₁ = w₂ then + let binOp ← mkBinOp op v₁ (hty ▸ v₂) + return ⟨.bitvec w₁, binOp⟩ + else + throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" + | vStx::[] => + let ⟨.bitvec w, v⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ vStx + let op ← match opStx.name with + | "llvm.not" => pure <| MOp.not w + | "llvm.neg" => pure <| MOp.neg w + | "llvm.copy" => pure <| MOp.copy w + | _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}" + let op ← mkUnaryOp op v + return ⟨.bitvec w, op⟩ + | [] => + if opStx.name == "llvm.mlir.constant" + then do + let some att := opStx.attrs.getAttr "value" + | throw <| .generic "tried to resolve constant without 'value' attribute" + match att with + | .int val ty => + let opTy@(MTy.bitvec w) ← mkTy ty -- ty.mkTy + return ⟨opTy, ⟨ + MOp.const w val, + by simp [OpSignature.outTy, signature, *], + HVector.nil, + HVector.nil + ⟩⟩ + | _ => throw <| .generic "invalid constant attribute" + else + throw <| .generic s!"invalid (0-ary) expression {opStx.name}" + | _ => throw <| .generic s!"unsupported expression (with unsupported arity) {opStx.name}" + +instance : AST.TransformExpr (MOp φ) (MTy φ) φ where + mkExpr := mkExpr + +def mkReturn (Γ : Ctxt (MTy φ)) (opStx : MLIR.AST.Op φ) : MLIR.AST.ReaderM (MOp φ) (Σ ty, Com (MOp φ) Γ ty) := + if opStx.name == "llvm.return" + then match opStx.args with + | vStx::[] => do + let ⟨ty, v⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ vStx + return ⟨ty, _root_.Com.ret v⟩ + | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {opStx.args.length})" + else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" + +instance : AST.TransformReturn (MOp φ) (MTy φ) φ where + mkReturn := mkReturn + + +/-! + ## Instantiation + Finally, we show how to instantiate a family of programs to a concrete program +-/ + +def instantiateMTy (vals : Vector Nat φ) : (MTy φ) → InstCombine.Ty + | .bitvec w => .bitvec <| .concrete <| w.instantiate vals + +def instantiateMOp (vals : Vector Nat φ) : MOp φ → InstCombine.Op + | .and w => .and (w.instantiate vals) + | .or w => .or (w.instantiate vals) + | .not w => .not (w.instantiate vals) + | .xor w => .xor (w.instantiate vals) + | .shl w => .shl (w.instantiate vals) + | .lshr w => .lshr (w.instantiate vals) + | .ashr w => .ashr (w.instantiate vals) + | .urem w => .urem (w.instantiate vals) + | .srem w => .srem (w.instantiate vals) + | .select w => .select (w.instantiate vals) + | .add w => .add (w.instantiate vals) + | .mul w => .mul (w.instantiate vals) + | .sub w => .sub (w.instantiate vals) + | .neg w => .neg (w.instantiate vals) + | .copy w => .copy (w.instantiate vals) + | .sdiv w => .sdiv (w.instantiate vals) + | .udiv w => .udiv (w.instantiate vals) + | .icmp c w => .icmp c (w.instantiate vals) + | .const w val => .const (w.instantiate vals) val + +def instantiateCtxt (vals : Vector Nat φ) (Γ : Ctxt (MTy φ)) : Ctxt InstCombine.Ty := + Γ.map (instantiateMTy vals) + +def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCombine.Op) where + mapOp := instantiateMOp vals + mapTy := instantiateMTy vals + preserves_signature op := by + simp only [instantiateMTy, instantiateMOp, ConcreteOrMVar.instantiate, (· <$> ·), signature, + InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, + true_and] + cases op <;> simp only [List.map, and_self, List.cons.injEq] + + +open InstCombine (Op Ty) in +def mkComInstantiate (reg : MLIR.AST.Region φ) : + MLIR.AST.ExceptM (MOp φ) (Vector Nat φ → Σ (Γ : Ctxt InstCombine.Ty) (ty : InstCombine.Ty), Com InstCombine.Op Γ ty) := do + let ⟨Γ, ty, com⟩ ← MLIR.AST.mkCom reg + return fun vals => + ⟨instantiateCtxt vals Γ, instantiateMTy vals ty, com.map (MOp.instantiateCom vals)⟩ + +end InstcombineTransformDialect + + +/- +https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/Cannot.20Find.20.60Real.2Eadd.60/near/402089561 +> I would recommend avoiding Qq for pattern matching. +> That part of the Qq implementation is spicy. + +Therefore, we choose to match on raw `Expr`. +-/ open MLIR.AST InstCombine in -elab "[mlir_icom (" mvars:term,* ")| " reg:mlir_region "]" : term => do +elab "[alive_icom (" mvars:term,* ")| " reg:mlir_region "]" : term => do let ast_stx ← `([mlir_region| $reg]) let φ : Nat := mvars.getElems.size let ast ← elabTermEnsuringTypeQ ast_stx q(Region $φ) let mvalues ← `(⟨[$mvars,*], by rfl⟩) let mvalues : Q(Vector Nat $φ) ← elabTermEnsuringType mvalues q(Vector Nat $φ) - let com := q(mkComInstantiate $ast |>.map (· $mvalues)) + let com := q(InstcombineTransformDialect.mkComInstantiate (φ := $φ) $ast |>.map (· $mvalues)) synthesizeSyntheticMVarsNoPostponing - let com : Q(ExceptM (Σ (Γ' : Ctxt Ty) (ty : InstCombine.Ty), Com Γ' ty)) ← + let com : Q(MLIR.AST.ExceptM (MOp $φ) (Σ (Γ' : Ctxt (MTy $φ)) (ty : (MTy $φ)), Com (MOp $φ) Γ' ty)) ← withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do withTransparency (mode := TransparencyMode.all) <| return ←reduce com + let comExpr : Expr := com trace[Meta] com - match com with - | ~q(Except.ok $comOk) => return q(($comOk).snd.snd) - | ~q(Except.error $err) => do - let err ← unsafe evalExpr TransformError q(TransformError) err - throwError "Translation failed with error:\n\t{repr err}" - | e => throwError "Translation failed to reduce, possibly too generic syntax\n\t{e}" - -macro "[mlir_icom| " reg:mlir_region "]" : term => `([mlir_icom ()| $reg]) + trace[Meta] comExpr + match comExpr.app3? ``Except.ok with + | .some (_εexpr, _αexpr, aexpr) => + match aexpr.app4? ``Sigma.mk with + | .some (_αexpr, _βexpr, _fstexpr, sndexpr) => + match sndexpr.app4? ``Sigma.mk with + | .some (_αexpr, _βexpr, _fstexpr, sndexpr) => + return sndexpr + | .none => throwError "Found `Except.ok (Sigma.mk _ WRONG)`, Expected (Except.ok (Sigma.mk _ (Sigma.mk _ _))" + | .none => throwError "Found `Except.ok WRONG`, Expected (Except.ok (Sigma.mk _ _))" + | .none => throwError "Expected `Except.ok`, found {comExpr}" +macro "[alive_icom| " reg:mlir_region "]" : term => `([alive_icom ()| $reg]) diff --git a/SSA/Projects/InstCombine/LLVM/Transform.lean b/SSA/Projects/InstCombine/LLVM/Transform.lean index 9bad81a21..b2f621aa3 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform.lean @@ -1,7 +1,9 @@ -- should replace with Lean import once Pure is upstream import SSA.Projects.MLIRSyntax.AST -import SSA.Projects.InstCombine.Base +import SSA.Projects.InstCombine.LLVM.Transform.NameMapping +import SSA.Projects.InstCombine.LLVM.Transform.TransformError import SSA.Core.Framework +import SSA.Core.ErasedContext import Std.Data.BitVec @@ -9,114 +11,56 @@ universe u namespace MLIR.AST -open InstCombine (MOp MTy Width) open Std (BitVec) +open Ctxt -abbrev Context (φ) := List (MTy φ) - -abbrev Expr (Γ : Context φ) (ty : MTy φ) := _root_.Expr (MOp φ) Γ ty -abbrev Com (Γ : Context φ) (ty : MTy φ) := _root_.Com (MOp φ) Γ ty -abbrev Var (Γ : Context φ) (ty : MTy φ) := _root_.Ctxt.Var Γ ty - -abbrev Com.lete (body : Expr Γ ty₁) (rest : Com (ty₁::Γ) ty₂) : Com Γ ty₂ := - _root_.Com.lete body rest - -inductive TransformError - | nameAlreadyDeclared (var : String) - | undeclaredName (var : String) - | indexOutOfBounds (name : String) (index len : Nat) - | typeError {φ} (expected got : MTy φ) - | widthError {φ} (expected got : Width φ) - | unsupportedUnaryOp - | unsupportedBinaryOp (name : String) - | unsupportedOp (error : String) - | unsupportedType - | generic (error : String) - -namespace TransformError - -instance : Lean.ToFormat (MTy φ) where - format := repr - -instance : Repr TransformError where - reprPrec err _ := match err with - | nameAlreadyDeclared var => f!"Already declared {var}, shadowing is not allowed" - | undeclaredName name => f!"Undeclared name '{name}'" - | indexOutOfBounds name index len => - f!"Index of '{name}' out of bounds of the given context (index was {index}, but context has length {len})" - | typeError expected got => f!"Type mismatch: expected '{expected}', but 'name' has type '{got}'" - | widthError expected got => f!"Type mismatch: {expected} ≠ {got}" - | unsupportedUnaryOp => f!"Unsuported unary operation" - | unsupportedBinaryOp name => f!"Unsuported binary operation {name}" - | unsupportedOp err => f!"Unsuported operation '{err}'" - | unsupportedType => f!"Unsuported type" - | generic err => err +instance {Op Ty : Type} [OpSignature Op Ty] {t : Ty} {Γ : Ctxt Ty} {Γ' : DerivedCtxt Γ} : Coe (Expr Op Γ t) (Expr Op Γ'.ctxt t) where + coe e := e.changeVars Γ'.diff.toHom -end TransformError -/-- -Store the names of the raw SSA variables (as strings). -The order in the list should match the order in which they appear in the code. --/ -abbrev NameMapping := List String +section Monads -def NameMapping.lookup (nm : NameMapping) (name : String) : Option Nat := - nm.indexOf? name - -/-- - Add a new name to the mapping, assuming the name is not present in the list yet. - If the name is already present, return `none` +/-! + Even though we technically only need to know the `Ty` type, + our typeclass hierarchy is fully based on `Op`. + It is thus more convenient to incorporate `Op` in these types, so that there will be no ambiguity + errors. -/ -def NameMapping.add (nm : NameMapping) (name : String) : Option NameMapping := - match nm.lookup name with - | none => some <| name::nm - | some _ => none - -example : (ExceptT ε <| ReaderM ρ) = (ReaderT ρ <| Except ε) := rfl -abbrev ExceptM := Except TransformError -abbrev BuilderM := StateT NameMapping ExceptM -abbrev ReaderM := ReaderT NameMapping ExceptM +abbrev ExceptM Op [OpSignature Op Ty] := Except (TransformError Ty) +abbrev BuilderM Op [OpSignature Op Ty] := StateT NameMapping (ExceptM Op) +abbrev ReaderM Op [OpSignature Op Ty] := ReaderT NameMapping (ExceptM Op) -instance : MonadLift ReaderM BuilderM where +instance {Op : Type} [OpSignature Op Ty] : MonadLift (ReaderM Op) (BuilderM Op) where monadLift x := do (ReaderT.run x (←get) : ExceptM ..) -instance : MonadLift ExceptM ReaderM where +instance {Op : Type} [OpSignature Op Ty] : MonadLift (ExceptM Op) (ReaderM Op) where monadLift x := do return ←x -def BuilderM.runWithNewMapping (k : BuilderM α) : ExceptM α := +def BuilderM.runWithEmptyMapping {Op : Type} [OpSignature Op Ty] (k : BuilderM Op α) : ExceptM Op α := Prod.fst <$> StateT.run k [] -structure DerivedContext (Γ : Context φ) where - ctxt : Context φ - diff : Ctxt.Diff Γ ctxt - -namespace DerivedContext - -/-- Every context is trivially derived from itself -/ -abbrev ofContext (Γ : Context φ) : DerivedContext Γ := ⟨Γ, .zero _⟩ - -/-- `snoc` of a derived context applies `snoc` to the underlying context, and updates the diff -/ -def snoc {Γ : Context φ} : DerivedContext Γ → MTy φ → DerivedContext Γ - | ⟨ctxt, diff⟩, ty => ⟨ty::ctxt, diff.toSnoc⟩ - -instance {Γ : Context φ} : CoeHead (DerivedContext Γ) (Context φ) where - coe := fun ⟨Γ', _⟩ => Γ' +end Monads -instance {Γ : Context φ} : CoeDep (Context φ) Γ (DerivedContext Γ) where - coe := ⟨Γ, .zero _⟩ +/-! + These typeclasses provide a natural flow to how users should implement `TransformDialect`. + - First declare how to transform types with `TransformTy`. + - Second, using `TransformTy`, declare how to transform expressions with `TransformExpr`. + - Third, using both type and expression conversion, declare how to transform returns with `TransformReturn`. + - These three automatically give an instance of `TransformDialect`. +-/ +class TransformTy (Op : Type) (Ty : outParam (Type)) (φ : outParam Nat) [OpSignature Op Ty] where + mkTy : MLIRType φ → ExceptM Op Ty -instance {Γ : Context φ} {Γ' : DerivedContext Γ} : - CoeHead (DerivedContext (Γ' : Context φ)) (DerivedContext Γ) where - coe := fun ⟨Γ'', diff⟩ => ⟨Γ'', Γ'.diff + diff⟩ +class TransformExpr (Op : Type) (Ty : outParam (Type)) (φ : outParam Nat) [OpSignature Op Ty] [TransformTy Op Ty φ] where + mkExpr : (Γ : List Ty) → (opStx : AST.Op φ) → ReaderM Op (Σ ty, Expr Op Γ ty) -instance {Γ : Context φ} {Γ' : DerivedContext Γ} : Coe (Expr Γ t) (Expr Γ' t) where - coe e := e.changeVars Γ'.diff.toHom +class TransformReturn (Op : Type) (Ty : outParam (Type)) (φ : outParam Nat) + [OpSignature Op Ty] [TransformTy Op Ty φ] where + mkReturn : (Γ : List Ty) → (opStx : AST.Op φ) → ReaderM Op (Σ ty, Com Op Γ ty) -instance {Γ' : DerivedContext Γ} : Coe (Var Γ t) (Var Γ' t) where - coe v := Γ'.diff.toHom v - -end DerivedContext +/- instance of the transform dialect, plus data needed about `Op` and `Ty`. -/ +variable {Op Ty φ} [OpSignature Op Ty] [DecidableEq Ty] [DecidableEq Op] /-- Add a new variable to the context, and record it's (absolute) index in the name mapping @@ -124,12 +68,12 @@ end DerivedContext Throws an error if the variable name already exists in the mapping, essentially disallowing shadowing -/ -def addValToMapping (Γ : Context φ) (name : String) (ty : MTy φ) : - BuilderM (Σ (Γ' : DerivedContext Γ), Var Γ' ty) := do +def addValToMapping (Γ : Ctxt Ty) (name : String) (ty : Ty) : + BuilderM Op (Σ (Γ' : DerivedCtxt Γ), Ctxt.Var Γ'.ctxt ty) := do let some nm := (←get).add name | throw <| .nameAlreadyDeclared name set nm - return ⟨DerivedContext.ofContext Γ |>.snoc ty, Ctxt.Var.last ..⟩ + return ⟨DerivedCtxt.ofCtxt Γ |>.snoc ty, Ctxt.Var.last ..⟩ /-- Look up a name from the name mapping, and return the corresponding variable in the given context. @@ -137,8 +81,8 @@ def addValToMapping (Γ : Context φ) (name : String) (ty : MTy φ) : Throws an error if the name is not present in the mapping (this indicates the name may be free), or if the type of the variable in the context is different from `expectedType` -/ -def getValFromContext (Γ : Context φ) (name : String) (expectedType : MTy φ) : - ReaderM (Ctxt.Var Γ expectedType) := do +def getValFromCtxt (Γ : Ctxt Ty) (name : String) (expectedType : Ty) : + ReaderM Op (Ctxt.Var Γ expectedType) := do let index := (←read).lookup name let some index := index | throw <| .undeclaredName name let n := Γ.length @@ -153,348 +97,65 @@ def getValFromContext (Γ : Context φ) (name : String) (expectedType : MTy φ) else throw <| .typeError expectedType t -def BuilderM.isOk {α : Type} (x : BuilderM α) : Bool := +def BuilderM.isOk {α : Type} (x : BuilderM Op α) : Bool := match x.run [] with | Except.ok _ => true | Except.error _ => false -def BuilderM.isErr {α : Type} (x : BuilderM α) : Bool := +def BuilderM.isErr {α : Type} (x : BuilderM Op α) : Bool := match x.run [] with | Except.ok _ => true | Except.error _ => false -def mkUnaryOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) - (e : Var Γ ty) : ExceptM <| Expr Γ ty := - match ty with - | .bitvec w => - match op with - -- Can't use a single arm, Lean won't write the rhs accordingly - | .neg w' => if h : w = w' - then return ⟨ - .neg w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | .not w' => if h : w = w' - then return ⟨ - .not w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | .copy w' => if h : w = w' - then return ⟨ - .copy w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw <| .unsupportedUnaryOp - -def mkBinOp {Γ : Ctxt _} {ty : MTy φ} (op : MOp φ) - (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM <| Expr Γ ty := - match ty with - | .bitvec w => - match op with - -- Can't use a single arm, Lean won't write the rhs accordingly - | .add w' => if h : w = w' - then return ⟨ - .add w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .and w' => if h : w = w' - then return ⟨ - .and w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .or w' => if h : w = w' - then return ⟨ - .or w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .xor w' => if h : w = w' - then return ⟨ - .xor w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .shl w' => if h : w = w' - then return ⟨ - .shl w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .lshr w' => if h : w = w' - then return ⟨ - .lshr w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .ashr w' => if h : w = w' - then return ⟨ - .ashr w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .urem w' => if h : w = w' - then return ⟨ - .urem w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .srem w' => if h : w = w' - then return ⟨ - .srem w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .mul w' => if h : w = w' - then return ⟨ - .mul w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .sub w' => if h : w = w' - then return ⟨ - .sub w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .sdiv w' => if h : w = w' - then return ⟨ - .sdiv w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .udiv w' => if h : w = w' - then return ⟨ - .udiv w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | op => throw <| .unsupportedBinaryOp s!"unsupported binary operation {op}" - -def mkIcmp {Γ : Ctxt _} {ty : MTy φ} (op : MOp φ) - (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM <| Expr Γ (.bitvec 1) := - match ty with - | .bitvec w => - match op with - | .icmp p w' => if h : w = w' - then return ⟨ - .icmp p w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw <| .unsupportedOp "unsupported icmp operation" - -def mkSelect {Γ : Context φ} {ty : MTy φ} (op : MOp φ) - (c : Var Γ (.bitvec 1)) (e₁ e₂ : Var Γ ty) : - ExceptM <| Expr Γ ty := - match ty with - | .bitvec w => - match op with - | .select w' => if h : w = w' - then return ⟨ - .select w', - by simp [OpSignature.outTy, signature, h], - .cons c <|.cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw <| .unsupportedOp "Unsupported select operation" - -def mkOpExpr {Γ : Context φ} (op : MOp φ) - (arg : HVector (fun t => Ctxt.Var Γ t) (OpSignature.sig op)) : - ExceptM <| Expr Γ (OpSignature.outTy op) := - match op with - | .and _ | .or _ | .xor _ | .shl _ | .lshr _ | .ashr _ - | .add _ | .mul _ | .sub _ | .udiv _ | .sdiv _ - | .srem _ | .urem _ => - let (e₁, e₂) := arg.toTuple - mkBinOp op e₁ e₂ - | .icmp _ _ => - let (e₁, e₂) := arg.toTuple - mkIcmp op e₁ e₂ - | .not _ | .neg _ | .copy _ => - mkUnaryOp op arg.head - | .select _ => - let (c, e₁, e₂) := arg.toTuple - mkSelect op c e₁ e₂ - | .const .. => throw <| .unsupportedOp "Tried to build Op expression from constant" - -def MLIRType.mkTy : MLIRType φ → ExceptM (MTy φ) - | MLIRType.int Signedness.Signless w => do - return .bitvec w - | _ => throw .unsupportedType -- "Unsupported type" - -def TypedSSAVal.mkTy : TypedSSAVal φ → ExceptM (MTy φ) - | (.SSAVal _, ty) => ty.mkTy - -def mkVal (ty : InstCombine.Ty) : Int → BitVec ty.width - | val => BitVec.ofInt ty.width val +def TypedSSAVal.mkTy [TransformTy Op Ty φ] : TypedSSAVal φ → ExceptM Op Ty + | (.SSAVal _, ty) => TransformTy.mkTy ty /-- Translate a `TypedSSAVal` (a name with an expected type), to a variable in the context. This expects the name to have already been declared before -/ -def TypedSSAVal.mkVal (Γ : Context φ) : TypedSSAVal φ → - ReaderM (Σ (ty : MTy φ), Var Γ ty) +def TypedSSAVal.mkVal [instTransformTy : TransformTy Op Ty φ] (Γ : Ctxt Ty) : TypedSSAVal φ → + ReaderM Op (Σ (ty : Ty), Ctxt.Var Γ ty) | (.SSAVal valStx, tyStx) => do - let ty ← tyStx.mkTy - let var ← getValFromContext Γ valStx ty + let ty ← instTransformTy.mkTy tyStx + let var ← getValFromCtxt Γ valStx ty + return ⟨ty, var⟩ + +/-- A variant of `TypedSSAVal.mkVal` that takes the function `mkTy` as an argument + instead of using the typeclass `TransformDialect`. + This is useful when trying to implement an instance of `TransformDialect` itself, + to cut infinite regress. -/ +def TypedSSAVal.mkVal' [instTransformTy : TransformTy Op Ty φ] (Γ : Ctxt Ty) : TypedSSAVal φ → + ReaderM Op (Σ (ty : Ty), Ctxt.Var Γ ty) +| (.SSAVal valStx, tyStx) => do + let ty ← instTransformTy.mkTy tyStx + let var ← getValFromCtxt Γ valStx ty return ⟨ty, var⟩ /-- Declare a new variable, by adding the passed name to the name mapping stored in the monad state -/ -def TypedSSAVal.newVal (Γ : Context φ) : TypedSSAVal φ → - BuilderM (Σ (Γ' : DerivedContext Γ) (ty : MTy φ), Var Γ' ty) +def TypedSSAVal.newVal [instTransformTy : TransformTy Op Ty φ] (Γ : Ctxt Ty) : TypedSSAVal φ → + BuilderM Op (Σ (Γ' : DerivedCtxt Γ) (ty : Ty), Ctxt.Var Γ'.ctxt ty) | (.SSAVal valStx, tyStx) => do - let ty ← tyStx.mkTy + let ty ← instTransformTy.mkTy tyStx let ⟨Γ, var⟩ ← addValToMapping Γ valStx ty return ⟨Γ, ty, var⟩ -def mkExpr (Γ : Context φ) (opStx : Op φ) : ReaderM (Σ ty, Expr Γ ty) := do - match opStx.args with - | v₁Stx::v₂Stx::v₃Stx::[] => - let ⟨.bitvec w₁, v₁⟩ ← TypedSSAVal.mkVal Γ v₁Stx - let ⟨.bitvec w₂, v₂⟩ ← TypedSSAVal.mkVal Γ v₂Stx - let ⟨.bitvec w₃, v₃⟩ ← TypedSSAVal.mkVal Γ v₃Stx - match opStx.name with - | "llvm.select" => - if hw1 : w₁ = 1 then - if hw23 : w₂ = w₃ then - let selectOp ← mkSelect (MOp.select w₂) (hw1 ▸ v₁) v₂ (hw23 ▸ v₃) - return ⟨.bitvec w₂, selectOp⟩ - else - throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" - else throw <| .unsupportedOp s!"expected select condtion to have width 1, found width '{w₁}'" - | op => throw <| .unsupportedOp s!"Unsuported ternary operation or invalid arguments '{op}'" - | v₁Stx::v₂Stx::[] => - let ⟨.bitvec w₁, v₁⟩ ← TypedSSAVal.mkVal Γ v₁Stx - let ⟨.bitvec w₂, v₂⟩ ← TypedSSAVal.mkVal Γ v₂Stx - -- let ty₁ := ty₁.instantiave - let op ← match opStx.name with - | "llvm.and" => pure (MOp.and w₁) - | "llvm.or" => pure (MOp.or w₁) - | "llvm.xor" => pure (MOp.xor w₁) - | "llvm.shl" => pure (MOp.shl w₁) - | "llvm.lshr" => pure (MOp.lshr w₁) - | "llvm.ashr" => pure (MOp.ashr w₁) - | "llvm.urem" => pure (MOp.urem w₁) - | "llvm.srem" => pure (MOp.srem w₁) - | "llvm.add" => pure (MOp.add w₁) - | "llvm.mul" => pure (MOp.mul w₁) - | "llvm.sub" => pure (MOp.sub w₁) - | "llvm.sdiv" => pure (MOp.sdiv w₁) - | "llvm.udiv" => pure (MOp.udiv w₁) - | "llvm.icmp.eq" => pure (MOp.icmp LLVM.IntPredicate.eq w₁) - | "llvm.icmp.ne" => pure (MOp.icmp LLVM.IntPredicate.ne w₁) - | "llvm.icmp.ugt" => pure (MOp.icmp LLVM.IntPredicate.ugt w₁) - | "llvm.icmp.uge" => pure (MOp.icmp LLVM.IntPredicate.uge w₁) - | "llvm.icmp.ult" => pure (MOp.icmp LLVM.IntPredicate.ult w₁) - | "llvm.icmp.ule" => pure (MOp.icmp LLVM.IntPredicate.ule w₁) - | "llvm.icmp.sgt" => pure (MOp.icmp LLVM.IntPredicate.sgt w₁) - | "llvm.icmp.sge" => pure (MOp.icmp LLVM.IntPredicate.sge w₁) - | "llvm.icmp.slt" => pure (MOp.icmp LLVM.IntPredicate.slt w₁) - | "llvm.icmp.sle" => pure (MOp.icmp LLVM.IntPredicate.sle w₁) - | opstr => throw <| .unsupportedOp s!"Unsuported binary operation or invalid arguments '{opstr}'" - match op with - | .icmp .. => - if hty : w₁ = w₂ then - let icmpOp ← (mkIcmp op v₁ (hty ▸ v₂) : ExceptM _) - return ⟨.bitvec 1, icmpOp⟩ - else - throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" - | _ => - if hty : w₁ = w₂ then - let binOp ← (mkBinOp op v₁ (hty ▸ v₂) : ExceptM _) - return ⟨.bitvec w₁, binOp⟩ - else - throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" - | vStx::[] => - let ⟨.bitvec w, v⟩ ← TypedSSAVal.mkVal Γ vStx - let op ← match opStx.name with - | "llvm.not" => pure <| MOp.not w - | "llvm.neg" => pure <| MOp.neg w - | "llvm.copy" => pure <| MOp.copy w - | _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}" - let op ← mkUnaryOp op v - return ⟨.bitvec w, op⟩ - | [] => - if opStx.name == "llvm.mlir.constant" - then do - let some att := opStx.attrs.getAttr "value" - | throw <| .generic "tried to resolve constant without 'value' attribute" - match att with - | .int val ty => - let opTy@(MTy.bitvec w) ← ty.mkTy - return ⟨opTy, ⟨ - MOp.const w val, - by simp [OpSignature.outTy, signature, *], - HVector.nil, - HVector.nil - ⟩⟩ - | _ => throw <| .generic "invalid constant attribute" - else - throw <| .generic s!"invalid (0-ary) expression {opStx.name}" - | _ => throw <| .generic s!"unsupported expression (with unsupported arity) {opStx.name}" - -def mkReturn (Γ : Context φ) (opStx : Op φ) : ReaderM (Σ ty, Com Γ ty) := - if opStx.name == "llvm.return" - then match opStx.args with - | vStx::[] => do - let ⟨ty, v⟩ ← vStx.mkVal Γ - return ⟨ty, _root_.Com.ret v⟩ - | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {opStx.args.length})" - else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" - /-- Given a list of `TypedSSAVal`s, treat each as a binder and declare a new variable with the given name and type -/ -private def declareBindings (Γ : Context φ) (vals : List (TypedSSAVal φ)) : - BuilderM (DerivedContext Γ) := do +private def declareBindings [TransformTy Op Ty φ] (Γ : Ctxt Ty) (vals : List (TypedSSAVal φ)) : + BuilderM Op (DerivedCtxt Γ) := do vals.foldlM (fun Γ' ssaVal => do - let ⟨Γ'', _⟩ ← TypedSSAVal.newVal Γ' ssaVal + let ⟨Γ'', _⟩ ← TypedSSAVal.newVal Γ'.ctxt ssaVal return Γ'' - ) (.ofContext Γ) - -private def mkComHelper (Γ : Context φ) : - List (Op φ) → BuilderM (Σ (ty : _), Com Γ ty) - | [retStx] => mkReturn Γ retStx + ) (.ofCtxt Γ) + +private def mkComHelper + [TransformTy Op Ty φ] [instTransformExpr : TransformExpr Op Ty φ] [instTransformReturn : TransformReturn Op Ty φ] + (Γ : Ctxt Ty) : + List (MLIR.AST.Op φ) → BuilderM Op (Σ (ty : _), Com Op Γ ty) + | [retStx] => do + instTransformReturn.mkReturn Γ retStx | lete::rest => do - let ⟨ty₁, expr⟩ ← mkExpr Γ lete + let ⟨ty₁, expr⟩ ← (instTransformExpr.mkExpr Γ lete) if h : lete.res.length != 1 then throw <| .generic s!"Each let-binding must have exactly one name on the left-hand side. Operations with multiple, or no, results are not yet supported.\n\tExpected a list of length one, found `{repr lete}`" else @@ -503,61 +164,13 @@ private def mkComHelper (Γ : Context φ) : return ⟨ty₂, Com.lete expr body⟩ | [] => throw <| .generic "Ill-formed (empty) block" -def mkCom (reg : Region φ) : ExceptM (Σ (Γ : Context φ) (ty : MTy φ), Com Γ ty) := +def mkCom [TransformTy Op Ty φ] [TransformExpr Op Ty φ] [TransformReturn Op Ty φ] + (reg : MLIR.AST.Region φ) : ExceptM Op (Σ (Γ : Ctxt Ty) (ty : Ty), Com Op Γ ty) := match reg.ops with | [] => throw <| .generic "Ill-formed region (empty)" - | coms => BuilderM.runWithNewMapping <| do + | coms => BuilderM.runWithEmptyMapping <| do let Γ ← declareBindings ∅ reg.args let com ← mkComHelper Γ coms return ⟨Γ, com⟩ -/-! - ## Instantiation - Finally, we show how to instantiate a family of programs to a concrete program --/ - -def _root_.InstCombine.MTy.instantiate (vals : Vector Nat φ) : MTy φ → InstCombine.Ty - | .bitvec w => .bitvec <| .concrete <| w.instantiate vals - -def _root_.InstCombine.MOp.instantiate (vals : Vector Nat φ) : MOp φ → InstCombine.Op - | .and w => .and (w.instantiate vals) - | .or w => .or (w.instantiate vals) - | .not w => .not (w.instantiate vals) - | .xor w => .xor (w.instantiate vals) - | .shl w => .shl (w.instantiate vals) - | .lshr w => .lshr (w.instantiate vals) - | .ashr w => .ashr (w.instantiate vals) - | .urem w => .urem (w.instantiate vals) - | .srem w => .srem (w.instantiate vals) - | .select w => .select (w.instantiate vals) - | .add w => .add (w.instantiate vals) - | .mul w => .mul (w.instantiate vals) - | .sub w => .sub (w.instantiate vals) - | .neg w => .neg (w.instantiate vals) - | .copy w => .copy (w.instantiate vals) - | .sdiv w => .sdiv (w.instantiate vals) - | .udiv w => .udiv (w.instantiate vals) - | .icmp c w => .icmp c (w.instantiate vals) - | .const w val => .const (w.instantiate vals) val - -def Context.instantiate (vals : Vector Nat φ) (Γ : Context φ) : Ctxt InstCombine.Ty := - Γ.map (MTy.instantiate vals) - -def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCombine.Op) where - mapOp := MOp.instantiate vals - mapTy := MTy.instantiate vals - preserves_signature op := by - simp only [MTy.instantiate, MOp.instantiate, ConcreteOrMVar.instantiate, (· <$> ·), signature, - InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, - true_and] - cases op <;> simp only [List.map, and_self, List.cons.injEq] - - -open InstCombine (Op Ty) in -def mkComInstantiate (reg : Region φ) : - ExceptM (Vector Nat φ → Σ (Γ : Ctxt Ty) (ty : Ty), _root_.Com InstCombine.Op Γ ty) := do - let ⟨Γ, ty, com⟩ ← mkCom reg - return fun vals => - ⟨Γ.instantiate vals, ty.instantiate vals, com.map (MOp.instantiateCom vals)⟩ - end MLIR.AST diff --git a/SSA/Projects/InstCombine/LLVM/Transform/NameMapping.lean b/SSA/Projects/InstCombine/LLVM/Transform/NameMapping.lean new file mode 100644 index 000000000..d9140e77e --- /dev/null +++ b/SSA/Projects/InstCombine/LLVM/Transform/NameMapping.lean @@ -0,0 +1,27 @@ +import Std.Data.List.Basic + +namespace MLIR.AST + +/-- +Store the names of the raw SSA variables (as strings). +The order in the list should match the order in which they appear in the code. +-/ +abbrev NameMapping := List String + +namespace NameMapping + +def lookup (nm : NameMapping) (name : String) : Option Nat := + nm.indexOf? name + +/-- + Add a new name to the mapping, assuming the name is not present in the list yet. + If the name is already present, return `none` +-/ +def add (nm : NameMapping) (name : String) : Option NameMapping := + match nm.lookup name with + | none => some <| name::nm + | some _ => none + +end NameMapping + +end MLIR.AST \ No newline at end of file diff --git a/SSA/Projects/InstCombine/LLVM/Transform/TransformError.lean b/SSA/Projects/InstCombine/LLVM/Transform/TransformError.lean new file mode 100644 index 000000000..eadaed68b --- /dev/null +++ b/SSA/Projects/InstCombine/LLVM/Transform/TransformError.lean @@ -0,0 +1,35 @@ +import SSA.Projects.MLIRSyntax.AST + +namespace MLIR.AST + +inductive TransformError (Ty : Type) + | nameAlreadyDeclared (var : String) + | undeclaredName (var : String) + | indexOutOfBounds (name : String) (index len : Nat) + | typeError (expected got : Ty) + | widthError {φ} (expected got : Width φ) + | unsupportedUnaryOp + | unsupportedBinaryOp (error : String) + | unsupportedOp (error : String) + | unsupportedType + | generic (error : String) + +namespace TransformError + +instance [Repr Ty] : Repr (TransformError Ty) where + reprPrec err _ := match err with + | nameAlreadyDeclared var => f!"Already declared {var}, shadowing is not allowed" + | undeclaredName name => f!"Undeclared name '{name}'" + | indexOutOfBounds name index len => + f!"Index of '{name}' out of bounds of the given context (index was {index}, but context has length {len})" + | typeError expected got => f!"Type mismatch: expected '{repr expected}', but 'name' has type '{repr got}'" + | widthError expected got => f!"Type mismatch: {expected} ≠ {got}" + | unsupportedUnaryOp => f!"Unsuported unary operation" + | unsupportedBinaryOp err => f!"Unsuported binary operation 's!{err}'" + | unsupportedOp err => f!"Unsuported operation 's!{err}'" + | unsupportedType => f!"Unsuported type" + | generic err => err + +end TransformError + +end MLIR.AST diff --git a/SSA/Projects/InstCombine/Refinement.lean b/SSA/Projects/InstCombine/Refinement.lean index 91f5cd2fa..7f213c1fc 100644 --- a/SSA/Projects/InstCombine/Refinement.lean +++ b/SSA/Projects/InstCombine/Refinement.lean @@ -1,9 +1,10 @@ +import SSA.Projects.InstCombine.Base import SSA.Projects.InstCombine.LLVM.EDSL import SSA.Projects.InstCombine.AliveStatements open MLIR AST -abbrev Com.Refinement (src tgt : Com (φ:=0) Γ t) (h : Goedel.toType t = Option α := by rfl) : Prop := +abbrev Com.Refinement (src tgt : Com (InstCombine.MOp 0) Γ t) (h : Goedel.toType t = Option α := by rfl) : Prop := ∀ Γv, (h ▸ src.denote Γv) ⊑ (h ▸ tgt.denote Γv) infixr:90 " ⊑ " => Com.Refinement diff --git a/SSA/Projects/InstCombine/Tactic.lean b/SSA/Projects/InstCombine/Tactic.lean index 53093456b..9cb26903c 100644 --- a/SSA/Projects/InstCombine/Tactic.lean +++ b/SSA/Projects/InstCombine/Tactic.lean @@ -2,9 +2,11 @@ import SSA.Projects.InstCombine.LLVM.EDSL import SSA.Projects.InstCombine.AliveStatements import SSA.Projects.InstCombine.Refinement import Mathlib.Tactic +import SSA.Core.ErasedContext open MLIR AST open Std (BitVec) +open Ctxt theorem bitvec_minus_one : BitVec.ofInt w (Int.negSucc 0) = (-1 : BitVec w) := by simp[BitVec.ofInt, BitVec.ofNat,Neg.neg, @@ -19,6 +21,8 @@ theorem bitvec_minus_one : BitVec.ofInt w (Int.negSucc 0) = (-1 : BitVec w) := b simp rw[ONE] + +open MLIR AST in /-- - We first simplify `Com.refinement` to see the context `Γv`. - We `simp_peephole Γv` to simplify context accesses by variables. @@ -32,19 +36,30 @@ macro "simp_alive_peephole" : tactic => ( dsimp only [Com.Refinement] intros Γv + simp [InstcombineTransformDialect.MOp.instantiateCom, InstcombineTransformDialect.instantiateMOp, + ConcreteOrMVar.instantiate, Vector.get, List.nthLe, List.length_singleton, Fin.coe_fin_one, Fin.zero_eta, + List.get_cons_zero, Function.comp_apply, InstcombineTransformDialect.instantiateMTy, Ctxt.empty_eq, Ctxt.DerivedCtxt.snoc, + Ctxt.DerivedCtxt.ofCtxt, List.map_eq_map, List.map, DialectMorphism.mapTy] at Γv simp_peephole at Γv /- note that we need the `HVector.toPair`, `HVector.toSingle`, `HVector.toTriple` lemmas since it's used in `InstCombine.Op.denote` We need `HVector.toTuple` since it's used in `MLIR.AST.mkOpExpr`. -/ - try simp (config := {decide := false}) only [OpDenote.denote, + simp (config := {decide := false}) only [OpDenote.denote, InstCombine.Op.denote, HVector.toPair, HVector.toTriple, pairMapM, BitVec.Refinement, - bind, Option.bind, pure, DerivedContext.ofContext, DerivedContext.snoc, - Ctxt.snoc, MOp.instantiateCom, InstCombine.MTy.instantiate, + bind, Option.bind, pure, Ctxt.DerivedCtxt.ofCtxt, Ctxt.DerivedCtxt.snoc, + Ctxt.snoc, ConcreteOrMVar.instantiate, Vector.get, HVector.toSingle, LLVM.and?, LLVM.or?, LLVM.xor?, LLVM.add?, LLVM.sub?, LLVM.mul?, LLVM.udiv?, LLVM.sdiv?, LLVM.urem?, LLVM.srem?, LLVM.sshr, LLVM.lshr?, LLVM.ashr?, LLVM.shl?, LLVM.select?, LLVM.const?, LLVM.icmp?, - HVector.toTuple, List.nthLe, bitvec_minus_one] + HVector.toTuple, List.nthLe, bitvec_minus_one, + DialectMorphism.mapTy, + InstcombineTransformDialect.instantiateMTy, + InstcombineTransformDialect.instantiateMOp, + InstcombineTransformDialect.MOp.instantiateCom, + InstcombineTransformDialect.instantiateCtxt, + ConcreteOrMVar.instantiate, Com.Refinement, + DialectMorphism.mapTy] try intros v0 try intros v1 try intros v2 diff --git a/SSA/Projects/MLIRSyntax/AST.lean b/SSA/Projects/MLIRSyntax/AST.lean index de4038ab8..31b5894bf 100644 --- a/SSA/Projects/MLIRSyntax/AST.lean +++ b/SSA/Projects/MLIRSyntax/AST.lean @@ -1,5 +1,4 @@ import SSA.Core.Util.ConcreteOrMVar -import SSA.Core.Framework open Lean PrettyPrinter namespace MLIR.AST diff --git a/SSA/Projects/PaperExamples/PaperExamples.lean b/SSA/Projects/PaperExamples/PaperExamples.lean index fd0d9f724..0492ec705 100644 --- a/SSA/Projects/PaperExamples/PaperExamples.lean +++ b/SSA/Projects/PaperExamples/PaperExamples.lean @@ -1,6 +1,11 @@ +import Qq +import Lean import Mathlib.Logic.Function.Iterate import SSA.Core.Framework import SSA.Core.Util +import SSA.Projects.InstCombine.LLVM.Transform +import SSA.Projects.MLIRSyntax.AST +import SSA.Projects.MLIRSyntax.EDSL import Mathlib.Data.StdBitVec.Lemmas set_option pp.proofs false @@ -57,24 +62,114 @@ def add {Γ : Ctxt _} (e₁ e₂ : Var Γ .int) : Expr Op Γ .int := attribute [local simp] Ctxt.snoc +namespace MLIR2Simple + +def mkTy : MLIR.AST.MLIRType φ → MLIR.AST.ExceptM Op Ty + | MLIR.AST.MLIRType.int MLIR.AST.Signedness.Signless w => do + return .int + | _ => throw .unsupportedType + +instance instTransformTy : MLIR.AST.TransformTy Op Ty 0 where + mkTy := mkTy + +def mkExpr (Γ : Ctxt Ty) (opStx : MLIR.AST.Op 0) : MLIR.AST.ReaderM Op (Σ ty, Expr Op Γ ty) := do + match opStx.name with + | "const" => + match opStx.attrs.find_int "value" with + | .some (v, _ty) => return ⟨.int, cst v⟩ + | .none => throw <| .generic s!"expected 'const' to have int attr 'value', found: {repr opStx}" + | "add" => + match opStx.args with + | v₁Stx::v₂Stx::[] => + let ⟨.int, v₁⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₁Stx + let ⟨.int, v₂⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₂Stx + return ⟨.int, add v₁ v₂⟩ + | _ => throw <| .generic s!"expected two operands for `add`, found #'{opStx.args.length}' in '{repr opStx.args}'" + | _ => throw <| .unsupportedOp s!"unsupported operation {repr opStx}" + +instance : MLIR.AST.TransformExpr Op Ty 0 where + mkExpr := mkExpr + +def mkReturn (Γ : Ctxt Ty) (opStx : MLIR.AST.Op 0) : MLIR.AST.ReaderM Op (Σ ty, Com Op Γ ty) := + if opStx.name == "return" + then match opStx.args with + | vStx::[] => do + let ⟨ty, v⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ vStx + return ⟨ty, Com.ret v⟩ + | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {opStx.args.length})" + else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" + +instance : MLIR.AST.TransformReturn Op Ty 0 where + mkReturn := mkReturn + +def mlir2simple (reg : MLIR.AST.Region 0) : + MLIR.AST.ExceptM Op (Σ (Γ : Ctxt Ty) (ty : Ty), Com Op Γ ty) := MLIR.AST.mkCom reg + +open Qq MLIR AST Lean Elab Term Meta in +elab "[simple_com| " reg:mlir_region "]" : term => do + let ast_stx ← `([mlir_region| $reg]) + let ast ← elabTermEnsuringTypeQ ast_stx q(Region 0) + let mvalues ← `(⟨[], by rfl⟩) + -- let mvalues : Q(Vector Nat 0) ← elabTermEnsuringType mvalues q(Vector Nat 0) + let com := q(ToyNoRegion.MLIR2Simple.mlir2simple $ast) + synthesizeSyntheticMVarsNoPostponing + let com : Q(MLIR.AST.ExceptM Op (Σ (Γ' : Ctxt Ty) (ty : Ty), Com Op Γ' ty)) ← + withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do + withTransparency (mode := TransparencyMode.all) <| + return ←reduce com + let comExpr : Expr := com + trace[Meta] com + trace[Meta] comExpr + match comExpr.app3? ``Except.ok with + | .some (_εexpr, _αexpr, aexpr) => + match aexpr.app4? ``Sigma.mk with + | .some (_αexpr, _βexpr, _fstexpr, sndexpr) => + match sndexpr.app4? ``Sigma.mk with + | .some (_αexpr, _βexpr, _fstexpr, sndexpr) => + return sndexpr + | .none => throwError "Found `Except.ok (Sigma.mk _ WRONG)`, Expected (Except.ok (Sigma.mk _ (Sigma.mk _ _))" + | .none => throwError "Found `Except.ok WRONG`, Expected (Except.ok (Sigma.mk _ _))" + | .none => throwError "Expected `Except.ok`, found {comExpr}" + +end MLIR2Simple + +open MLIR AST MLIR2Simple in +def eg₀ : Com Op (Ctxt.ofList []) .int := + [simple_com| { + %c2= "const"() {value = 2} : () -> i32 + %c4 = "const"() {value = 4} : () -> i32 + %c6 = "add"(%c2, %c4) : (i32, i32) -> i32 + %c8 = "add"(%c6, %c2) : (i32, i32) -> i32 + "return"(%c8) : (i32) -> () + }] + +def eg₀val := Com.denote eg₀ Ctxt.Valuation.nil +#eval eg₀val -- 0x00000008#32 + +open MLIR AST MLIR2Simple in /-- x + 0 -/ def lhs : Com Op (Ctxt.ofList [.int]) .int := - -- %c0 = 0 - Com.lete (cst 0) <| - -- %out = %x + %c0 - Com.lete (add ⟨1, by simp [Ctxt.snoc]⟩ ⟨0, by simp [Ctxt.snoc]⟩ ) <| - -- return %out - Com.ret ⟨0, by simp [Ctxt.snoc]⟩ - + [simple_com| { + ^bb0(%x : i32): + %c0 = "const" () { value = 0 : i32 } : () -> i32 + %out = "add" (%x, %c0) : (i32, i32) -> i32 + "return" (%out) : (i32) -> (i32) + }] + +open MLIR AST MLIR2Simple in /-- x -/ def rhs : Com Op (Ctxt.ofList [.int]) .int := - Com.ret ⟨0, by simp⟩ + [simple_com| { + ^bb0(%x : i32): + "return" (%x) : (i32) -> (i32) + }] +open MLIR AST MLIR2Simple in def p1 : PeepholeRewrite Op [.int] .int := { lhs := lhs, rhs := rhs, correct := by rw [lhs, rhs] - /- + /-: Com.denote (Com.lete (cst 0) (Com.lete (add { val := 1, property := _ } { val := 0, property := _ }) @@ -91,10 +186,20 @@ def p1 : PeepholeRewrite Op [.int] .int := done } -def ex1' : Com Op (Ctxt.ofList [.int]) .int := rewritePeepholeAt p1 1 lhs +def ex1_rewritePeepholeAt : Com Op (Ctxt.ofList [.int]) .int := rewritePeepholeAt p1 1 lhs + +theorem hex1_rewritePeephole : ex1_rewritePeepholeAt = ( + -- %c0 = 0 + Com.lete (cst 0) <| + -- %out_dead = %x + %c0 + Com.lete (add ⟨1, by simp [Ctxt.snoc]⟩ ⟨0, by simp [Ctxt.snoc]⟩ ) <| -- %out = %x + %c0 + -- ret %c0 + Com.ret ⟨2, by simp [Ctxt.snoc]⟩) + := by rfl +def ex1_rewritePeephole : Com Op (Ctxt.ofList [.int]) .int := rewritePeephole (fuel := 100) p1 lhs -theorem EX1' : ex1' = ( +theorem Hex1_rewritePeephole : ex1_rewritePeephole = ( -- %c0 = 0 Com.lete (cst 0) <| -- %out_dead = %x + %c0 @@ -103,6 +208,7 @@ theorem EX1' : ex1' = ( Com.ret ⟨2, by simp [Ctxt.snoc]⟩) := by rfl + end ToyNoRegion namespace ToyRegion