From 73466280e61e2a8464befda09712734bee3d88d0 Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Tue, 10 Oct 2023 17:37:05 +0100 Subject: [PATCH 01/28] Refactor: make the AST->ICom code dialect-generic --- SSA/Projects/InstCombine/LLVM/Transform.lean | 520 ++++-------------- .../LLVM/Transform/Dialects/InstCombine.lean | 283 ++++++++++ .../LLVM/Transform/NameMapping.lean | 27 + .../LLVM/Transform/TransformError.lean | 35 ++ 4 files changed, 467 insertions(+), 398 deletions(-) create mode 100644 SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean create mode 100644 SSA/Projects/InstCombine/LLVM/Transform/NameMapping.lean create mode 100644 SSA/Projects/InstCombine/LLVM/Transform/TransformError.lean diff --git a/SSA/Projects/InstCombine/LLVM/Transform.lean b/SSA/Projects/InstCombine/LLVM/Transform.lean index 763cf91c7..6b3867696 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform.lean @@ -1,116 +1,93 @@ -- should replace with Lean import once Pure is upstream import SSA.Projects.MLIRSyntax.AST import SSA.Projects.InstCombine.Base +import SSA.Projects.InstCombine.LLVM.Transform.NameMapping +import SSA.Projects.InstCombine.LLVM.Transform.TransformError import SSA.Core.Framework universe u namespace MLIR.AST -open InstCombine (MOp MTy Width) +section Monads -abbrev Context (φ) := List (MTy φ) +/-! + Even though we technically only need to know the `Ty` type, + our typeclass hierarchy is fully based on `Op`. + It is thus more convenient to incorporate `Op` in these types, so that there will be no ambiguity + errors. +-/ -abbrev Expr (Γ : Context φ) (ty : MTy φ) := IExpr (MOp φ) Γ ty -abbrev Com (Γ : Context φ) (ty : MTy φ) := ICom (MOp φ) Γ ty -abbrev Var (Γ : Context φ) (ty : MTy φ) := Ctxt.Var Γ ty +abbrev ExceptM (Op) {Ty} [OpSignature Op Ty] := Except (TransformError Ty) +abbrev BuilderM (Op) {Ty} [OpSignature Op Ty] := StateT NameMapping (ExceptM Op) +abbrev ReaderM (Op) {Ty} [OpSignature Op Ty] := ReaderT NameMapping (ExceptM Op) -abbrev Com.lete (body : Expr Γ ty₁) (rest : Com (ty₁::Γ) ty₂) : Com Γ ty₂ := - ICom.lete body rest +variable {Op Ty} [OpSignature Op Ty] -inductive TransformError - | nameAlreadyDeclared (var : String) - | undeclaredName (var : String) - | indexOutOfBounds (name : String) (index len : Nat) - | typeError {φ} (expected got : MTy φ) - | widthError {φ} (expected got : Width φ) - | unsupportedUnaryOp - | unsupportedBinaryOp - | unsupportedOp - | unsupportedType - | generic (error : String) - -namespace TransformError - -instance : Lean.ToFormat (MTy φ) where - format := repr - -instance : Repr TransformError where - reprPrec err _ := match err with - | nameAlreadyDeclared var => f!"Already declared {var}, shadowing is not allowed" - | undeclaredName name => f!"Undeclared name '{name}'" - | indexOutOfBounds name index len => - f!"Index of '{name}' out of bounds of the given context (index was {index}, but context has length {len})" - | typeError expected got => f!"Type mismatch: expected '{expected}', but 'name' has type '{got}'" - | widthError expected got => f!"Type mismatch: {expected} ≠ {got}" - | unsupportedUnaryOp => f!"Unsuported unary operation" - | unsupportedBinaryOp => f!"Unsuported binary operation" - | unsupportedOp => f!"Unsuported operation" - | unsupportedType => f!"Unsuported type" - | generic err => err - -end TransformError +instance : MonadLift (ReaderM Op) (BuilderM Op) where + monadLift x := do (ReaderT.run x (←get) : ExceptM ..) -/-- -Store the names of the raw SSA variables (as strings). -The order in the list should match the order in which they appear in the code. --/ -abbrev NameMapping := List String +instance : MonadLift (ExceptM Op) (ReaderM Op) where + monadLift x := do return ←x -def NameMapping.lookup (nm : NameMapping) (name : String) : Option Nat := - nm.indexOf? name +def BuilderM.runWithNewMapping (k : BuilderM Op α) : ExceptM Op α := + Prod.fst <$> StateT.run k [] -/-- - Add a new name to the mapping, assuming the name is not present in the list yet. - If the name is already present, return `none` --/ -def NameMapping.add (nm : NameMapping) (name : String) : Option NameMapping := - match nm.lookup name with - | none => some <| name::nm - | some _ => none +end Monads -example : (ExceptT ε <| ReaderM ρ) = (ReaderT ρ <| Except ε) := rfl +class TransformDialect (Op : Type) (Ty : outParam (Type)) (φ : outParam Nat) extends OpSignature Op Ty where + mkType : MLIRType φ → ExceptM Op Ty + mkReturn : (Γ : List Ty) → (opStx : AST.Op φ) → (args : List (Σ (ty : Ty), Ctxt.Var Γ ty)) + → ReaderM Op (Σ ty, ICom Op Γ ty) + mkExpr : (Γ : List Ty) → (opStx : AST.Op φ) → (args : List (Σ (ty : Ty), Ctxt.Var Γ ty)) + → ReaderM Op (Σ ty, IExpr Op Γ ty) -abbrev ExceptM := Except TransformError -abbrev BuilderM := StateT NameMapping ExceptM -abbrev ReaderM := ReaderT NameMapping ExceptM +variable (Op) {Ty φ} [d : TransformDialect Op Ty φ] -instance : MonadLift ReaderM BuilderM where - monadLift x := do (ReaderT.run x (←get) : ExceptM ..) +abbrev Context (Ty) := List (Ty) -instance : MonadLift ExceptM ReaderM where - monadLift x := do return ←x +abbrev Expr (Γ : Context Ty) (ty : Ty) := IExpr Op Γ ty +abbrev Com (Γ : Context Ty) (ty : Ty) := ICom Op Γ ty +abbrev Var (Γ : Context Ty) (ty : Ty) := Ctxt.Var Γ ty + + + +variable {Op} [d : TransformDialect Op Ty φ] [DecidableEq Ty] + +abbrev Com.lete (body : Expr Op Γ ty₁) (rest : Com Op (ty₁::Γ) ty₂) : Com Op Γ ty₂ := + ICom.lete body rest -def BuilderM.runWithNewMapping (k : BuilderM α) : ExceptM α := - Prod.fst <$> StateT.run k [] -structure DerivedContext (Γ : Context φ) where - ctxt : Context φ + + + +structure DerivedContext (Γ : Context Ty) where + ctxt : Context Ty diff : Ctxt.Diff Γ ctxt namespace DerivedContext /-- Every context is trivially derived from itself -/ -abbrev ofContext (Γ : Context φ) : DerivedContext Γ := ⟨Γ, .zero _⟩ +abbrev ofContext (Γ : Context Ty) : DerivedContext Γ := ⟨Γ, .zero _⟩ /-- `snoc` of a derived context applies `snoc` to the underlying context, and updates the diff -/ -def snoc {Γ : Context φ} : DerivedContext Γ → MTy φ → DerivedContext Γ +def snoc {Γ : Context Ty} : DerivedContext Γ → Ty → DerivedContext Γ | ⟨ctxt, diff⟩, ty => ⟨ty::ctxt, diff.toSnoc⟩ -instance {Γ : Context φ} : CoeHead (DerivedContext Γ) (Context φ) where +instance {Γ : Context Ty} : CoeHead (DerivedContext Γ) (Context Ty) where coe := fun ⟨Γ', _⟩ => Γ' -instance {Γ : Context φ} : CoeDep (Context φ) Γ (DerivedContext Γ) where +instance {Γ : Context Ty} : CoeDep (Context Ty) Γ (DerivedContext Γ) where coe := ⟨Γ, .zero _⟩ -instance {Γ : Context φ} {Γ' : DerivedContext Γ} : - CoeHead (DerivedContext (Γ' : Context φ)) (DerivedContext Γ) where +instance {Γ : Context Ty} {Γ' : DerivedContext Γ} : + CoeHead (DerivedContext (Γ' : Context Ty)) (DerivedContext Γ) where coe := fun ⟨Γ'', diff⟩ => ⟨Γ'', Γ'.diff + diff⟩ -instance {Γ : Context φ} {Γ' : DerivedContext Γ} : Coe (Expr Γ t) (Expr Γ' t) where +instance {Γ : Context Ty} {Γ' : DerivedContext Γ} : Coe (Expr Op Γ t) (Expr Op Γ' t) where coe e := e.changeVars Γ'.diff.toHom -instance {Γ' : DerivedContext Γ} : Coe (Var Γ t) (Var Γ' t) where +instance {Γ' : DerivedContext Γ} : Coe (Var Γ t) (Var (Γ' : Context Ty) t) where coe v := Γ'.diff.toHom v end DerivedContext @@ -121,8 +98,8 @@ end DerivedContext Throws an error if the variable name already exists in the mapping, essentially disallowing shadowing -/ -def addValToMapping (Γ : Context φ) (name : String) (ty : MTy φ) : - BuilderM (Σ (Γ' : DerivedContext Γ), Var Γ' ty) := do +def addValToMapping (Γ : Context Ty) (name : String) (ty : Ty) : + BuilderM Op (Σ (Γ' : DerivedContext Γ), Var Γ' ty) := do let some nm := (←get).add name | throw <| .nameAlreadyDeclared name set nm @@ -134,8 +111,8 @@ def addValToMapping (Γ : Context φ) (name : String) (ty : MTy φ) : Throws an error if the name is not present in the mapping (this indicates the name may be free), or if the type of the variable in the context is different from `expectedType` -/ -def getValFromContext (Γ : Context φ) (name : String) (expectedType : MTy φ) : - ReaderM (Ctxt.Var Γ expectedType) := do +def getValFromContext (Γ : Context Ty) (name : String) (expectedType : Ty) : + ReaderM Op (Ctxt.Var Γ expectedType) := do let index := (←read).lookup name let some index := index | throw <| .undeclaredName name let n := Γ.length @@ -150,216 +127,21 @@ def getValFromContext (Γ : Context φ) (name : String) (expectedType : MTy φ) else throw <| .typeError expectedType t -def BuilderM.isOk {α : Type} (x : BuilderM α) : Bool := +def BuilderM.isOk {α : Type} (x : BuilderM Op α) : Bool := match x.run [] with | Except.ok _ => true | Except.error _ => false -def BuilderM.isErr {α : Type} (x : BuilderM α) : Bool := +def BuilderM.isErr {α : Type} (x : BuilderM Op α) : Bool := match x.run [] with | Except.ok _ => true | Except.error _ => false -def mkUnaryOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) - (e : Var Γ ty) : ExceptM <| Expr Γ ty := - match ty with - | .bitvec w => - match op with - -- Can't use a single arm, Lean won't write the rhs accordingly - | .neg w' => if h : w = w' - then return ⟨ - .neg w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | .not w' => if h : w = w' - then return ⟨ - .not w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | .copy w' => if h : w = w' - then return ⟨ - .copy w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw <| .unsupportedUnaryOp - -def mkBinOp {Γ : Ctxt _} {ty : MTy φ} (op : MOp φ) - (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM <| Expr Γ ty := - match ty with - | .bitvec w => - match op with - -- Can't use a single arm, Lean won't write the rhs accordingly - | .add w' => if h : w = w' - then return ⟨ - .add w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .and w' => if h : w = w' - then return ⟨ - .and w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .or w' => if h : w = w' - then return ⟨ - .or w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .xor w' => if h : w = w' - then return ⟨ - .xor w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .shl w' => if h : w = w' - then return ⟨ - .shl w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .lshr w' => if h : w = w' - then return ⟨ - .lshr w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .ashr w' => if h : w = w' - then return ⟨ - .ashr w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .urem w' => if h : w = w' - then return ⟨ - .urem w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .srem w' => if h : w = w' - then return ⟨ - .srem w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .mul w' => if h : w = w' - then return ⟨ - .mul w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .sub w' => if h : w = w' - then return ⟨ - .sub w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .sdiv w' => if h : w = w' - then return ⟨ - .sdiv w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .udiv w' => if h : w = w' - then return ⟨ - .udiv w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw <| .unsupportedBinaryOp - -def mkIcmp {Γ : Ctxt _} {ty : MTy φ} (op : MOp φ) - (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM <| Expr Γ (.bitvec 1) := - match ty with - | .bitvec w => - match op with - | .icmp p w' => if h : w = w' - then return ⟨ - .icmp p w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil -⟩ - else throw <| .widthError w w' - | _ => throw .unsupportedOp -- unsupported icmp operation - -def mkSelect {Γ : Context φ} {ty : MTy φ} (op : MOp φ) - (c : Var Γ (.bitvec 1)) (e₁ e₂ : Var Γ ty) : - ExceptM <| Expr Γ ty := - match ty with - | .bitvec w => - match op with - | .select w' => if h : w = w' - then return ⟨ - .select w', - by simp [OpSignature.outTy, signature, h], - .cons c <|.cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw .unsupportedOp -- "Unsupported select operation" - -def mkOpExpr {Γ : Context φ} (op : MOp φ) - (arg : HVector (fun t => Ctxt.Var Γ t) (OpSignature.sig op)) : - ExceptM <| Expr Γ (OpSignature.outTy op) := - match op with - | .and _ | .or _ | .xor _ | .shl _ | .lshr _ | .ashr _ - | .add _ | .mul _ | .sub _ | .udiv _ | .sdiv _ - | .srem _ | .urem _ => - let (e₁, e₂) := arg.toTuple - mkBinOp op e₁ e₂ - | .icmp _ _ => - let (e₁, e₂) := arg.toTuple - mkIcmp op e₁ e₂ - | .not _ | .neg _ | .copy _ => - mkUnaryOp op arg.head - | .select _ => - let (c, e₁, e₂) := arg.toTuple - mkSelect op c e₁ e₂ - | .const .. => throw .unsupportedOp -- "Tried to build Op expression from constant" - -def MLIRType.mkTy : MLIRType φ → ExceptM (MTy φ) - | MLIRType.int Signedness.Signless w => do - return .bitvec w - | _ => throw .unsupportedType -- "Unsupported type" - -def TypedSSAVal.mkTy : TypedSSAVal φ → ExceptM (MTy φ) +def MLIRType.mkTy : MLIRType φ → ExceptM Op Ty := + d.mkType + + +def TypedSSAVal.mkTy : TypedSSAVal φ → ExceptM Op Ty | (.SSAVal _, ty) => ty.mkTy def mkVal (ty : InstCombine.Ty) : Int → Bitvec ty.width @@ -367,100 +149,42 @@ def mkVal (ty : InstCombine.Ty) : Int → Bitvec ty.width /-- Translate a `TypedSSAVal` (a name with an expected type), to a variable in the context. This expects the name to have already been declared before -/ -def TypedSSAVal.mkVal (Γ : Context φ) : TypedSSAVal φ → - ReaderM (Σ (ty : MTy φ), Var Γ ty) +def TypedSSAVal.mkVal (Γ : Context Ty) : TypedSSAVal φ → + ReaderM Op (Σ (ty : Ty), Var Γ ty) | (.SSAVal valStx, tyStx) => do - let ty ← tyStx.mkTy + let ty ← (tyStx.mkTy : ExceptM Op ..) let var ← getValFromContext Γ valStx ty return ⟨ty, var⟩ /-- Declare a new variable, by adding the passed name to the name mapping stored in the monad state -/ -def TypedSSAVal.newVal (Γ : Context φ) : TypedSSAVal φ → - BuilderM (Σ (Γ' : DerivedContext Γ) (ty : MTy φ), Var Γ' ty) +def TypedSSAVal.newVal (Γ : Context Ty) : TypedSSAVal φ → + BuilderM Op (Σ (Γ' : DerivedContext Γ) (ty : Ty), Var Γ' ty) | (.SSAVal valStx, tyStx) => do - let ty ← tyStx.mkTy + let ty ← (tyStx.mkTy : ExceptM Op ..) let ⟨Γ, var⟩ ← addValToMapping Γ valStx ty return ⟨Γ, ty, var⟩ -def mkExpr (Γ : Context φ) (opStx : Op φ) : ReaderM (Σ ty, Expr Γ ty) := do - match opStx.args with - | v₁Stx::v₂Stx::[] => - let ⟨.bitvec w₁, v₁⟩ ← TypedSSAVal.mkVal Γ v₁Stx - let ⟨.bitvec w₂, v₂⟩ ← TypedSSAVal.mkVal Γ v₂Stx - -- let ty₁ := ty₁.instantiave - let op ← match opStx.name with - | "llvm.and" => pure (MOp.and w₁) - | "llvm.or" => pure (MOp.or w₁) - | "llvm.xor" => pure (MOp.xor w₁) - | "llvm.shl" => pure (MOp.shl w₁) - | "llvm.lshr" => pure (MOp.lshr w₁) - | "llvm.ashr" => pure (MOp.ashr w₁) - | "llvm.urem" => pure (MOp.urem w₁) - | "llvm.srem" => pure (MOp.srem w₁) - | "llvm.select" => pure (MOp.select w₁) - | "llvm.add" => pure (MOp.add w₁) - | "llvm.mul" => pure (MOp.mul w₁) - | "llvm.sub" => pure (MOp.sub w₁) - | "llvm.sdiv" => pure (MOp.sdiv w₁) - | "llvm.udiv" => pure (MOp.udiv w₁) - --| "llvm.icmp" => return InstCombine.Op.icmp v₁.width - | _ => throw .unsupportedOp -- "Unsuported operation or invalid arguments" - if hty : w₁ = w₂ then - let binOp ← (mkBinOp op v₁ (hty ▸ v₂) : ExceptM _) - return ⟨.bitvec w₁, binOp⟩ - else - throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" - | vStx::[] => - let ⟨.bitvec w, v⟩ ← TypedSSAVal.mkVal Γ vStx - let op ← match opStx.name with - | "llvm.not" => pure <| MOp.not w - | "llvm.neg" => pure <| MOp.neg w - | _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}" - let op ← mkUnaryOp op v - return ⟨.bitvec w, op⟩ - | [] => - if opStx.name == "llvm.mlir.constant" - then do - let some att := opStx.attrs.getAttr "value" - | throw <| .generic "tried to resolve constant without 'value' attribute" - match att with - | .int val ty => - let opTy@(MTy.bitvec w) ← ty.mkTy - return ⟨opTy, ⟨ - MOp.const w val, - by simp [OpSignature.outTy, signature, *], - HVector.nil, - HVector.nil - ⟩⟩ - | _ => throw <| .generic "invalid constant attribute" - else - throw <| .generic s!"invalid (0-ary) expression {opStx.name}" - | _ => throw <| .generic s!"unsupported expression (with unsupported arity) {opStx.name}" - -def mkReturn (Γ : Context φ) (opStx : Op φ) : ReaderM (Σ ty, Com Γ ty) := - if opStx.name == "llvm.return" - then match opStx.args with - | vStx::[] => do - let ⟨ty, v⟩ ← vStx.mkVal Γ - return ⟨ty, ICom.ret v⟩ - | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {opStx.args.length})" - else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" - /-- Given a list of `TypedSSAVal`s, treat each as a binder and declare a new variable with the given name and type -/ -private def declareBindings (Γ : Context φ) (vals : List (TypedSSAVal φ)) : - BuilderM (DerivedContext Γ) := do +private def declareBindings (Γ : Context Ty) (vals : List (TypedSSAVal φ)) : + BuilderM Op (DerivedContext Γ) := do vals.foldlM (fun Γ' ssaVal => do let ⟨Γ'', _⟩ ← TypedSSAVal.newVal Γ' ssaVal return Γ'' ) (.ofContext Γ) -private def mkComHelper (Γ : Context φ) : - List (Op φ) → BuilderM (Σ (ty : _), Com Γ ty) - | [retStx] => mkReturn Γ retStx +def mkExpr (Γ : Context Ty) (opStx : AST.Op φ) : ReaderM Op (Σ ty, Expr Op Γ ty) := do + let args ← opStx.args.mapM (TypedSSAVal.mkVal Γ) + d.mkExpr Γ opStx args + +private def mkComHelper (Γ : Context Ty) : + List (AST.Op φ) → BuilderM Op (Σ (ty : _), Com Op Γ ty) + | [retStx] => do + let args ← (retStx.args.mapM (TypedSSAVal.mkVal Γ) : ReaderM Op ..) + d.mkReturn Γ retStx args | lete::rest => do - let ⟨ty₁, expr⟩ ← mkExpr Γ lete + let ⟨ty₁, expr⟩ ← (mkExpr Γ lete : ReaderM Op ..) if h : lete.res.length != 1 then throw <| .generic s!"Each let-binding must have exactly one name on the left-hand side. Operations with multiple, or no, results are not yet supported.\n\tExpected a list of length one, found `{repr lete}`" else @@ -469,7 +193,7 @@ private def mkComHelper (Γ : Context φ) : return ⟨ty₂, Com.lete expr body⟩ | [] => throw <| .generic "Ill-formed (empty) block" -def mkCom (reg : Region φ) : ExceptM (Σ (Γ : Context φ) (ty : MTy φ), Com Γ ty) := +def mkCom (reg : Region φ) : ExceptM Op (Σ (Γ : Context Ty) (ty : Ty), Com Op Γ ty) := match reg.ops with | [] => throw <| .generic "Ill-formed region (empty)" | coms => BuilderM.runWithNewMapping <| do @@ -482,48 +206,48 @@ def mkCom (reg : Region φ) : ExceptM (Σ (Γ : Context φ) (ty : MTy φ), Com Finally, we show how to instantiate a family of programs to a concrete program -/ -def _root_.InstCombine.MTy.instantiate (vals : Vector Nat φ) : MTy φ → InstCombine.Ty - | .bitvec w => .bitvec <| .concrete <| w.instantiate vals - -def _root_.InstCombine.MOp.instantiate (vals : Vector Nat φ) : MOp φ → InstCombine.Op - | .and w => .and (w.instantiate vals) - | .or w => .or (w.instantiate vals) - | .not w => .not (w.instantiate vals) - | .xor w => .xor (w.instantiate vals) - | .shl w => .shl (w.instantiate vals) - | .lshr w => .lshr (w.instantiate vals) - | .ashr w => .ashr (w.instantiate vals) - | .urem w => .urem (w.instantiate vals) - | .srem w => .srem (w.instantiate vals) - | .select w => .select (w.instantiate vals) - | .add w => .add (w.instantiate vals) - | .mul w => .mul (w.instantiate vals) - | .sub w => .sub (w.instantiate vals) - | .neg w => .neg (w.instantiate vals) - | .copy w => .copy (w.instantiate vals) - | .sdiv w => .sdiv (w.instantiate vals) - | .udiv w => .udiv (w.instantiate vals) - | .icmp c w => .icmp c (w.instantiate vals) - | .const w val => .const (w.instantiate vals) val - -def Context.instantiate (vals : Vector Nat φ) (Γ : Context φ) : Ctxt InstCombine.Ty := - Γ.map (MTy.instantiate vals) - -def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCombine.Op) where - mapOp := MOp.instantiate vals - mapTy := MTy.instantiate vals - preserves_signature op := by - simp only [MTy.instantiate, MOp.instantiate, ConcreteOrMVar.instantiate, (· <$> ·), signature, - InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, - true_and] - cases op <;> simp only [List.map, and_self, List.cons.injEq] +-- def _root_.InstCombine.MTy.instantiate (vals : Vector Nat φ) : Ty → InstCombine.Ty +-- | .bitvec w => .bitvec <| .concrete <| w.instantiate vals + +-- def _root_.InstCombine.MOp.instantiate (vals : Vector Nat φ) : MOp φ → InstCombine.Op +-- | .and w => .and (w.instantiate vals) +-- | .or w => .or (w.instantiate vals) +-- | .not w => .not (w.instantiate vals) +-- | .xor w => .xor (w.instantiate vals) +-- | .shl w => .shl (w.instantiate vals) +-- | .lshr w => .lshr (w.instantiate vals) +-- | .ashr w => .ashr (w.instantiate vals) +-- | .urem w => .urem (w.instantiate vals) +-- | .srem w => .srem (w.instantiate vals) +-- | .select w => .select (w.instantiate vals) +-- | .add w => .add (w.instantiate vals) +-- | .mul w => .mul (w.instantiate vals) +-- | .sub w => .sub (w.instantiate vals) +-- | .neg w => .neg (w.instantiate vals) +-- | .copy w => .copy (w.instantiate vals) +-- | .sdiv w => .sdiv (w.instantiate vals) +-- | .udiv w => .udiv (w.instantiate vals) +-- | .icmp c w => .icmp c (w.instantiate vals) +-- | .const w val => .const (w.instantiate vals) val + +-- def Context.instantiate (vals : Vector Nat φ) (Γ : Context Ty) : Ctxt InstCombine.Ty := +-- Γ.map (MTy.instantiate vals) + +-- def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCombine.Op) where +-- mapOp := MOp.instantiate vals +-- mapTy := MTy.instantiate vals +-- preserves_signature op := by +-- simp only [MTy.instantiate, MOp.instantiate, ConcreteOrMVar.instantiate, (· <$> ·), signature, +-- InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, +-- true_and] +-- cases op <;> simp only [List.map, and_self, List.cons.injEq] -open InstCombine (Op Ty) in -def mkComInstantiate (reg : Region φ) : - ExceptM (Vector Nat φ → Σ (Γ : Ctxt Ty) (ty : Ty), ICom InstCombine.Op Γ ty) := do - let ⟨Γ, ty, icom⟩ ← mkCom reg - return fun vals => - ⟨Γ.instantiate vals, ty.instantiate vals, icom.map (MOp.instantiateCom vals)⟩ +-- open InstCombine (Op Ty) in +-- def mkComInstantiate (reg : Region φ) : +-- ExceptM (Vector Nat φ → Σ (Γ : Ctxt Ty) (ty : Ty), ICom InstCombine.Op Γ ty) := do +-- let ⟨Γ, ty, icom⟩ ← mkCom reg +-- return fun vals => +-- ⟨Γ.instantiate vals, ty.instantiate vals, icom.map (MOp.instantiateCom vals)⟩ end MLIR.AST diff --git a/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean b/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean new file mode 100644 index 000000000..d7e0a9a80 --- /dev/null +++ b/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean @@ -0,0 +1,283 @@ +import SSA.Projects.InstCombine.LLVM.Transform + +namespace InstCombine + +open MLIR +open AST (TypedSSAVal) +open Ctxt (Var) + +protected abbrev ExceptM (φ) := AST.ExceptM (MOp φ) +protected abbrev ReaderM (φ) := AST.ReaderM (MOp φ) +protected abbrev BuilderM (φ) := AST.BuilderM (MOp φ) + +protected abbrev Context (φ) := List (MTy φ) +protected abbrev Expr {φ} := IExpr (MOp φ) + +open InstCombine (ExceptM ReaderM BuilderM Context Expr) + +def mkType : AST.MLIRType φ → ExceptM φ (MTy φ) + | .int .Signless w => return .bitvec w + | _ => throw .unsupportedType -- "Unsupported type" + +def mkUnaryOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) + (e : Var Γ ty) : ExceptM φ <| Expr Γ ty := + match ty with + | .bitvec w => + match op with + -- Can't use a single arm, Lean won't write the rhs accordingly + | .neg w' => if h : w = w' + then return ⟨ + .neg w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e) .nil, + .nil + ⟩ + else throw <| .widthError w w' + | .not w' => if h : w = w' + then return ⟨ + .not w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e) .nil, + .nil + ⟩ + else throw <| .widthError w w' + | .copy w' => if h : w = w' + then return ⟨ + .copy w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e) .nil, + .nil + ⟩ + else throw <| .widthError w w' + | _ => throw <| .unsupportedUnaryOp + +def mkBinOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) + (e₁ e₂ : Var Γ ty) : ExceptM φ <| Expr Γ ty := + match ty with + | .bitvec w => + match op with + -- Can't use a single arm, Lean won't write the rhs accordingly + | .add w' => if h : w = w' + then return ⟨ + .add w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .and w' => if h : w = w' + then return ⟨ + .and w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .or w' => if h : w = w' + then return ⟨ + .or w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .xor w' => if h : w = w' + then return ⟨ + .xor w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .shl w' => if h : w = w' + then return ⟨ + .shl w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .lshr w' => if h : w = w' + then return ⟨ + .lshr w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .ashr w' => if h : w = w' + then return ⟨ + .ashr w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .urem w' => if h : w = w' + then return ⟨ + .urem w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .srem w' => if h : w = w' + then return ⟨ + .srem w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .mul w' => if h : w = w' + then return ⟨ + .mul w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .sub w' => if h : w = w' + then return ⟨ + .sub w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .sdiv w' => if h : w = w' + then return ⟨ + .sdiv w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .udiv w' => if h : w = w' + then return ⟨ + .udiv w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | _ => throw <| .unsupportedBinaryOp + +def mkIcmp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) + (e₁ e₂ : Var Γ ty) : ExceptM φ <| Expr Γ (.bitvec 1) := + match ty with + | .bitvec w => + match op with + | .icmp p w' => if h : w = w' + then return ⟨ + .icmp p w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil +⟩ + else throw <| .widthError w w' + | _ => throw .unsupportedOp -- unsupported icmp operation + +def mkSelect {Γ : Context φ} {ty : MTy φ} (op : MOp φ) + (c : Var Γ (.bitvec 1)) (e₁ e₂ : Var Γ ty) : + ExceptM φ <| Expr Γ ty := + match ty with + | .bitvec w => + match op with + | .select w' => if h : w = w' + then return ⟨ + .select w', + by simp [OpSignature.outTy, signature, h], + .cons c <|.cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | _ => throw .unsupportedOp -- "Unsupported select operation" + +def mkOpExpr {Γ : Context φ} (op : MOp φ) + (arg : HVector (fun t => Ctxt.Var Γ t) (OpSignature.sig op)) : + ExceptM φ <| Expr Γ (OpSignature.outTy op) := + match op with + | .and _ | .or _ | .xor _ | .shl _ | .lshr _ | .ashr _ + | .add _ | .mul _ | .sub _ | .udiv _ | .sdiv _ + | .srem _ | .urem _ => + let (e₁, e₂) := arg.toTuple + mkBinOp op e₁ e₂ + | .icmp _ _ => + let (e₁, e₂) := arg.toTuple + mkIcmp op e₁ e₂ + | .not _ | .neg _ | .copy _ => + mkUnaryOp op arg.head + | .select _ => + let (c, e₁, e₂) := arg.toTuple + mkSelect op c e₁ e₂ + | .const .. => throw .unsupportedOp -- "Tried to build Op expression from constant" + +def mkExpr (Γ : Context φ) (opStx : AST.Op φ) (args : List (Σ (ty : MTy φ), Ctxt.Var Γ ty)) : + ReaderM φ (Σ ty, Expr Γ ty) := do + match args with + | ⟨.bitvec w₁, v₁⟩::⟨.bitvec w₂, v₂⟩::[] => + -- let ty₁ := ty₁.instantiave + let op ← match opStx.name with + | "llvm.and" => pure (MOp.and w₁) + | "llvm.or" => pure (MOp.or w₁) + | "llvm.xor" => pure (MOp.xor w₁) + | "llvm.shl" => pure (MOp.shl w₁) + | "llvm.lshr" => pure (MOp.lshr w₁) + | "llvm.ashr" => pure (MOp.ashr w₁) + | "llvm.urem" => pure (MOp.urem w₁) + | "llvm.srem" => pure (MOp.srem w₁) + | "llvm.select" => pure (MOp.select w₁) + | "llvm.add" => pure (MOp.add w₁) + | "llvm.mul" => pure (MOp.mul w₁) + | "llvm.sub" => pure (MOp.sub w₁) + | "llvm.sdiv" => pure (MOp.sdiv w₁) + | "llvm.udiv" => pure (MOp.udiv w₁) + --| "llvm.icmp" => return InstCombine.Op.icmp v₁.width + | _ => throw .unsupportedOp -- "Unsuported operation or invalid arguments" + if hty : w₁ = w₂ then + let binOp ← (mkBinOp op v₁ (hty ▸ v₂) : ExceptM ..) + return ⟨.bitvec w₁, binOp⟩ + else + throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" + | ⟨.bitvec w, v⟩::[] => + let op ← match opStx.name with + | "llvm.not" => pure <| MOp.not w + | "llvm.neg" => pure <| MOp.neg w + | _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}" + let op ← mkUnaryOp op v + return ⟨.bitvec w, op⟩ + | [] => + if opStx.name == "llvm.mlir.constant" + then do + let some att := opStx.attrs.getAttr "value" + | throw <| .generic "tried to resolve constant without 'value' attribute" + match att with + | .int val ty => + let opTy@(MTy.bitvec w) ← mkType ty + return ⟨opTy, ⟨ + MOp.const w val, + by simp [OpSignature.outTy, signature, *], + HVector.nil, + HVector.nil + ⟩⟩ + | _ => throw <| .generic "invalid constant attribute" + else + throw <| .generic s!"invalid (0-ary) expression {opStx.name}" + | _ => throw <| .generic s!"unsupported expression (with unsupported arity) {opStx.name}" + +def mkReturn (Γ : Context φ) (opStx : AST.Op φ) (args : List (Σ (ty : MTy φ), Ctxt.Var Γ ty)) : + ReaderM φ (Σ ty, ICom (MOp φ) Γ ty) := + if opStx.name == "llvm.return" + then match args with + | ⟨ty, v⟩::[] => do + return ⟨ty, ICom.ret v⟩ + | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {args.length})" + else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" + + + +instance : AST.TransformDialect (MOp φ) (MTy φ) φ where + mkType := mkType + mkReturn := mkReturn + mkExpr := mkExpr \ No newline at end of file diff --git a/SSA/Projects/InstCombine/LLVM/Transform/NameMapping.lean b/SSA/Projects/InstCombine/LLVM/Transform/NameMapping.lean new file mode 100644 index 000000000..d9140e77e --- /dev/null +++ b/SSA/Projects/InstCombine/LLVM/Transform/NameMapping.lean @@ -0,0 +1,27 @@ +import Std.Data.List.Basic + +namespace MLIR.AST + +/-- +Store the names of the raw SSA variables (as strings). +The order in the list should match the order in which they appear in the code. +-/ +abbrev NameMapping := List String + +namespace NameMapping + +def lookup (nm : NameMapping) (name : String) : Option Nat := + nm.indexOf? name + +/-- + Add a new name to the mapping, assuming the name is not present in the list yet. + If the name is already present, return `none` +-/ +def add (nm : NameMapping) (name : String) : Option NameMapping := + match nm.lookup name with + | none => some <| name::nm + | some _ => none + +end NameMapping + +end MLIR.AST \ No newline at end of file diff --git a/SSA/Projects/InstCombine/LLVM/Transform/TransformError.lean b/SSA/Projects/InstCombine/LLVM/Transform/TransformError.lean new file mode 100644 index 000000000..010e788c7 --- /dev/null +++ b/SSA/Projects/InstCombine/LLVM/Transform/TransformError.lean @@ -0,0 +1,35 @@ +import SSA.Projects.MLIRSyntax.AST + +namespace MLIR.AST + +inductive TransformError (Ty : Type) + | nameAlreadyDeclared (var : String) + | undeclaredName (var : String) + | indexOutOfBounds (name : String) (index len : Nat) + | typeError (expected got : Ty) + | widthError {φ} (expected got : Width φ) + | unsupportedUnaryOp + | unsupportedBinaryOp + | unsupportedOp + | unsupportedType + | generic (error : String) + +namespace TransformError + +instance [Repr Ty] : Repr (TransformError Ty) where + reprPrec err _ := match err with + | nameAlreadyDeclared var => f!"Already declared {var}, shadowing is not allowed" + | undeclaredName name => f!"Undeclared name '{name}'" + | indexOutOfBounds name index len => + f!"Index of '{name}' out of bounds of the given context (index was {index}, but context has length {len})" + | typeError expected got => f!"Type mismatch: expected '{repr expected}', but 'name' has type '{repr got}'" + | widthError expected got => f!"Type mismatch: {expected} ≠ {got}" + | unsupportedUnaryOp => f!"Unsuported unary operation" + | unsupportedBinaryOp => f!"Unsuported binary operation" + | unsupportedOp => f!"Unsuported operation" + | unsupportedType => f!"Unsuported type" + | generic err => err + +end TransformError + +end MLIR.AST \ No newline at end of file From b42347e353e5c1070b5a3be7f5d0c363bf8f9b73 Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Tue, 10 Oct 2023 18:06:33 +0100 Subject: [PATCH 02/28] TransformDialectInstantiate class --- SSA/Projects/InstCombine/LLVM/EDSL.lean | 2 + SSA/Projects/InstCombine/LLVM/Transform.lean | 49 +-------------- .../LLVM/Transform/Instantiate.lean | 63 +++++++++++++++++++ 3 files changed, 67 insertions(+), 47 deletions(-) create mode 100644 SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 2f115310a..7356482aa 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -2,6 +2,8 @@ import Qq import SSA.Projects.MLIRSyntax.EDSL import SSA.Projects.InstCombine.LLVM.Transform +import SSA.Projects.InstCombine.LLVM.Transform.Dialects.InstCombine + open Qq Lean Meta Elab.Term open MLIR.AST InstCombine in diff --git a/SSA/Projects/InstCombine/LLVM/Transform.lean b/SSA/Projects/InstCombine/LLVM/Transform.lean index 6b3867696..b322a2dfa 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform.lean @@ -193,6 +193,8 @@ private def mkComHelper (Γ : Context Ty) : return ⟨ty₂, Com.lete expr body⟩ | [] => throw <| .generic "Ill-formed (empty) block" +variable (Op) + def mkCom (reg : Region φ) : ExceptM Op (Σ (Γ : Context Ty) (ty : Ty), Com Op Γ ty) := match reg.ops with | [] => throw <| .generic "Ill-formed region (empty)" @@ -201,53 +203,6 @@ def mkCom (reg : Region φ) : ExceptM Op (Σ (Γ : Context Ty) (ty : Ty), Com Op let icom ← mkComHelper Γ coms return ⟨Γ, icom⟩ -/-! - ## Instantiation - Finally, we show how to instantiate a family of programs to a concrete program --/ --- def _root_.InstCombine.MTy.instantiate (vals : Vector Nat φ) : Ty → InstCombine.Ty --- | .bitvec w => .bitvec <| .concrete <| w.instantiate vals - --- def _root_.InstCombine.MOp.instantiate (vals : Vector Nat φ) : MOp φ → InstCombine.Op --- | .and w => .and (w.instantiate vals) --- | .or w => .or (w.instantiate vals) --- | .not w => .not (w.instantiate vals) --- | .xor w => .xor (w.instantiate vals) --- | .shl w => .shl (w.instantiate vals) --- | .lshr w => .lshr (w.instantiate vals) --- | .ashr w => .ashr (w.instantiate vals) --- | .urem w => .urem (w.instantiate vals) --- | .srem w => .srem (w.instantiate vals) --- | .select w => .select (w.instantiate vals) --- | .add w => .add (w.instantiate vals) --- | .mul w => .mul (w.instantiate vals) --- | .sub w => .sub (w.instantiate vals) --- | .neg w => .neg (w.instantiate vals) --- | .copy w => .copy (w.instantiate vals) --- | .sdiv w => .sdiv (w.instantiate vals) --- | .udiv w => .udiv (w.instantiate vals) --- | .icmp c w => .icmp c (w.instantiate vals) --- | .const w val => .const (w.instantiate vals) val - --- def Context.instantiate (vals : Vector Nat φ) (Γ : Context Ty) : Ctxt InstCombine.Ty := --- Γ.map (MTy.instantiate vals) - --- def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCombine.Op) where --- mapOp := MOp.instantiate vals --- mapTy := MTy.instantiate vals --- preserves_signature op := by --- simp only [MTy.instantiate, MOp.instantiate, ConcreteOrMVar.instantiate, (· <$> ·), signature, --- InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, --- true_and] --- cases op <;> simp only [List.map, and_self, List.cons.injEq] - - --- open InstCombine (Op Ty) in --- def mkComInstantiate (reg : Region φ) : --- ExceptM (Vector Nat φ → Σ (Γ : Ctxt Ty) (ty : Ty), ICom InstCombine.Op Γ ty) := do --- let ⟨Γ, ty, icom⟩ ← mkCom reg --- return fun vals => --- ⟨Γ.instantiate vals, ty.instantiate vals, icom.map (MOp.instantiateCom vals)⟩ end MLIR.AST diff --git a/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean b/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean new file mode 100644 index 000000000..bda2fa332 --- /dev/null +++ b/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean @@ -0,0 +1,63 @@ +import SSA.Core.Framework +import SSA.Projects.InstCombine.LLVM.Transform + +namespace MLIR + +/-! + ## Instantiation + Finally, we show how to instantiate a family of programs to a concrete program +-/ + +class TransformDialectInstantiate (Op : Type) (φ : Nat) (Ty MOp MTy : outParam Type) + [OpSignature Op Ty] [AST.TransformDialect MOp MTy φ] where + morphism : Vector Nat φ → DialectMorphism MOp Op + + +set_option linter.unusedVariables false -- linter gives a false positive for `φ` in `[∀ φ, ...]` + +variable (Op) (φ) {Ty} {MOp MTy} [OpSignature Op Ty] + [AST.TransformDialect MOp MTy φ] + [inst : TransformDialectInstantiate Op φ Ty MOp MTy] + [DecidableEq MTy] + +-- def _root_.InstCombine.MTy.instantiate (vals : Vector Nat φ) : Ty → InstCombine.Ty +-- | .bitvec w => .bitvec <| .concrete <| w.instantiate vals + +-- def _root_.InstCombine.MOp.instantiate (vals : Vector Nat φ) : MOp φ → InstCombine.Op +-- | .and w => .and (w.instantiate vals) +-- | .or w => .or (w.instantiate vals) +-- | .not w => .not (w.instantiate vals) +-- | .xor w => .xor (w.instantiate vals) +-- | .shl w => .shl (w.instantiate vals) +-- | .lshr w => .lshr (w.instantiate vals) +-- | .ashr w => .ashr (w.instantiate vals) +-- | .urem w => .urem (w.instantiate vals) +-- | .srem w => .srem (w.instantiate vals) +-- | .select w => .select (w.instantiate vals) +-- | .add w => .add (w.instantiate vals) +-- | .mul w => .mul (w.instantiate vals) +-- | .sub w => .sub (w.instantiate vals) +-- | .neg w => .neg (w.instantiate vals) +-- | .copy w => .copy (w.instantiate vals) +-- | .sdiv w => .sdiv (w.instantiate vals) +-- | .udiv w => .udiv (w.instantiate vals) +-- | .icmp c w => .icmp c (w.instantiate vals) +-- | .const w val => .const (w.instantiate vals) val + +-- def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCombine.Op) where +-- mapOp := MOp.instantiate vals +-- mapTy := MTy.instantiate vals +-- preserves_signature op := by +-- simp only [MTy.instantiate, MOp.instantiate, ConcreteOrMVar.instantiate, (· <$> ·), signature, +-- InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, +-- true_and] +-- cases op <;> simp only [List.map, and_self, List.cons.injEq] + +def AST.mkComInstantiate (reg : AST.Region φ) : + AST.ExceptM MOp (Vector Nat φ → Σ (Γ : Ctxt Ty) (ty : Ty), ICom Op Γ ty) := do + let ⟨Γ, ty, icom⟩ ← AST.mkCom MOp reg + return fun vals => + let f := inst.morphism vals + ⟨Γ.map f.mapTy, f.mapTy ty, icom.map f⟩ + +end MLIR \ No newline at end of file From 5b39a75baaf84ef2fc1223fccc10411d5c26bae4 Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Tue, 10 Oct 2023 18:10:37 +0100 Subject: [PATCH 03/28] Implement TransformDialectInstantiate for InstCombine --- .../LLVM/Transform/Dialects/InstCombine.lean | 42 ++++++++++++++++++- .../LLVM/Transform/Instantiate.lean | 33 --------------- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean b/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean index d7e0a9a80..d4da4077d 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean @@ -1,4 +1,5 @@ import SSA.Projects.InstCombine.LLVM.Transform +import SSA.Projects.InstCombine.LLVM.Transform.Instantiate namespace InstCombine @@ -275,9 +276,48 @@ def mkReturn (Γ : Context φ) (opStx : AST.Op φ) (args : List (Σ (ty : MTy φ | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {args.length})" else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" +/-! ## Instantiation -/ +def MTy.instantiate (vals : Vector Nat φ) : (MTy φ) → Ty + | .bitvec w => .bitvec <| .concrete <| w.instantiate vals + +def MOp.instantiate (vals : Vector Nat φ) : MOp φ → Op + | .and w => .and (w.instantiate vals) + | .or w => .or (w.instantiate vals) + | .not w => .not (w.instantiate vals) + | .xor w => .xor (w.instantiate vals) + | .shl w => .shl (w.instantiate vals) + | .lshr w => .lshr (w.instantiate vals) + | .ashr w => .ashr (w.instantiate vals) + | .urem w => .urem (w.instantiate vals) + | .srem w => .srem (w.instantiate vals) + | .select w => .select (w.instantiate vals) + | .add w => .add (w.instantiate vals) + | .mul w => .mul (w.instantiate vals) + | .sub w => .sub (w.instantiate vals) + | .neg w => .neg (w.instantiate vals) + | .copy w => .copy (w.instantiate vals) + | .sdiv w => .sdiv (w.instantiate vals) + | .udiv w => .udiv (w.instantiate vals) + | .icmp c w => .icmp c (w.instantiate vals) + | .const w val => .const (w.instantiate vals) val + +/-! ## Instances -/ + instance : AST.TransformDialect (MOp φ) (MTy φ) φ where mkType := mkType mkReturn := mkReturn - mkExpr := mkExpr \ No newline at end of file + mkExpr := mkExpr + +instance : TransformDialectInstantiate Op φ Ty (MOp φ) (MTy φ) where + morphism vals := { + mapOp := MOp.instantiate vals, + mapTy := MTy.instantiate vals, + preserves_signature := by + intro op + simp only [MTy.instantiate, MOp.instantiate, ConcreteOrMVar.instantiate, (· <$> ·), signature, + InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, + true_and] + cases op <;> simp only [List.map, and_self, List.cons.injEq] + } \ No newline at end of file diff --git a/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean b/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean index bda2fa332..878d70051 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean @@ -20,39 +20,6 @@ variable (Op) (φ) {Ty} {MOp MTy} [OpSignature Op Ty] [inst : TransformDialectInstantiate Op φ Ty MOp MTy] [DecidableEq MTy] --- def _root_.InstCombine.MTy.instantiate (vals : Vector Nat φ) : Ty → InstCombine.Ty --- | .bitvec w => .bitvec <| .concrete <| w.instantiate vals - --- def _root_.InstCombine.MOp.instantiate (vals : Vector Nat φ) : MOp φ → InstCombine.Op --- | .and w => .and (w.instantiate vals) --- | .or w => .or (w.instantiate vals) --- | .not w => .not (w.instantiate vals) --- | .xor w => .xor (w.instantiate vals) --- | .shl w => .shl (w.instantiate vals) --- | .lshr w => .lshr (w.instantiate vals) --- | .ashr w => .ashr (w.instantiate vals) --- | .urem w => .urem (w.instantiate vals) --- | .srem w => .srem (w.instantiate vals) --- | .select w => .select (w.instantiate vals) --- | .add w => .add (w.instantiate vals) --- | .mul w => .mul (w.instantiate vals) --- | .sub w => .sub (w.instantiate vals) --- | .neg w => .neg (w.instantiate vals) --- | .copy w => .copy (w.instantiate vals) --- | .sdiv w => .sdiv (w.instantiate vals) --- | .udiv w => .udiv (w.instantiate vals) --- | .icmp c w => .icmp c (w.instantiate vals) --- | .const w val => .const (w.instantiate vals) val - --- def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCombine.Op) where --- mapOp := MOp.instantiate vals --- mapTy := MTy.instantiate vals --- preserves_signature op := by --- simp only [MTy.instantiate, MOp.instantiate, ConcreteOrMVar.instantiate, (· <$> ·), signature, --- InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, --- true_and] --- cases op <;> simp only [List.map, and_self, List.cons.injEq] - def AST.mkComInstantiate (reg : AST.Region φ) : AST.ExceptM MOp (Vector Nat φ → Σ (Γ : Ctxt Ty) (ty : Ty), ICom Op Γ ty) := do let ⟨Γ, ty, icom⟩ ← AST.mkCom MOp reg From f5db8f1f605d2916d5e8fceb2e7c8b8734543b6d Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Tue, 10 Oct 2023 18:37:04 +0100 Subject: [PATCH 04/28] Rename mlir_icom -> alive_icom --- SSA/Projects/InstCombine/LLVM/EDSL.lean | 47 ++++++++++++++++--- .../LLVM/Transform/Instantiate.lean | 4 +- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 7356482aa..c2a3c6854 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -1,21 +1,56 @@ import Qq import SSA.Projects.MLIRSyntax.EDSL import SSA.Projects.InstCombine.LLVM.Transform +import SSA.Projects.InstCombine.LLVM.Transform.Instantiate import SSA.Projects.InstCombine.LLVM.Transform.Dialects.InstCombine -open Qq Lean Meta Elab.Term +open Qq Lean Meta Elab.Term Elab Command + +open MLIR + +-- def elabToMlirICom (Op : Q(Type)) (mvars : Syntax.TSepArray `term ",") (reg : TSyntax `mlir_region) : +-- TermElabM Unit := do +-- let φ : Nat := mvars.getElems.size +-- let Ty ← mkFreshExprMVarQ q(Type) +-- let MOp ← mkFreshExprMVarQ q(Type) +-- let MTy ← mkFreshExprMVarQ q(Type) +-- let _ ← synthInstanceQ q(OpSignature $Op $Ty) +-- let _ ← synthInstanceQ q(AST.TransformDialect $MOp $MTy $φ) +-- let instInst ← synthInstanceQ q(TransformDialectInstantiate $Op $φ $Ty $MOp $MTy) + +-- let ast_stx ← `([mlir_region| $reg]) +-- let ast ← elabTermEnsuringTypeQ ast_stx q(AST.Region $φ) + +-- let mvalues ← `(⟨[$mvars,*], by rfl⟩) +-- let mvalues : Q(Vector Nat $φ) ← elabTermEnsuringType mvalues q(Vector Nat $φ) + +-- let com := q(AST.mkComInstantiate (instInst:=$instInst) $ast) +-- synthesizeSyntheticMVarsNoPostponing +-- -- let com : Q(ExceptM (Σ (Γ' : Ctxt Ty) (ty : InstCombine.Ty), Com Γ' ty)) ← +-- -- withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do +-- -- withTransparency (mode := TransparencyMode.all) <| +-- -- return ←reduce com +-- -- trace[Meta] com +-- -- match com with +-- -- | ~q(Except.ok $comOk) => return q(($comOk).snd.snd) +-- -- | ~q(Except.error $err) => do +-- -- let err ← unsafe evalExpr TransformError q(TransformError) err +-- -- throwError "Translation failed with error:\n\t{repr err}" +-- -- | e => throwError "Translation failed to reduce, possibly too generic syntax\n\t{e}" + +-- sorry open MLIR.AST InstCombine in -elab "[mlir_icom (" mvars:term,* ")| " reg:mlir_region "]" : term => do +elab "[alive_icom (" mvars:term,* ")| " reg:mlir_region "]" : term => do let ast_stx ← `([mlir_region| $reg]) let φ : Nat := mvars.getElems.size let ast ← elabTermEnsuringTypeQ ast_stx q(Region $φ) let mvalues ← `(⟨[$mvars,*], by rfl⟩) let mvalues : Q(Vector Nat $φ) ← elabTermEnsuringType mvalues q(Vector Nat $φ) - let com := q(mkComInstantiate $ast |>.map (· $mvalues)) + let com := q(mkComInstantiate Op $φ $ast |>.map (· $mvalues)) synthesizeSyntheticMVarsNoPostponing - let com : Q(ExceptM (Σ (Γ' : Ctxt Ty) (ty : InstCombine.Ty), Com Γ' ty)) ← + let com : Q(ExceptM Op (Σ (Γ' : Ctxt Ty) (ty : Ty), ICom Op Γ' ty)) ← withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do withTransparency (mode := TransparencyMode.all) <| return ←reduce com @@ -23,8 +58,8 @@ elab "[mlir_icom (" mvars:term,* ")| " reg:mlir_region "]" : term => do match com with | ~q(Except.ok $comOk) => return q(($comOk).snd.snd) | ~q(Except.error $err) => do - let err ← unsafe evalExpr TransformError q(TransformError) err + let err ← unsafe evalExpr (TransformError Op) q(TransformError Op) err throwError "Translation failed with error:\n\t{repr err}" | e => throwError "Translation failed to reduce, possibly too generic syntax\n\t{e}" -macro "[mlir_icom| " reg:mlir_region "]" : term => `([mlir_icom ()| $reg]) \ No newline at end of file +macro "[alive_icom| " reg:mlir_region "]" : term => `([alive_icom ()| $reg]) \ No newline at end of file diff --git a/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean b/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean index 878d70051..48d278dce 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean @@ -17,14 +17,14 @@ set_option linter.unusedVariables false -- linter gives a false positive for `φ variable (Op) (φ) {Ty} {MOp MTy} [OpSignature Op Ty] [AST.TransformDialect MOp MTy φ] - [inst : TransformDialectInstantiate Op φ Ty MOp MTy] + [instInst : TransformDialectInstantiate Op φ Ty MOp MTy] [DecidableEq MTy] def AST.mkComInstantiate (reg : AST.Region φ) : AST.ExceptM MOp (Vector Nat φ → Σ (Γ : Ctxt Ty) (ty : Ty), ICom Op Γ ty) := do let ⟨Γ, ty, icom⟩ ← AST.mkCom MOp reg return fun vals => - let f := inst.morphism vals + let f := instInst.morphism vals ⟨Γ.map f.mapTy, f.mapTy ty, icom.map f⟩ end MLIR \ No newline at end of file From 2736a725a49dec442b12bc96d051adb7adfac958 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 15 Nov 2023 01:08:44 +0000 Subject: [PATCH 05/28] wip: fixup translation This does not yet work, in the file `AliveAutoGenerated.lean`, we see the error: Translation failed to reduce, possibly too generic syntax ``` . 'error Except.ok { fst := [InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w), InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w), InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w)], snd := { fst := InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w), snd := Com.lete (Expr.mk (InstCombine.MOp.not (ConcreteOrMVar.concrete w)) _ (HVector.cons { val := 2, property := _ } HVector.nil) HVector.nil) ... } ``` which is very bizarre, because it sees that the head term is `Except.ok`! --- SSA/Projects/InstCombine/Alive.lean | 4 +- .../InstCombine/AliveAutoGenerated.lean | 12 +- SSA/Projects/InstCombine/LLVM/EDSL.lean | 10 +- SSA/Projects/InstCombine/LLVM/Transform.lean | 183 +++--- .../LLVM/Transform/Dialects/InstCombine.lean | 524 +++++++++--------- .../LLVM/Transform/Instantiate.lean | 11 +- .../LLVM/Transform/TransformError.lean | 12 +- SSA/Projects/InstCombine/Refinement.lean | 3 +- 8 files changed, 357 insertions(+), 402 deletions(-) diff --git a/SSA/Projects/InstCombine/Alive.lean b/SSA/Projects/InstCombine/Alive.lean index ff5b4ce28..d60c7a456 100644 --- a/SSA/Projects/InstCombine/Alive.lean +++ b/SSA/Projects/InstCombine/Alive.lean @@ -2,7 +2,9 @@ -- Auto generated alive statements -- Do not import these into the default tree, because they take a long -- time to elaborate and typecheck. --- import SSA.Projects.InstCombine.AliveAutoGenerated +-- Import this by default, since the file has a `#exit` that checks the first two +-- theorems. This ensures that the file (and transitively, its dependencies) are not broken. +import SSA.Projects.InstCombine.AliveAutoGenerated -- Pure math statements needed to prove alive statements. -- Include these, as they are reasonably fast to typecheck. diff --git a/SSA/Projects/InstCombine/AliveAutoGenerated.lean b/SSA/Projects/InstCombine/AliveAutoGenerated.lean index a89a157fc..88e59ff74 100644 --- a/SSA/Projects/InstCombine/AliveAutoGenerated.lean +++ b/SSA/Projects/InstCombine/AliveAutoGenerated.lean @@ -12,8 +12,6 @@ namespace AliveAutoGenerated set_option pp.proofs false set_option pp.proofs.withType false - - -- Name:AddSub:1043 -- precondition: true /- @@ -31,7 +29,7 @@ set_option pp.proofs.withType false -/ def alive_AddSub_1043_src (w : Nat) := -[mlir_icom ( w )| { +[alive_icom ( w )| { ^bb0(%C1 : _, %Z : _, %RHS : _): %v1 = "llvm.and" (%Z,%C1) : (_, _) -> (_) %v2 = "llvm.xor" (%v1,%C1) : (_, _) -> (_) @@ -42,7 +40,7 @@ def alive_AddSub_1043_src (w : Nat) := }] def alive_AddSub_1043_tgt (w : Nat) := -[mlir_icom ( w )| { +[alive_icom ( w )| { ^bb0(%C1 : _, %Z : _, %RHS : _): %v1 = "llvm.not" (%C1) : (_) -> (_) %v2 = "llvm.or" (%Z,%v1) : (_, _) -> (_) @@ -59,6 +57,12 @@ theorem alive_AddSub_1043 (w : Nat) : alive_AddSub_1043_src w ⊑ alive_AddS apply bitvec_AddSub_1043 +/-# Early Exit +delete this to check the rest of the file. The `#exit` + is an early exit to allow our `lake` builds to complete in sensible amounts of time. +-/ +#exit + -- Name:AddSub:1152 -- precondition: true /- diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 559e7c0c0..8088d4fcf 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -48,18 +48,18 @@ elab "[alive_icom (" mvars:term,* ")| " reg:mlir_region "]" : term => do let ast ← elabTermEnsuringTypeQ ast_stx q(Region $φ) let mvalues ← `(⟨[$mvars,*], by rfl⟩) let mvalues : Q(Vector Nat $φ) ← elabTermEnsuringType mvalues q(Vector Nat $φ) - let com := q(mkComInstantiate Op $φ $ast |>.map (· $mvalues)) + let com := q(mkComInstantiate (φ := $φ) $ast |>.map (· $mvalues)) synthesizeSyntheticMVarsNoPostponing - let com : Q(ExceptM Op (Σ (Γ' : Ctxt Ty) (ty : Ty), Com Op Γ' ty)) ← + let com : Q(MLIR.AST.ExceptM Op (Σ (Γ' : Ctxt Ty) (ty : Ty), Com Op Γ' ty)) ← withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do withTransparency (mode := TransparencyMode.all) <| - return ←reduce com + return ←reduceAll com trace[Meta] com match com with | ~q(Except.ok $comOk) => return q(($comOk).snd.snd) | ~q(Except.error $err) => do let err ← unsafe evalExpr (TransformError Op) q(TransformError Op) err - throwError "Translation failed with error:\n\t{repr err}" - | e => throwError "Translation failed to reduce, possibly too generic syntax\n\t{e}" + throwError "Translation failed with error:\n\t'com: {repr com}',\n\t'error {repr err}:" + | e => throwError "Translation failed to reduce, possibly too generic syntax\n\t.'error:\n\t {e}'" macro "[alive_icom| " reg:mlir_region "]" : term => `([alive_icom ()| $reg]) diff --git a/SSA/Projects/InstCombine/LLVM/Transform.lean b/SSA/Projects/InstCombine/LLVM/Transform.lean index 271d0f9fb..f13a0876a 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform.lean @@ -14,67 +14,20 @@ namespace MLIR.AST open InstCombine (MOp MTy Width) open Std (BitVec) -abbrev Context (φ) := List (MTy φ) +-- abbrev Context (φ) := List (MTy φ) -abbrev Expr (Γ : Context φ) (ty : MTy φ) := _root_.Expr (MOp φ) Γ ty -abbrev Com (Γ : Context φ) (ty : MTy φ) := _root_.Com (MOp φ) Γ ty -abbrev Var (Γ : Context φ) (ty : MTy φ) := _root_.Ctxt.Var Γ ty +-- abbrev Expr (Γ : Context φ) (ty : MTy φ) := _root_.Expr (MOp φ) Γ ty +-- abbrev Com (Γ : Context φ) (ty : MTy φ) := _root_.Com (MOp φ) Γ ty +-- abbrev Var (Γ : Context φ) (ty : MTy φ) := _root_.Ctxt.Var Γ ty -abbrev Com.lete (body : Expr Γ ty₁) (rest : Com (ty₁::Γ) ty₂) : Com Γ ty₂ := - _root_.Com.lete body rest +-- abbrev Com.lete (body : Expr Γ ty₁) (rest : Com (ty₁::Γ) ty₂) : Com Γ ty₂ := +-- _root_.Com.lete body rest -inductive TransformError - | nameAlreadyDeclared (var : String) - | undeclaredName (var : String) - | indexOutOfBounds (name : String) (index len : Nat) - | typeError {φ} (expected got : MTy φ) - | widthError {φ} (expected got : Width φ) - | unsupportedUnaryOp - | unsupportedBinaryOp (name : String) - | unsupportedOp (error : String) - | unsupportedType - | generic (error : String) - -namespace TransformError +-- namespace TransformError instance : Lean.ToFormat (MTy φ) where format := repr -instance : Repr TransformError where - reprPrec err _ := match err with - | nameAlreadyDeclared var => f!"Already declared {var}, shadowing is not allowed" - | undeclaredName name => f!"Undeclared name '{name}'" - | indexOutOfBounds name index len => - f!"Index of '{name}' out of bounds of the given context (index was {index}, but context has length {len})" - | typeError expected got => f!"Type mismatch: expected '{expected}', but 'name' has type '{got}'" - | widthError expected got => f!"Type mismatch: {expected} ≠ {got}" - | unsupportedUnaryOp => f!"Unsuported unary operation" - | unsupportedBinaryOp name => f!"Unsuported binary operation {name}" - | unsupportedOp err => f!"Unsuported operation '{err}'" - | unsupportedType => f!"Unsuported type" - | generic err => err - -end TransformError - -/-- -Store the names of the raw SSA variables (as strings). -The order in the list should match the order in which they appear in the code. --/ -abbrev NameMapping := List String - -def NameMapping.lookup (nm : NameMapping) (name : String) : Option Nat := - nm.indexOf? name - -/-- - Add a new name to the mapping, assuming the name is not present in the list yet. - If the name is already present, return `none` --/ -def NameMapping.add (nm : NameMapping) (name : String) : Option NameMapping := - match nm.lookup name with - | none => some <| name::nm - | some _ => none - -instance : MonadLift ReaderM BuilderM where section Monads /-! @@ -84,19 +37,21 @@ section Monads errors. -/ -abbrev ExceptM (Op) {Ty} [OpSignature Op Ty] := Except (TransformError Ty) -abbrev BuilderM (Op) {Ty} [OpSignature Op Ty] := StateT NameMapping (ExceptM Op) -abbrev ReaderM (Op) {Ty} [OpSignature Op Ty] := ReaderT NameMapping (ExceptM Op) +abbrev ExceptM (Op) [OpSignature Op Ty] := Except (TransformError Ty) +abbrev BuilderM (Op) [OpSignature Op Ty] := StateT NameMapping (ExceptM Op) +abbrev ReaderM (Op) [OpSignature Op Ty] := ReaderT NameMapping (ExceptM Op) + +-- instance : MonadLift ReaderM BuilderM where -variable {Op Ty} [OpSignature Op Ty] +-- variable {Op Ty} [OpSignature Op Ty] -instance : MonadLift (ReaderM Op) (BuilderM Op) where +instance {Op : Type} [OpSignature Op Ty] : MonadLift (ReaderM Op) (BuilderM Op) where monadLift x := do (ReaderT.run x (←get) : ExceptM ..) -instance : MonadLift (ExceptM Op) (ReaderM Op) where +instance {Op : Type} [OpSignature Op Ty] : MonadLift (ExceptM Op) (ReaderM Op) where monadLift x := do return ←x -def BuilderM.runWithNewMapping (k : BuilderM Op α) : ExceptM Op α := +def BuilderM.runWithNewMapping {Op : Type} [OpSignature Op Ty] (k : BuilderM Op α) : ExceptM Op α := Prod.fst <$> StateT.run k [] end Monads @@ -104,24 +59,23 @@ end Monads class TransformDialect (Op : Type) (Ty : outParam (Type)) (φ : outParam Nat) extends OpSignature Op Ty where mkType : MLIRType φ → ExceptM Op Ty mkReturn : (Γ : List Ty) → (opStx : AST.Op φ) → (args : List (Σ (ty : Ty), Ctxt.Var Γ ty)) - → ReaderM Op (Σ ty, ICom Op Γ ty) + → ReaderM Op (Σ ty, _root_.Com Op Γ ty) mkExpr : (Γ : List Ty) → (opStx : AST.Op φ) → (args : List (Σ (ty : Ty), Ctxt.Var Γ ty)) - → ReaderM Op (Σ ty, IExpr Op Γ ty) + → ReaderM Op (Σ ty, _root_.Expr Op Γ ty) -variable (Op) {Ty φ} [d : TransformDialect Op Ty φ] -abbrev Context (Ty) := List (Ty) - -abbrev Expr (Γ : Context Ty) (ty : Ty) := IExpr Op Γ ty -abbrev Com (Γ : Context Ty) (ty : Ty) := ICom Op Γ ty -abbrev Var (Γ : Context Ty) (ty : Ty) := Ctxt.Var Γ ty +variable {Op Ty φ} [OpSignature Op Ty] +abbrev Context (Ty) := List (Ty) -variable {Op} [d : TransformDialect Op Ty φ] [DecidableEq Ty] +-- abbrev Expr (Γ : Context Ty) (ty : Ty) := _root_.Expr Op Γ ty +-- abbrev Com (Γ : Context Ty) (ty : Ty) := _root_.Com Op Γ ty +-- abbrev Var (Γ : Context Ty) (ty : Ty) := Ctxt.Var Γ ty +-- variable [d : TransformDialect Op Ty φ] [DecidableEq Ty] -abbrev Com.lete (body : Expr Op Γ ty₁) (rest : Com Op (ty₁::Γ) ty₂) : Com Op Γ ty₂ := - ICom.lete body rest +-- abbrev Com.lete {Γ : Context Ty} (body : Expr Γ ty₁) (rest : Com Op (ty₁::Γ) ty₂) : Com Op Γ ty₂ := +-- _root_.Com.lete body rest @@ -150,10 +104,10 @@ instance {Γ : Context Ty} {Γ' : DerivedContext Γ} : CoeHead (DerivedContext (Γ' : Context Ty)) (DerivedContext Γ) where coe := fun ⟨Γ'', diff⟩ => ⟨Γ'', Γ'.diff + diff⟩ -instance {Γ : Context Ty} {Γ' : DerivedContext Γ} : Coe (Expr Op Γ t) (Expr Op Γ' t) where +instance {Γ : Context Ty} {Γ' : DerivedContext Γ} : Coe (Expr Op Γ t) (Expr Op Γ'.ctxt t) where coe e := e.changeVars Γ'.diff.toHom -instance {Γ' : DerivedContext Γ} : Coe (Var Γ t) (Var (Γ' : Context Ty) t) where +instance {Γ' : DerivedContext Γ} : Coe (Ctxt.Var Γ t) (Ctxt.Var (Γ' : Context Ty) t) where coe v := Γ'.diff.toHom v end DerivedContext @@ -164,8 +118,8 @@ end DerivedContext Throws an error if the variable name already exists in the mapping, essentially disallowing shadowing -/ -def addValToMapping (Γ : Context φ) (name : String) (ty : MTy φ) : - BuilderM (Σ (Γ' : DerivedContext Γ), Var Γ' ty) := do +def addValToMapping (Γ : Context (MTy φ)) (name : String) (ty : MTy φ) : + BuilderM (MOp φ) (Σ (Γ' : DerivedContext Γ), Ctxt.Var Γ'.ctxt ty) := do let some nm := (←get).add name | throw <| .nameAlreadyDeclared name set nm @@ -177,8 +131,8 @@ def addValToMapping (Γ : Context φ) (name : String) (ty : MTy φ) : Throws an error if the name is not present in the mapping (this indicates the name may be free), or if the type of the variable in the context is different from `expectedType` -/ -def getValFromContext (Γ : Context φ) (name : String) (expectedType : MTy φ) : - ReaderM (Ctxt.Var Γ expectedType) := do +def getValFromContext (Γ : Context (MTy φ)) (name : String) (expectedType : MTy φ) : + ReaderM (MOp φ) (Ctxt.Var Γ expectedType) := do let index := (←read).lookup name let some index := index | throw <| .undeclaredName name let n := Γ.length @@ -193,18 +147,19 @@ def getValFromContext (Γ : Context φ) (name : String) (expectedType : MTy φ) else throw <| .typeError expectedType t -def BuilderM.isOk {α : Type} (x : BuilderM α) : Bool := +def BuilderM.isOk {α : Type} (x : BuilderM Op α) : Bool := match x.run [] with | Except.ok _ => true | Except.error _ => false -def BuilderM.isErr {α : Type} (x : BuilderM α) : Bool := +def BuilderM.isErr {α : Type} (x : BuilderM Op α) : Bool := match x.run [] with | Except.ok _ => true | Except.error _ => false -def mkUnaryOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) - (e : Var Γ ty) : ExceptM <| Expr Γ ty := +#check Ctxt.Var +def mkUnaryOp {Γ : Ctxt (MTy φ)} {ty : MTy φ} (op : MOp φ) + (e : Ctxt.Var Γ ty) : ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := match ty with | .bitvec w => match op with @@ -235,8 +190,8 @@ def mkUnaryOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) else throw <| .widthError w w' | _ => throw <| .unsupportedUnaryOp -def mkBinOp {Γ : Ctxt _} {ty : MTy φ} (op : MOp φ) - (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM <| Expr Γ ty := +def mkBinOp {Γ : Ctxt (MTy φ)} {ty : MTy φ} (op : MOp φ) + (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := match ty with | .bitvec w => match op with @@ -348,7 +303,7 @@ def mkBinOp {Γ : Ctxt _} {ty : MTy φ} (op : MOp φ) | op => throw <| .unsupportedBinaryOp s!"unsupported binary operation {op}" def mkIcmp {Γ : Ctxt _} {ty : MTy φ} (op : MOp φ) - (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM <| Expr Γ (.bitvec 1) := + (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM (MOp φ) <| Expr (MOp φ) Γ (.bitvec 1) := match ty with | .bitvec w => match op with @@ -362,9 +317,9 @@ def mkIcmp {Γ : Ctxt _} {ty : MTy φ} (op : MOp φ) else throw <| .widthError w w' | _ => throw <| .unsupportedOp "unsupported icmp operation" -def mkSelect {Γ : Context φ} {ty : MTy φ} (op : MOp φ) - (c : Var Γ (.bitvec 1)) (e₁ e₂ : Var Γ ty) : - ExceptM <| Expr Γ ty := +def mkSelect {Γ : Ctxt (MTy φ)} {ty : MTy φ} (op : MOp φ) + (c : Ctxt.Var Γ (.bitvec 1)) (e₁ e₂ : Ctxt.Var Γ ty) : + ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := match ty with | .bitvec w => match op with @@ -378,9 +333,9 @@ def mkSelect {Γ : Context φ} {ty : MTy φ} (op : MOp φ) else throw <| .widthError w w' | _ => throw <| .unsupportedOp "Unsupported select operation" -def mkOpExpr {Γ : Context φ} (op : MOp φ) +def mkOpExpr {Γ : Context (MTy φ)} (op : MOp φ) (arg : HVector (fun t => Ctxt.Var Γ t) (OpSignature.sig op)) : - ExceptM <| Expr Γ (OpSignature.outTy op) := + ExceptM (MOp φ) <| Expr (MOp φ) Γ (OpSignature.outTy op) := match op with | .and _ | .or _ | .xor _ | .shl _ | .lshr _ | .ashr _ | .add _ | .mul _ | .sub _ | .udiv _ | .sdiv _ @@ -397,36 +352,36 @@ def mkOpExpr {Γ : Context φ} (op : MOp φ) mkSelect op c e₁ e₂ | .const .. => throw <| .unsupportedOp "Tried to build Op expression from constant" -def MLIRType.mkTy : MLIRType φ → ExceptM (MTy φ) +def MLIRType.mkTy : MLIRType φ → ExceptM (MOp φ) (MTy φ) | MLIRType.int Signedness.Signless w => do return .bitvec w | _ => throw .unsupportedType -- "Unsupported type" -def TypedSSAVal.mkTy : TypedSSAVal φ → ExceptM (MTy φ) - | (.SSAVal _, ty) => ty.mkTy +def TypedSSAVal.mkTy : TypedSSAVal φ → ExceptM (MOp φ) (MTy φ) + | (.SSAVal _, ty) => MLIRType.mkTy ty def mkVal (ty : InstCombine.Ty) : Int → BitVec ty.width | val => BitVec.ofInt ty.width val /-- Translate a `TypedSSAVal` (a name with an expected type), to a variable in the context. This expects the name to have already been declared before -/ -def TypedSSAVal.mkVal (Γ : Context φ) : TypedSSAVal φ → - ReaderM (Σ (ty : MTy φ), Var Γ ty) +def TypedSSAVal.mkVal (Γ : Context (MTy φ)) : TypedSSAVal φ → + ReaderM (MOp φ) (Σ (ty : MTy φ), Ctxt.Var Γ ty) | (.SSAVal valStx, tyStx) => do - let ty ← (tyStx.mkTy : ExceptM Op ..) + let ty ← MLIRType.mkTy tyStx let var ← getValFromContext Γ valStx ty return ⟨ty, var⟩ /-- Declare a new variable, by adding the passed name to the name mapping stored in the monad state -/ -def TypedSSAVal.newVal (Γ : Context φ) : TypedSSAVal φ → - BuilderM (Σ (Γ' : DerivedContext Γ) (ty : MTy φ), Var Γ' ty) +def TypedSSAVal.newVal (Γ : Context (MTy φ)) : TypedSSAVal φ → + BuilderM (MOp φ) (Σ (Γ' : DerivedContext Γ) (ty : MTy φ), Ctxt.Var Γ'.ctxt ty) | (.SSAVal valStx, tyStx) => do - let ty ← (tyStx.mkTy : ExceptM Op ..) + let ty ← MLIRType.mkTy tyStx let ⟨Γ, var⟩ ← addValToMapping Γ valStx ty return ⟨Γ, ty, var⟩ -def mkExpr (Γ : Context φ) (opStx : Op φ) : ReaderM (Σ ty, Expr Γ ty) := do +def mkExpr (Γ : Context (MTy φ)) (opStx : MLIR.AST.Op φ) : ReaderM (MOp φ) (Σ ty, Expr (MOp φ) Γ ty) := do match opStx.args with | v₁Stx::v₂Stx::v₃Stx::[] => let ⟨.bitvec w₁, v₁⟩ ← TypedSSAVal.mkVal Γ v₁Stx @@ -474,13 +429,13 @@ def mkExpr (Γ : Context φ) (opStx : Op φ) : ReaderM (Σ ty, Expr Γ ty) := do match op with | .icmp .. => if hty : w₁ = w₂ then - let icmpOp ← (mkIcmp op v₁ (hty ▸ v₂) : ExceptM _) + let icmpOp ← mkIcmp op v₁ (hty ▸ v₂) return ⟨.bitvec 1, icmpOp⟩ else throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" | _ => if hty : w₁ = w₂ then - let binOp ← (mkBinOp op v₁ (hty ▸ v₂) : ExceptM _) + let binOp ← mkBinOp op v₁ (hty ▸ v₂) return ⟨.bitvec w₁, binOp⟩ else throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" @@ -500,7 +455,7 @@ def mkExpr (Γ : Context φ) (opStx : Op φ) : ReaderM (Σ ty, Expr Γ ty) := do | throw <| .generic "tried to resolve constant without 'value' attribute" match att with | .int val ty => - let opTy@(MTy.bitvec w) ← ty.mkTy + let opTy@(MTy.bitvec w) ← MLIRType.mkTy ty -- ty.mkTy return ⟨opTy, ⟨ MOp.const w val, by simp [OpSignature.outTy, signature, *], @@ -512,29 +467,29 @@ def mkExpr (Γ : Context φ) (opStx : Op φ) : ReaderM (Σ ty, Expr Γ ty) := do throw <| .generic s!"invalid (0-ary) expression {opStx.name}" | _ => throw <| .generic s!"unsupported expression (with unsupported arity) {opStx.name}" -def mkReturn (Γ : Context φ) (opStx : Op φ) : ReaderM (Σ ty, Com Γ ty) := +def mkReturn (Γ : Context (MTy φ)) (opStx : MLIR.AST.Op φ) : ReaderM (MOp φ) (Σ ty, Com (MOp φ) Γ ty) := if opStx.name == "llvm.return" then match opStx.args with | vStx::[] => do - let ⟨ty, v⟩ ← vStx.mkVal Γ + let ⟨ty, v⟩ ← TypedSSAVal.mkVal Γ vStx return ⟨ty, _root_.Com.ret v⟩ | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {opStx.args.length})" else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" /-- Given a list of `TypedSSAVal`s, treat each as a binder and declare a new variable with the given name and type -/ -private def declareBindings (Γ : Context φ) (vals : List (TypedSSAVal φ)) : - BuilderM (DerivedContext Γ) := do +private def declareBindings (Γ : Context (MTy φ)) (vals : List (TypedSSAVal φ)) : + BuilderM (MOp φ) (DerivedContext Γ) := do vals.foldlM (fun Γ' ssaVal => do - let ⟨Γ'', _⟩ ← TypedSSAVal.newVal Γ' ssaVal + let ⟨Γ'', _⟩ ← TypedSSAVal.newVal Γ'.ctxt ssaVal return Γ'' ) (.ofContext Γ) -private def mkComHelper (Γ : Context φ) : - List (Op φ) → BuilderM (Σ (ty : _), Com Γ ty) +private def mkComHelper (Γ : Context (MTy φ)) : + List (MLIR.AST.Op φ) → BuilderM (MOp φ) (Σ (ty : _), Com (MOp φ) Γ ty) | [retStx] => mkReturn Γ retStx | lete::rest => do - let ⟨ty₁, expr⟩ ← (mkExpr Γ lete : ReaderM Op ..) + let ⟨ty₁, expr⟩ ← (mkExpr Γ lete) if h : lete.res.length != 1 then throw <| .generic s!"Each let-binding must have exactly one name on the left-hand side. Operations with multiple, or no, results are not yet supported.\n\tExpected a list of length one, found `{repr lete}`" else @@ -543,7 +498,7 @@ private def mkComHelper (Γ : Context φ) : return ⟨ty₂, Com.lete expr body⟩ | [] => throw <| .generic "Ill-formed (empty) block" -def mkCom (reg : Region φ) : ExceptM (Σ (Γ : Context φ) (ty : MTy φ), Com Γ ty) := +def mkCom (reg : MLIR.AST.Region φ) : ExceptM (MOp φ) (Σ (Γ : Context (MTy φ)) (ty : MTy φ), Com (MOp φ) Γ ty) := match reg.ops with | [] => throw <| .generic "Ill-formed region (empty)" | coms => BuilderM.runWithNewMapping <| do @@ -580,7 +535,7 @@ def _root_.InstCombine.MOp.instantiate (vals : Vector Nat φ) : MOp φ → InstC | .icmp c w => .icmp c (w.instantiate vals) | .const w val => .const (w.instantiate vals) val -def Context.instantiate (vals : Vector Nat φ) (Γ : Context φ) : Ctxt InstCombine.Ty := +def Context.instantiate (vals : Vector Nat φ) (Γ : Context (MTy φ)) : Ctxt InstCombine.Ty := Γ.map (MTy.instantiate vals) def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCombine.Op) where @@ -595,7 +550,7 @@ def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCo open InstCombine (Op Ty) in def mkComInstantiate (reg : Region φ) : - ExceptM (Vector Nat φ → Σ (Γ : Ctxt Ty) (ty : Ty), _root_.Com InstCombine.Op Γ ty) := do + ExceptM (MOp φ) (Vector Nat φ → Σ (Γ : Ctxt InstCombine.Ty) (ty : InstCombine.Ty), Com InstCombine.Op Γ ty) := do let ⟨Γ, ty, com⟩ ← mkCom reg return fun vals => ⟨Γ.instantiate vals, ty.instantiate vals, com.map (MOp.instantiateCom vals)⟩ diff --git a/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean b/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean index d4da4077d..65aa180d5 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean @@ -12,7 +12,7 @@ protected abbrev ReaderM (φ) := AST.ReaderM (MOp φ) protected abbrev BuilderM (φ) := AST.BuilderM (MOp φ) protected abbrev Context (φ) := List (MTy φ) -protected abbrev Expr {φ} := IExpr (MOp φ) +protected abbrev Expr {φ} := Expr (MOp φ) open InstCombine (ExceptM ReaderM BuilderM Context Expr) @@ -26,7 +26,7 @@ def mkUnaryOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) | .bitvec w => match op with -- Can't use a single arm, Lean won't write the rhs accordingly - | .neg w' => if h : w = w' + | .neg w' => if h : w = w' then return ⟨ .neg w', by simp [OpSignature.outTy, signature, h], @@ -34,15 +34,15 @@ def mkUnaryOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) .nil ⟩ else throw <| .widthError w w' - | .not w' => if h : w = w' + | .not w' => if h : w = w' then return ⟨ .not w', by simp [OpSignature.outTy, signature, h], .cons (h ▸ e) .nil, .nil - ⟩ + ⟩ else throw <| .widthError w w' - | .copy w' => if h : w = w' + | .copy w' => if h : w = w' then return ⟨ .copy w', by simp [OpSignature.outTy, signature, h], @@ -52,272 +52,272 @@ def mkUnaryOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) else throw <| .widthError w w' | _ => throw <| .unsupportedUnaryOp -def mkBinOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) - (e₁ e₂ : Var Γ ty) : ExceptM φ <| Expr Γ ty := - match ty with - | .bitvec w => - match op with - -- Can't use a single arm, Lean won't write the rhs accordingly - | .add w' => if h : w = w' - then return ⟨ - .add w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .and w' => if h : w = w' - then return ⟨ - .and w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .or w' => if h : w = w' - then return ⟨ - .or w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .xor w' => if h : w = w' - then return ⟨ - .xor w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .shl w' => if h : w = w' - then return ⟨ - .shl w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .lshr w' => if h : w = w' - then return ⟨ - .lshr w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .ashr w' => if h : w = w' - then return ⟨ - .ashr w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .urem w' => if h : w = w' - then return ⟨ - .urem w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .srem w' => if h : w = w' - then return ⟨ - .srem w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .mul w' => if h : w = w' - then return ⟨ - .mul w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .sub w' => if h : w = w' - then return ⟨ - .sub w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .sdiv w' => if h : w = w' - then return ⟨ - .sdiv w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .udiv w' => if h : w = w' - then return ⟨ - .udiv w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw <| .unsupportedBinaryOp +-- def mkBinOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) +-- (e₁ e₂ : Var Γ ty) : ExceptM φ <| Expr Γ ty := +-- match ty with +-- | .bitvec w => +-- match op with +-- -- Can't use a single arm, Lean won't write the rhs accordingly +-- | .add w' => if h : w = w' +-- then return ⟨ +-- .add w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .and w' => if h : w = w' +-- then return ⟨ +-- .and w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .or w' => if h : w = w' +-- then return ⟨ +-- .or w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .xor w' => if h : w = w' +-- then return ⟨ +-- .xor w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .shl w' => if h : w = w' +-- then return ⟨ +-- .shl w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .lshr w' => if h : w = w' +-- then return ⟨ +-- .lshr w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .ashr w' => if h : w = w' +-- then return ⟨ +-- .ashr w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .urem w' => if h : w = w' +-- then return ⟨ +-- .urem w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .srem w' => if h : w = w' +-- then return ⟨ +-- .srem w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .mul w' => if h : w = w' +-- then return ⟨ +-- .mul w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .sub w' => if h : w = w' +-- then return ⟨ +-- .sub w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .sdiv w' => if h : w = w' +-- then return ⟨ +-- .sdiv w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | .udiv w' => if h : w = w' +-- then return ⟨ +-- .udiv w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | _ => throw <| .unsupportedBinaryOp -def mkIcmp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) - (e₁ e₂ : Var Γ ty) : ExceptM φ <| Expr Γ (.bitvec 1) := - match ty with - | .bitvec w => - match op with - | .icmp p w' => if h : w = w' - then return ⟨ - .icmp p w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil -⟩ - else throw <| .widthError w w' - | _ => throw .unsupportedOp -- unsupported icmp operation +-- def mkIcmp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) +-- (e₁ e₂ : Var Γ ty) : ExceptM φ <| Expr Γ (.bitvec 1) := +-- match ty with +-- | .bitvec w => +-- match op with +-- | .icmp p w' => if h : w = w' +-- then return ⟨ +-- .icmp p w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | _ => throw .unsupportedOp -- unsupported icmp operation -def mkSelect {Γ : Context φ} {ty : MTy φ} (op : MOp φ) - (c : Var Γ (.bitvec 1)) (e₁ e₂ : Var Γ ty) : - ExceptM φ <| Expr Γ ty := - match ty with - | .bitvec w => - match op with - | .select w' => if h : w = w' - then return ⟨ - .select w', - by simp [OpSignature.outTy, signature, h], - .cons c <|.cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw .unsupportedOp -- "Unsupported select operation" +-- def mkSelect {Γ : Context φ} {ty : MTy φ} (op : MOp φ) +-- (c : Var Γ (.bitvec 1)) (e₁ e₂ : Var Γ ty) : +-- ExceptM φ <| Expr Γ ty := +-- match ty with +-- | .bitvec w => +-- match op with +-- | .select w' => if h : w = w' +-- then return ⟨ +-- .select w', +-- by simp [OpSignature.outTy, signature, h], +-- .cons c <|.cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , +-- .nil +-- ⟩ +-- else throw <| .widthError w w' +-- | _ => throw .unsupportedOp -- "Unsupported select operation" -def mkOpExpr {Γ : Context φ} (op : MOp φ) - (arg : HVector (fun t => Ctxt.Var Γ t) (OpSignature.sig op)) : - ExceptM φ <| Expr Γ (OpSignature.outTy op) := - match op with - | .and _ | .or _ | .xor _ | .shl _ | .lshr _ | .ashr _ - | .add _ | .mul _ | .sub _ | .udiv _ | .sdiv _ - | .srem _ | .urem _ => - let (e₁, e₂) := arg.toTuple - mkBinOp op e₁ e₂ - | .icmp _ _ => - let (e₁, e₂) := arg.toTuple - mkIcmp op e₁ e₂ - | .not _ | .neg _ | .copy _ => - mkUnaryOp op arg.head - | .select _ => - let (c, e₁, e₂) := arg.toTuple - mkSelect op c e₁ e₂ - | .const .. => throw .unsupportedOp -- "Tried to build Op expression from constant" +-- def mkOpExpr {Γ : Context φ} (op : MOp φ) +-- (arg : HVector (fun t => Ctxt.Var Γ t) (OpSignature.sig op)) : +-- ExceptM φ <| Expr Γ (OpSignature.outTy op) := +-- match op with +-- | .and _ | .or _ | .xor _ | .shl _ | .lshr _ | .ashr _ +-- | .add _ | .mul _ | .sub _ | .udiv _ | .sdiv _ +-- | .srem _ | .urem _ => +-- let (e₁, e₂) := arg.toTuple +-- mkBinOp op e₁ e₂ +-- | .icmp _ _ => +-- let (e₁, e₂) := arg.toTuple +-- mkIcmp op e₁ e₂ +-- | .not _ | .neg _ | .copy _ => +-- mkUnaryOp op arg.head +-- | .select _ => +-- let (c, e₁, e₂) := arg.toTuple +-- mkSelect op c e₁ e₂ +-- | .const .. => throw .unsupportedOp -- "Tried to build Op expression from constant" -def mkExpr (Γ : Context φ) (opStx : AST.Op φ) (args : List (Σ (ty : MTy φ), Ctxt.Var Γ ty)) : - ReaderM φ (Σ ty, Expr Γ ty) := do - match args with - | ⟨.bitvec w₁, v₁⟩::⟨.bitvec w₂, v₂⟩::[] => - -- let ty₁ := ty₁.instantiave - let op ← match opStx.name with - | "llvm.and" => pure (MOp.and w₁) - | "llvm.or" => pure (MOp.or w₁) - | "llvm.xor" => pure (MOp.xor w₁) - | "llvm.shl" => pure (MOp.shl w₁) - | "llvm.lshr" => pure (MOp.lshr w₁) - | "llvm.ashr" => pure (MOp.ashr w₁) - | "llvm.urem" => pure (MOp.urem w₁) - | "llvm.srem" => pure (MOp.srem w₁) - | "llvm.select" => pure (MOp.select w₁) - | "llvm.add" => pure (MOp.add w₁) - | "llvm.mul" => pure (MOp.mul w₁) - | "llvm.sub" => pure (MOp.sub w₁) - | "llvm.sdiv" => pure (MOp.sdiv w₁) - | "llvm.udiv" => pure (MOp.udiv w₁) - --| "llvm.icmp" => return InstCombine.Op.icmp v₁.width - | _ => throw .unsupportedOp -- "Unsuported operation or invalid arguments" - if hty : w₁ = w₂ then - let binOp ← (mkBinOp op v₁ (hty ▸ v₂) : ExceptM ..) - return ⟨.bitvec w₁, binOp⟩ - else - throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" - | ⟨.bitvec w, v⟩::[] => - let op ← match opStx.name with - | "llvm.not" => pure <| MOp.not w - | "llvm.neg" => pure <| MOp.neg w - | _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}" - let op ← mkUnaryOp op v - return ⟨.bitvec w, op⟩ - | [] => - if opStx.name == "llvm.mlir.constant" - then do - let some att := opStx.attrs.getAttr "value" - | throw <| .generic "tried to resolve constant without 'value' attribute" - match att with - | .int val ty => - let opTy@(MTy.bitvec w) ← mkType ty - return ⟨opTy, ⟨ - MOp.const w val, - by simp [OpSignature.outTy, signature, *], - HVector.nil, - HVector.nil - ⟩⟩ - | _ => throw <| .generic "invalid constant attribute" - else - throw <| .generic s!"invalid (0-ary) expression {opStx.name}" - | _ => throw <| .generic s!"unsupported expression (with unsupported arity) {opStx.name}" +-- def mkExpr (Γ : Context φ) (opStx : AST.Op φ) (args : List (Σ (ty : MTy φ), Ctxt.Var Γ ty)) : +-- ReaderM φ (Σ ty, Expr Γ ty) := do +-- match args with +-- | ⟨.bitvec w₁, v₁⟩::⟨.bitvec w₂, v₂⟩::[] => +-- -- let ty₁ := ty₁.instantiave +-- let op ← match opStx.name with +-- | "llvm.and" => pure (MOp.and w₁) +-- | "llvm.or" => pure (MOp.or w₁) +-- | "llvm.xor" => pure (MOp.xor w₁) +-- | "llvm.shl" => pure (MOp.shl w₁) +-- | "llvm.lshr" => pure (MOp.lshr w₁) +-- | "llvm.ashr" => pure (MOp.ashr w₁) +-- | "llvm.urem" => pure (MOp.urem w₁) +-- | "llvm.srem" => pure (MOp.srem w₁) +-- | "llvm.select" => pure (MOp.select w₁) +-- | "llvm.add" => pure (MOp.add w₁) +-- | "llvm.mul" => pure (MOp.mul w₁) +-- | "llvm.sub" => pure (MOp.sub w₁) +-- | "llvm.sdiv" => pure (MOp.sdiv w₁) +-- | "llvm.udiv" => pure (MOp.udiv w₁) +-- --| "llvm.icmp" => return InstCombine.Op.icmp v₁.width +-- | _ => throw .unsupportedOp -- "Unsuported operation or invalid arguments" +-- if hty : w₁ = w₂ then +-- let binOp ← (mkBinOp op v₁ (hty ▸ v₂) : ExceptM ..) +-- return ⟨.bitvec w₁, binOp⟩ +-- else +-- throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" +-- | ⟨.bitvec w, v⟩::[] => +-- let op ← match opStx.name with +-- | "llvm.not" => pure <| MOp.not w +-- | "llvm.neg" => pure <| MOp.neg w +-- | _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}" +-- let op ← mkUnaryOp op v +-- return ⟨.bitvec w, op⟩ +-- | [] => +-- if opStx.name == "llvm.mlir.constant" +-- then do +-- let some att := opStx.attrs.getAttr "value" +-- | throw <| .generic "tried to resolve constant without 'value' attribute" +-- match att with +-- | .int val ty => +-- let opTy@(MTy.bitvec w) ← mkType ty +-- return ⟨opTy, ⟨ +-- MOp.const w val, +-- by simp [OpSignature.outTy, signature, *], +-- HVector.nil, +-- HVector.nil +-- ⟩⟩ +-- | _ => throw <| .generic "invalid constant attribute" +-- else +-- throw <| .generic s!"invalid (0-ary) expression {opStx.name}" +-- | _ => throw <| .generic s!"unsupported expression (with unsupported arity) {opStx.name}" -def mkReturn (Γ : Context φ) (opStx : AST.Op φ) (args : List (Σ (ty : MTy φ), Ctxt.Var Γ ty)) : - ReaderM φ (Σ ty, ICom (MOp φ) Γ ty) := - if opStx.name == "llvm.return" - then match args with - | ⟨ty, v⟩::[] => do - return ⟨ty, ICom.ret v⟩ - | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {args.length})" - else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" +-- def mkReturn (Γ : Context φ) (opStx : AST.Op φ) (args : List (Σ (ty : MTy φ), Ctxt.Var Γ ty)) : +-- ReaderM φ (Σ ty, ICom (MOp φ) Γ ty) := +-- if opStx.name == "llvm.return" +-- then match args with +-- | ⟨ty, v⟩::[] => do +-- return ⟨ty, ICom.ret v⟩ +-- | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {args.length})" +-- else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" -/-! ## Instantiation -/ +-- /-! ## Instantiation -/ -def MTy.instantiate (vals : Vector Nat φ) : (MTy φ) → Ty - | .bitvec w => .bitvec <| .concrete <| w.instantiate vals +-- def MTy.instantiate (vals : Vector Nat φ) : (MTy φ) → Ty +-- | .bitvec w => .bitvec <| .concrete <| w.instantiate vals -def MOp.instantiate (vals : Vector Nat φ) : MOp φ → Op - | .and w => .and (w.instantiate vals) - | .or w => .or (w.instantiate vals) - | .not w => .not (w.instantiate vals) - | .xor w => .xor (w.instantiate vals) - | .shl w => .shl (w.instantiate vals) - | .lshr w => .lshr (w.instantiate vals) - | .ashr w => .ashr (w.instantiate vals) - | .urem w => .urem (w.instantiate vals) - | .srem w => .srem (w.instantiate vals) - | .select w => .select (w.instantiate vals) - | .add w => .add (w.instantiate vals) - | .mul w => .mul (w.instantiate vals) - | .sub w => .sub (w.instantiate vals) - | .neg w => .neg (w.instantiate vals) - | .copy w => .copy (w.instantiate vals) - | .sdiv w => .sdiv (w.instantiate vals) - | .udiv w => .udiv (w.instantiate vals) - | .icmp c w => .icmp c (w.instantiate vals) - | .const w val => .const (w.instantiate vals) val +-- def MOp.instantiate (vals : Vector Nat φ) : MOp φ → Op +-- | .and w => .and (w.instantiate vals) +-- | .or w => .or (w.instantiate vals) +-- | .not w => .not (w.instantiate vals) +-- | .xor w => .xor (w.instantiate vals) +-- | .shl w => .shl (w.instantiate vals) +-- | .lshr w => .lshr (w.instantiate vals) +-- | .ashr w => .ashr (w.instantiate vals) +-- | .urem w => .urem (w.instantiate vals) +-- | .srem w => .srem (w.instantiate vals) +-- | .select w => .select (w.instantiate vals) +-- | .add w => .add (w.instantiate vals) +-- | .mul w => .mul (w.instantiate vals) +-- | .sub w => .sub (w.instantiate vals) +-- | .neg w => .neg (w.instantiate vals) +-- | .copy w => .copy (w.instantiate vals) +-- | .sdiv w => .sdiv (w.instantiate vals) +-- | .udiv w => .udiv (w.instantiate vals) +-- | .icmp c w => .icmp c (w.instantiate vals) +-- | .const w val => .const (w.instantiate vals) val -/-! ## Instances -/ +-- /-! ## Instances -/ -instance : AST.TransformDialect (MOp φ) (MTy φ) φ where - mkType := mkType - mkReturn := mkReturn - mkExpr := mkExpr +-- instance : AST.TransformDialect (MOp φ) (MTy φ) φ where +-- mkType := mkType +-- mkReturn := mkReturn +-- mkExpr := mkExpr -instance : TransformDialectInstantiate Op φ Ty (MOp φ) (MTy φ) where - morphism vals := { - mapOp := MOp.instantiate vals, - mapTy := MTy.instantiate vals, - preserves_signature := by - intro op - simp only [MTy.instantiate, MOp.instantiate, ConcreteOrMVar.instantiate, (· <$> ·), signature, - InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, - true_and] - cases op <;> simp only [List.map, and_self, List.cons.injEq] - } \ No newline at end of file +-- instance : TransformDialectInstantiate Op φ Ty (MOp φ) (MTy φ) where +-- morphism vals := { +-- mapOp := MOp.instantiate vals, +-- mapTy := MTy.instantiate vals, +-- preserves_signature := by +-- intro op +-- simp only [MTy.instantiate, MOp.instantiate, ConcreteOrMVar.instantiate, (· <$> ·), signature, +-- InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, +-- true_and] +-- cases op <;> simp only [List.map, and_self, List.cons.injEq] +-- } diff --git a/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean b/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean index 48d278dce..adff71241 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean @@ -17,14 +17,7 @@ set_option linter.unusedVariables false -- linter gives a false positive for `φ variable (Op) (φ) {Ty} {MOp MTy} [OpSignature Op Ty] [AST.TransformDialect MOp MTy φ] - [instInst : TransformDialectInstantiate Op φ Ty MOp MTy] + [instInst : TransformDialectInstantiate Op φ Ty MOp MTy] [DecidableEq MTy] -def AST.mkComInstantiate (reg : AST.Region φ) : - AST.ExceptM MOp (Vector Nat φ → Σ (Γ : Ctxt Ty) (ty : Ty), ICom Op Γ ty) := do - let ⟨Γ, ty, icom⟩ ← AST.mkCom MOp reg - return fun vals => - let f := instInst.morphism vals - ⟨Γ.map f.mapTy, f.mapTy ty, icom.map f⟩ - -end MLIR \ No newline at end of file +end MLIR diff --git a/SSA/Projects/InstCombine/LLVM/Transform/TransformError.lean b/SSA/Projects/InstCombine/LLVM/Transform/TransformError.lean index 010e788c7..eadaed68b 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform/TransformError.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform/TransformError.lean @@ -9,8 +9,8 @@ inductive TransformError (Ty : Type) | typeError (expected got : Ty) | widthError {φ} (expected got : Width φ) | unsupportedUnaryOp - | unsupportedBinaryOp - | unsupportedOp + | unsupportedBinaryOp (error : String) + | unsupportedOp (error : String) | unsupportedType | generic (error : String) @@ -20,16 +20,16 @@ instance [Repr Ty] : Repr (TransformError Ty) where reprPrec err _ := match err with | nameAlreadyDeclared var => f!"Already declared {var}, shadowing is not allowed" | undeclaredName name => f!"Undeclared name '{name}'" - | indexOutOfBounds name index len => + | indexOutOfBounds name index len => f!"Index of '{name}' out of bounds of the given context (index was {index}, but context has length {len})" | typeError expected got => f!"Type mismatch: expected '{repr expected}', but 'name' has type '{repr got}'" | widthError expected got => f!"Type mismatch: {expected} ≠ {got}" | unsupportedUnaryOp => f!"Unsuported unary operation" - | unsupportedBinaryOp => f!"Unsuported binary operation" - | unsupportedOp => f!"Unsuported operation" + | unsupportedBinaryOp err => f!"Unsuported binary operation 's!{err}'" + | unsupportedOp err => f!"Unsuported operation 's!{err}'" | unsupportedType => f!"Unsuported type" | generic err => err end TransformError -end MLIR.AST \ No newline at end of file +end MLIR.AST diff --git a/SSA/Projects/InstCombine/Refinement.lean b/SSA/Projects/InstCombine/Refinement.lean index 91f5cd2fa..7f213c1fc 100644 --- a/SSA/Projects/InstCombine/Refinement.lean +++ b/SSA/Projects/InstCombine/Refinement.lean @@ -1,9 +1,10 @@ +import SSA.Projects.InstCombine.Base import SSA.Projects.InstCombine.LLVM.EDSL import SSA.Projects.InstCombine.AliveStatements open MLIR AST -abbrev Com.Refinement (src tgt : Com (φ:=0) Γ t) (h : Goedel.toType t = Option α := by rfl) : Prop := +abbrev Com.Refinement (src tgt : Com (InstCombine.MOp 0) Γ t) (h : Goedel.toType t = Option α := by rfl) : Prop := ∀ Γv, (h ▸ src.denote Γv) ⊑ (h ▸ tgt.denote Γv) infixr:90 " ⊑ " => Com.Refinement From 829b6da78e93da4a4286ead3179b99c948a0b8f6 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 15 Nov 2023 05:02:59 +0000 Subject: [PATCH 06/28] make some progress by replacing Qq match with Expr match MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We now get the error: ``` ./././SSA/Projects/InstCombine/AliveAutoGenerated.lean:31:4: error: code generator does not support recursor 'InstCombine.MOp.rec' yet, consider using 'match ... with' and/or structural recursion ./././SSA/Projects/InstCombine/AliveAutoGenerated.lean:42:4: error: code generator does not support recursor 'InstCombine.MOp.rec' yet, consider using 'match ... with' and/or structural recursion ./././SSA/Projects/InstCombine/AliveAutoGenerated.lean:54:41: error: type mismatch _ has type HEq ?m.273585 ?m.273585 : Prop but is expected to have type ⟦?m.270983⟧ = Option ?m.270984 : Prop ./././SSA/Projects/InstCombine/AliveAutoGenerated.lean:54:68: error: application type mismatch ?m.271475 ⊑ alive_AddSub_1043_tgt w argument alive_AddSub_1043_tgt w has type @Com (InstCombine.MOp 0) (InstCombine.MTy 0) { signature := fun op => { sig := (fun {φ} motive x h_1 h_2 h_3 h_4 h_5 h_6 h_7 h_8 h_9 h_10 h_11 h_12 h_13 h_14 h_15 h_16 h_17 h_18 h_19 => InstCombine.MOp.rec (fun w => h_1 w) (fun w => h_2 w) (fun w => h_15 w) (fun w => h_3 w) (fun w => h_4 w) (fun w => h_5 w) (fun w => h_6 w) (fun w => h_13 w) (fun w => h_12 w) (fun w => h_18 w) (fun w => h_7 w) (fun w => h_8 w) (fun w => h_9 w) (fun w => h_16 w) (fun w => h_17 w) (fun w => h_11 w) (fun w => h_10 w) (fun c w => h_14 c w) (fun w val => h_19 w val) x) (fun x => List (InstCombine.MTy 0)) op (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun c w => [InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec w]) (fun w => [InstCombine.MTy.bitvec (ConcreteOrMVar.concrete 1), InstCombine.MTy.bitvec w, InstCombine.MTy.bitvec w]) fun w val => [], regSig := [], outTy := (fun {φ} motive x h_1 h_2 h_3 h_4 h_5 h_6 h_7 h_8 h_9 h_10 h_11 h_12 h_13 h_14 h_15 h_16 h_17 h_18 h_19 => InstCombine.MOp.rec (fun w => h_1 w) (fun w => h_2 w) (fun w => h_3 w) (fun w => h_4 w) (fun w => h_5 w) (fun w => h_6 w) (fun w => h_7 w) (fun w => h_17 w) (fun w => h_16 w) (fun w => h_9 w) (fun w => h_12 w) (fun w => h_13 w) (fun w => h_8 w) (fun w => h_10 w) (fun w => h_11 w) (fun w => h_14 w) (fun w => h_15 w) (fun c w => h_18 c w) (fun w val => h_19 w val) x) (fun x => InstCombine.MTy 0) op (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun w => InstCombine.MTy.bitvec w) (fun c w => InstCombine.MTy.bitvec (ConcreteOrMVar.concrete 1)) fun width val => InstCombine.MTy.bitvec width } } [InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w), InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w), InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w)] (InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w)) : Type but is expected to have type @Com (InstCombine.MOp 0) (InstCombine.MTy 0) InstCombine.instOpSignatureMOpMTy [InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w), InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w), InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w)] (InstCombine.MTy.bitvec (ConcreteOrMVar.concrete w)) : Type ``` --- SSA/Projects/InstCombine/LLVM/EDSL.lean | 36 ++++++++++++++++++------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 8088d4fcf..0e761e6c4 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -39,8 +39,21 @@ open MLIR -- -- throwError "Translation failed with error:\n\t{repr err}" -- -- | e => throwError "Translation failed to reduce, possibly too generic syntax\n\t{e}" --- sorry +#check Except.ok +#check Sigma.mk +@[inline] def _root_.Lean.Expr.app5? (e : Expr) (fName : Name) : Option (Expr × Expr × Expr × Expr × Expr) := + if e.isAppOfArity fName 4 then + some (e.appFn!.appFn!.appFn!.appFn!.appArg!, e.appFn!.appFn!.appFn!.appArg!, e.appFn!.appFn!.appArg!, e.appFn!.appArg!, e.appArg!) + else + none +/- +https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/Cannot.20Find.20.60Real.2Eadd.60/near/402089561 +> I would recommend avoiding Qq for pattern matching. +> That part of the Qq implementation is spicy. + +Therefore, we choose to match on raw `Expr`. +-/ open MLIR.AST InstCombine in elab "[alive_icom (" mvars:term,* ")| " reg:mlir_region "]" : term => do let ast_stx ← `([mlir_region| $reg]) @@ -51,15 +64,20 @@ elab "[alive_icom (" mvars:term,* ")| " reg:mlir_region "]" : term => do let com := q(mkComInstantiate (φ := $φ) $ast |>.map (· $mvalues)) synthesizeSyntheticMVarsNoPostponing let com : Q(MLIR.AST.ExceptM Op (Σ (Γ' : Ctxt Ty) (ty : Ty), Com Op Γ' ty)) ← - withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do + withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding true }) do withTransparency (mode := TransparencyMode.all) <| return ←reduceAll com + let comExpr : Expr := com trace[Meta] com - match com with - | ~q(Except.ok $comOk) => return q(($comOk).snd.snd) - | ~q(Except.error $err) => do - let err ← unsafe evalExpr (TransformError Op) q(TransformError Op) err - throwError "Translation failed with error:\n\t'com: {repr com}',\n\t'error {repr err}:" - | e => throwError "Translation failed to reduce, possibly too generic syntax\n\t.'error:\n\t {e}'" - + trace[Meta] comExpr + match comExpr.app3? ``Except.ok with + | .some (_εexpr, _αexpr, aexpr) => + match aexpr.app4? ``Sigma.mk with + | .some (_αexpr, _βexpr, _fstexpr, sndexpr) => + match sndexpr.app4? ``Sigma.mk with + | .some (_αexpr, _βexpr, _fstexpr, sndexpr) => + return sndexpr + | .none => throwError "Found `Except.ok (Sigma.mk _ WRONG)`, Expected (Except.ok (Sigma.mk _ (Sigma.mk _ _))" + | .none => throwError "Found `Except.ok WRONG`, Expected (Except.ok (Sigma.mk _ _))" + | .none => throwError "Expected `Except.ok`, found {comExpr}" macro "[alive_icom| " reg:mlir_region "]" : term => `([alive_icom ()| $reg]) From dfbe90e3ad2e1833aa5f68ea51a69f8bed2ca940 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 15 Nov 2023 05:06:54 +0000 Subject: [PATCH 07/28] fix: change reduceAll -> reduce --- SSA/Projects/InstCombine/LLVM/EDSL.lean | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 0e761e6c4..83bd58717 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -64,9 +64,9 @@ elab "[alive_icom (" mvars:term,* ")| " reg:mlir_region "]" : term => do let com := q(mkComInstantiate (φ := $φ) $ast |>.map (· $mvalues)) synthesizeSyntheticMVarsNoPostponing let com : Q(MLIR.AST.ExceptM Op (Σ (Γ' : Ctxt Ty) (ty : Ty), Com Op Γ' ty)) ← - withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding true }) do + withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do withTransparency (mode := TransparencyMode.all) <| - return ←reduceAll com + return ←reduce com let comExpr : Expr := com trace[Meta] com trace[Meta] comExpr From d45a5c40caf503242a379cc72f998bde333ce253 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 15 Nov 2023 06:06:09 +0000 Subject: [PATCH 08/28] feat: split up code, and use to transform from MLIR AST to ICom. however, we generate _gigantic_ terms $ lake build 2>&1 | wc -l 429 --- SSA/Projects/InstCombine/LLVM/EDSL.lean | 406 ++++++++++++++++-- SSA/Projects/InstCombine/LLVM/Transform.lean | 180 ++++---- .../LLVM/Transform/Dialects/InstCombine.lean | 323 -------------- .../LLVM/Transform/Instantiate.lean | 23 - SSA/Projects/InstCombine/Tactic.lean | 3 +- 5 files changed, 457 insertions(+), 478 deletions(-) delete mode 100644 SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean delete mode 100644 SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 83bd58717..70cf4990c 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -1,51 +1,373 @@ import Qq +import SSA.Projects.InstCombine.Base +import Std.Data.BitVec +import SSA.Projects.MLIRSyntax.AST import SSA.Projects.MLIRSyntax.EDSL import SSA.Projects.InstCombine.LLVM.Transform -import SSA.Projects.InstCombine.LLVM.Transform.Instantiate - -import SSA.Projects.InstCombine.LLVM.Transform.Dialects.InstCombine +-- import SSA.Projects.InstCombine.LLVM.Transform.Dialects.InstCombine open Qq Lean Meta Elab.Term Elab Command +open InstCombine (MOp MTy Width) open MLIR --- def elabToMlirICom (Op : Q(Type)) (mvars : Syntax.TSepArray `term ",") (reg : TSyntax `mlir_region) : --- TermElabM Unit := do --- let φ : Nat := mvars.getElems.size --- let Ty ← mkFreshExprMVarQ q(Type) --- let MOp ← mkFreshExprMVarQ q(Type) --- let MTy ← mkFreshExprMVarQ q(Type) --- let _ ← synthInstanceQ q(OpSignature $Op $Ty) --- let _ ← synthInstanceQ q(AST.TransformDialect $MOp $MTy $φ) --- let instInst ← synthInstanceQ q(TransformDialectInstantiate $Op $φ $Ty $MOp $MTy) - --- let ast_stx ← `([mlir_region| $reg]) --- let ast ← elabTermEnsuringTypeQ ast_stx q(AST.Region $φ) - --- let mvalues ← `(⟨[$mvars,*], by rfl⟩) --- let mvalues : Q(Vector Nat $φ) ← elabTermEnsuringType mvalues q(Vector Nat $φ) - --- let com := q(AST.mkComInstantiate (instInst:=$instInst) $ast) --- synthesizeSyntheticMVarsNoPostponing --- -- let com : Q(ExceptM (Σ (Γ' : Ctxt Ty) (ty : InstCombine.Ty), Com Γ' ty)) ← --- -- withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do --- -- withTransparency (mode := TransparencyMode.all) <| --- -- return ←reduce com --- -- trace[Meta] com --- -- match com with --- -- | ~q(Except.ok $comOk) => return q(($comOk).snd.snd) --- -- | ~q(Except.error $err) => do --- -- let err ← unsafe evalExpr TransformError q(TransformError) err --- -- throwError "Translation failed with error:\n\t{repr err}" --- -- | e => throwError "Translation failed to reduce, possibly too generic syntax\n\t{e}" - -#check Except.ok -#check Sigma.mk -@[inline] def _root_.Lean.Expr.app5? (e : Expr) (fName : Name) : Option (Expr × Expr × Expr × Expr × Expr) := - if e.isAppOfArity fName 4 then - some (e.appFn!.appFn!.appFn!.appFn!.appArg!, e.appFn!.appFn!.appFn!.appArg!, e.appFn!.appFn!.appArg!, e.appFn!.appArg!, e.appArg!) - else - none +namespace InstcombineTransformDialect + +def mkUnaryOp {Γ : Ctxt (MTy φ)} {ty : (MTy φ)} (op : MOp φ) + (e : Ctxt.Var Γ ty) : MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := + match ty with + | .bitvec w => + match op with + -- Can't use a single arm, Lean won't write the rhs accordingly + | .neg w' => if h : w = w' + then return ⟨ + .neg w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e) .nil, + .nil + ⟩ + else throw <| .widthError w w' + | .not w' => if h : w = w' + then return ⟨ + .not w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e) .nil, + .nil + ⟩ + else throw <| .widthError w w' + | .copy w' => if h : w = w' + then return ⟨ + .copy w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e) .nil, + .nil + ⟩ + else throw <| .widthError w w' + | _ => throw <| .unsupportedUnaryOp + +def mkBinOp {Γ : Ctxt (MTy φ)} {ty : (MTy φ)} (op : MOp φ) + (e₁ e₂ : Ctxt.Var Γ ty) : MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := + match ty with + | .bitvec w => + match op with + -- Can't use a single arm, Lean won't write the rhs accordingly + | .add w' => if h : w = w' + then return ⟨ + .add w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .and w' => if h : w = w' + then return ⟨ + .and w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .or w' => if h : w = w' + then return ⟨ + .or w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .xor w' => if h : w = w' + then return ⟨ + .xor w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .shl w' => if h : w = w' + then return ⟨ + .shl w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .lshr w' => if h : w = w' + then return ⟨ + .lshr w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .ashr w' => if h : w = w' + then return ⟨ + .ashr w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .urem w' => if h : w = w' + then return ⟨ + .urem w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .srem w' => if h : w = w' + then return ⟨ + .srem w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .mul w' => if h : w = w' + then return ⟨ + .mul w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .sub w' => if h : w = w' + then return ⟨ + .sub w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .sdiv w' => if h : w = w' + then return ⟨ + .sdiv w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | .udiv w' => if h : w = w' + then return ⟨ + .udiv w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | op => throw <| .unsupportedBinaryOp s!"unsupported binary operation {op}" + +def mkIcmp {Γ : Ctxt _} {ty : (MTy φ)} (op : MOp φ) + (e₁ e₂ : Ctxt.Var Γ ty) : MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ (.bitvec 1) := + match ty with + | .bitvec w => + match op with + | .icmp p w' => if h : w = w' + then return ⟨ + .icmp p w', + by simp [OpSignature.outTy, signature, h], + .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | _ => throw <| .unsupportedOp "unsupported icmp operation" + +def mkSelect {Γ : Ctxt (MTy φ)} {ty : (MTy φ)} (op : MOp φ) + (c : Ctxt.Var Γ (.bitvec 1)) (e₁ e₂ : Ctxt.Var Γ ty) : + MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := + match ty with + | .bitvec w => + match op with + | .select w' => if h : w = w' + then return ⟨ + .select w', + by simp [OpSignature.outTy, signature, h], + .cons c <|.cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , + .nil + ⟩ + else throw <| .widthError w w' + | _ => throw <| .unsupportedOp "Unsupported select operation" + +def mkOpExpr {Γ : Ctxt (MTy φ)} (op : MOp φ) + (arg : HVector (fun t => Ctxt.Var Γ t) (OpSignature.sig op)) : + MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ (OpSignature.outTy op) := + match op with + | .and _ | .or _ | .xor _ | .shl _ | .lshr _ | .ashr _ + | .add _ | .mul _ | .sub _ | .udiv _ | .sdiv _ + | .srem _ | .urem _ => + let (e₁, e₂) := arg.toTuple + mkBinOp op e₁ e₂ + | .icmp _ _ => + let (e₁, e₂) := arg.toTuple + mkIcmp op e₁ e₂ + | .not _ | .neg _ | .copy _ => + mkUnaryOp op arg.head + | .select _ => + let (c, e₁, e₂) := arg.toTuple + mkSelect op c e₁ e₂ + | .const .. => throw <| .unsupportedOp "Tried to build (MOp φ) expression from constant" + +def mkTy : MLIR.AST.MLIRType φ → MLIR.AST.ExceptM (MOp φ) (MTy φ) + | MLIR.AST.MLIRType.int MLIR.AST.Signedness.Signless w => do + return .bitvec w + | _ => throw .unsupportedType -- "Unsupported type" + +instance instTransformTy : MLIR.AST.TransformTy (MOp φ) (MTy φ) φ where + mkTy := mkTy + +def mkExpr (Γ : Ctxt (MTy φ)) (opStx : MLIR.AST.Op φ) : AST.ReaderM (MOp φ) (Σ ty, Expr (MOp φ) Γ ty) := do + match opStx.args with + | v₁Stx::v₂Stx::v₃Stx::[] => + let ⟨.bitvec w₁, v₁⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₁Stx + let ⟨.bitvec w₂, v₂⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₂Stx + let ⟨.bitvec w₃, v₃⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₃Stx + match opStx.name with + | "llvm.select" => + if hw1 : w₁ = 1 then + if hw23 : w₂ = w₃ then + let selectOp ← mkSelect (MOp.select w₂) (hw1 ▸ v₁) v₂ (hw23 ▸ v₃) + return ⟨.bitvec w₂, selectOp⟩ + else + throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" + else throw <| .unsupportedOp s!"expected select condtion to have width 1, found width '{w₁}'" + | op => throw <| .unsupportedOp s!"Unsuported ternary operation or invalid arguments '{op}'" + | v₁Stx::v₂Stx::[] => + let ⟨.bitvec w₁, v₁⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₁Stx + let ⟨.bitvec w₂, v₂⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₂Stx + -- let ty₁ := ty₁.instantiave + let op ← match opStx.name with + | "llvm.and" => pure (MOp.and w₁) + | "llvm.or" => pure (MOp.or w₁) + | "llvm.xor" => pure (MOp.xor w₁) + | "llvm.shl" => pure (MOp.shl w₁) + | "llvm.lshr" => pure (MOp.lshr w₁) + | "llvm.ashr" => pure (MOp.ashr w₁) + | "llvm.urem" => pure (MOp.urem w₁) + | "llvm.srem" => pure (MOp.srem w₁) + | "llvm.add" => pure (MOp.add w₁) + | "llvm.mul" => pure (MOp.mul w₁) + | "llvm.sub" => pure (MOp.sub w₁) + | "llvm.sdiv" => pure (MOp.sdiv w₁) + | "llvm.udiv" => pure (MOp.udiv w₁) + | "llvm.icmp.eq" => pure (MOp.icmp LLVM.IntPredicate.eq w₁) + | "llvm.icmp.ne" => pure (MOp.icmp LLVM.IntPredicate.ne w₁) + | "llvm.icmp.ugt" => pure (MOp.icmp LLVM.IntPredicate.ugt w₁) + | "llvm.icmp.uge" => pure (MOp.icmp LLVM.IntPredicate.uge w₁) + | "llvm.icmp.ult" => pure (MOp.icmp LLVM.IntPredicate.ult w₁) + | "llvm.icmp.ule" => pure (MOp.icmp LLVM.IntPredicate.ule w₁) + | "llvm.icmp.sgt" => pure (MOp.icmp LLVM.IntPredicate.sgt w₁) + | "llvm.icmp.sge" => pure (MOp.icmp LLVM.IntPredicate.sge w₁) + | "llvm.icmp.slt" => pure (MOp.icmp LLVM.IntPredicate.slt w₁) + | "llvm.icmp.sle" => pure (MOp.icmp LLVM.IntPredicate.sle w₁) + | opstr => throw <| .unsupportedOp s!"Unsuported binary operation or invalid arguments '{opstr}'" + match op with + | .icmp .. => + if hty : w₁ = w₂ then + let icmpOp ← mkIcmp op v₁ (hty ▸ v₂) + return ⟨.bitvec 1, icmpOp⟩ + else + throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" + | _ => + if hty : w₁ = w₂ then + let binOp ← mkBinOp op v₁ (hty ▸ v₂) + return ⟨.bitvec w₁, binOp⟩ + else + throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" + | vStx::[] => + let ⟨.bitvec w, v⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ vStx + let op ← match opStx.name with + | "llvm.not" => pure <| MOp.not w + | "llvm.neg" => pure <| MOp.neg w + | "llvm.copy" => pure <| MOp.copy w + | _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}" + let op ← mkUnaryOp op v + return ⟨.bitvec w, op⟩ + | [] => + if opStx.name == "llvm.mlir.constant" + then do + let some att := opStx.attrs.getAttr "value" + | throw <| .generic "tried to resolve constant without 'value' attribute" + match att with + | .int val ty => + let opTy@(MTy.bitvec w) ← mkTy ty -- ty.mkTy + return ⟨opTy, ⟨ + MOp.const w val, + by simp [OpSignature.outTy, signature, *], + HVector.nil, + HVector.nil + ⟩⟩ + | _ => throw <| .generic "invalid constant attribute" + else + throw <| .generic s!"invalid (0-ary) expression {opStx.name}" + | _ => throw <| .generic s!"unsupported expression (with unsupported arity) {opStx.name}" + +instance : AST.TransformExpr (MOp φ) (MTy φ) φ where + mkExpr := mkExpr + +def mkReturn (Γ : Ctxt (MTy φ)) (opStx : MLIR.AST.Op φ) : MLIR.AST.ReaderM (MOp φ) (Σ ty, Com (MOp φ) Γ ty) := + if opStx.name == "llvm.return" + then match opStx.args with + | vStx::[] => do + let ⟨ty, v⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ vStx + return ⟨ty, _root_.Com.ret v⟩ + | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {opStx.args.length})" + else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" + +instance : AST.TransformReturn (MOp φ) (MTy φ) φ where + mkReturn := mkReturn + + +/-! + ## Instantiation + Finally, we show how to instantiate a family of programs to a concrete program +-/ + +def instantiateMTy (vals : Vector Nat φ) : (MTy φ) → InstCombine.Ty + | .bitvec w => .bitvec <| .concrete <| w.instantiate vals + +def instantiateMOp (vals : Vector Nat φ) : MOp φ → InstCombine.Op + | .and w => .and (w.instantiate vals) + | .or w => .or (w.instantiate vals) + | .not w => .not (w.instantiate vals) + | .xor w => .xor (w.instantiate vals) + | .shl w => .shl (w.instantiate vals) + | .lshr w => .lshr (w.instantiate vals) + | .ashr w => .ashr (w.instantiate vals) + | .urem w => .urem (w.instantiate vals) + | .srem w => .srem (w.instantiate vals) + | .select w => .select (w.instantiate vals) + | .add w => .add (w.instantiate vals) + | .mul w => .mul (w.instantiate vals) + | .sub w => .sub (w.instantiate vals) + | .neg w => .neg (w.instantiate vals) + | .copy w => .copy (w.instantiate vals) + | .sdiv w => .sdiv (w.instantiate vals) + | .udiv w => .udiv (w.instantiate vals) + | .icmp c w => .icmp c (w.instantiate vals) + | .const w val => .const (w.instantiate vals) val + +def instantiateCtxt (vals : Vector Nat φ) (Γ : Ctxt (MTy φ)) : Ctxt InstCombine.Ty := + Γ.map (instantiateMTy vals) + +def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCombine.Op) where + mapOp := instantiateMOp vals + mapTy := instantiateMTy vals + preserves_signature op := by + simp only [instantiateMTy, instantiateMOp, ConcreteOrMVar.instantiate, (· <$> ·), signature, + InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, + true_and] + cases op <;> simp only [List.map, and_self, List.cons.injEq] + + +open InstCombine (Op Ty) in +def mkComInstantiate (reg : MLIR.AST.Region φ) : + MLIR.AST.ExceptM (MOp φ) (Vector Nat φ → Σ (Γ : Ctxt InstCombine.Ty) (ty : InstCombine.Ty), Com InstCombine.Op Γ ty) := do + let ⟨Γ, ty, com⟩ ← MLIR.AST.mkCom reg + return fun vals => + ⟨instantiateCtxt vals Γ, instantiateMTy vals ty, com.map (MOp.instantiateCom vals)⟩ + +end InstcombineTransformDialect + /- https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/Cannot.20Find.20.60Real.2Eadd.60/near/402089561 @@ -61,9 +383,9 @@ elab "[alive_icom (" mvars:term,* ")| " reg:mlir_region "]" : term => do let ast ← elabTermEnsuringTypeQ ast_stx q(Region $φ) let mvalues ← `(⟨[$mvars,*], by rfl⟩) let mvalues : Q(Vector Nat $φ) ← elabTermEnsuringType mvalues q(Vector Nat $φ) - let com := q(mkComInstantiate (φ := $φ) $ast |>.map (· $mvalues)) + let com := q(InstcombineTransformDialect.mkComInstantiate (φ := $φ) $ast |>.map (· $mvalues)) synthesizeSyntheticMVarsNoPostponing - let com : Q(MLIR.AST.ExceptM Op (Σ (Γ' : Ctxt Ty) (ty : Ty), Com Op Γ' ty)) ← + let com : Q(MLIR.AST.ExceptM (MOp $φ) (Σ (Γ' : Ctxt (MTy $φ)) (ty : (MTy $φ)), Com (MOp $φ) Γ' ty)) ← withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do withTransparency (mode := TransparencyMode.all) <| return ←reduce com diff --git a/SSA/Projects/InstCombine/LLVM/Transform.lean b/SSA/Projects/InstCombine/LLVM/Transform.lean index f13a0876a..57128964c 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform.lean @@ -14,17 +14,6 @@ namespace MLIR.AST open InstCombine (MOp MTy Width) open Std (BitVec) --- abbrev Context (φ) := List (MTy φ) - --- abbrev Expr (Γ : Context φ) (ty : MTy φ) := _root_.Expr (MOp φ) Γ ty --- abbrev Com (Γ : Context φ) (ty : MTy φ) := _root_.Com (MOp φ) Γ ty --- abbrev Var (Γ : Context φ) (ty : MTy φ) := _root_.Ctxt.Var Γ ty - --- abbrev Com.lete (body : Expr Γ ty₁) (rest : Com (ty₁::Γ) ty₂) : Com Γ ty₂ := --- _root_.Com.lete body rest - --- namespace TransformError - instance : Lean.ToFormat (MTy φ) where format := repr @@ -41,10 +30,6 @@ abbrev ExceptM (Op) [OpSignature Op Ty] := Except (TransformError Ty) abbrev BuilderM (Op) [OpSignature Op Ty] := StateT NameMapping (ExceptM Op) abbrev ReaderM (Op) [OpSignature Op Ty] := ReaderT NameMapping (ExceptM Op) --- instance : MonadLift ReaderM BuilderM where - --- variable {Op Ty} [OpSignature Op Ty] - instance {Op : Type} [OpSignature Op Ty] : MonadLift (ReaderM Op) (BuilderM Op) where monadLift x := do (ReaderT.run x (←get) : ExceptM ..) @@ -56,61 +41,56 @@ def BuilderM.runWithNewMapping {Op : Type} [OpSignature Op Ty] (k : BuilderM Op end Monads -class TransformDialect (Op : Type) (Ty : outParam (Type)) (φ : outParam Nat) extends OpSignature Op Ty where - mkType : MLIRType φ → ExceptM Op Ty - mkReturn : (Γ : List Ty) → (opStx : AST.Op φ) → (args : List (Σ (ty : Ty), Ctxt.Var Γ ty)) - → ReaderM Op (Σ ty, _root_.Com Op Γ ty) - mkExpr : (Γ : List Ty) → (opStx : AST.Op φ) → (args : List (Σ (ty : Ty), Ctxt.Var Γ ty)) - → ReaderM Op (Σ ty, _root_.Expr Op Γ ty) - - - - -variable {Op Ty φ} [OpSignature Op Ty] -abbrev Context (Ty) := List (Ty) - --- abbrev Expr (Γ : Context Ty) (ty : Ty) := _root_.Expr Op Γ ty --- abbrev Com (Γ : Context Ty) (ty : Ty) := _root_.Com Op Γ ty --- abbrev Var (Γ : Context Ty) (ty : Ty) := Ctxt.Var Γ ty --- variable [d : TransformDialect Op Ty φ] [DecidableEq Ty] - --- abbrev Com.lete {Γ : Context Ty} (body : Expr Γ ty₁) (rest : Com Op (ty₁::Γ) ty₂) : Com Op Γ ty₂ := --- _root_.Com.lete body rest - +/-! + These typeclasses provide a natural flow to how users should implement `TransformDialect`. + - First declare how to transform types with `TransformTy`. + - Second, using `TransformTy`, declare how to transform expressions with `TransformExpr`. + - Third, using both type and expression conversion, declare how to transform returns with `TransformReturn`. + - These three automatically give an instance of `TransformDialect`. +-/ +class TransformTy (Op : Type) (Ty : outParam (Type)) (φ : outParam Nat) [OpSignature Op Ty] where + mkTy : MLIRType φ → ExceptM Op Ty +class TransformExpr (Op : Type) (Ty : outParam (Type)) (φ : outParam Nat) [OpSignature Op Ty] [TransformTy Op Ty φ] where + mkExpr : (Γ : List Ty) → (opStx : AST.Op φ) → ReaderM Op (Σ ty, Expr Op Γ ty) +class TransformReturn (Op : Type) (Ty : outParam (Type)) (φ : outParam Nat) + [OpSignature Op Ty] [TransformTy Op Ty φ] where + mkReturn : (Γ : List Ty) → (opStx : AST.Op φ) → ReaderM Op (Σ ty, Com Op Γ ty) +/- instance of the transform dialect, plus data needed about `Op` and `Ty`. -/ +variable {Op Ty φ} [OpSignature Op Ty] [DecidableEq Ty] [DecidableEq Op] -structure DerivedContext (Γ : Context Ty) where - ctxt : Context Ty +structure DerivedCtxt (Γ : Ctxt Ty) where + ctxt : Ctxt Ty diff : Ctxt.Diff Γ ctxt -namespace DerivedContext +namespace DerivedCtxt /-- Every context is trivially derived from itself -/ -abbrev ofContext (Γ : Context Ty) : DerivedContext Γ := ⟨Γ, .zero _⟩ +abbrev ofCtxt (Γ : Ctxt Ty) : DerivedCtxt Γ := ⟨Γ, .zero _⟩ /-- `snoc` of a derived context applies `snoc` to the underlying context, and updates the diff -/ -def snoc {Γ : Context Ty} : DerivedContext Γ → Ty → DerivedContext Γ +def snoc {Γ : Ctxt Ty} : DerivedCtxt Γ → Ty → DerivedCtxt Γ | ⟨ctxt, diff⟩, ty => ⟨ty::ctxt, diff.toSnoc⟩ -instance {Γ : Context Ty} : CoeHead (DerivedContext Γ) (Context Ty) where +instance {Γ : Ctxt Ty} : CoeHead (DerivedCtxt Γ) (Ctxt Ty) where coe := fun ⟨Γ', _⟩ => Γ' -instance {Γ : Context Ty} : CoeDep (Context Ty) Γ (DerivedContext Γ) where +instance {Γ : Ctxt Ty} : CoeDep (Ctxt Ty) Γ (DerivedCtxt Γ) where coe := ⟨Γ, .zero _⟩ -instance {Γ : Context Ty} {Γ' : DerivedContext Γ} : - CoeHead (DerivedContext (Γ' : Context Ty)) (DerivedContext Γ) where +instance {Γ : Ctxt Ty} {Γ' : DerivedCtxt Γ} : + CoeHead (DerivedCtxt (Γ' : Ctxt Ty)) (DerivedCtxt Γ) where coe := fun ⟨Γ'', diff⟩ => ⟨Γ'', Γ'.diff + diff⟩ -instance {Γ : Context Ty} {Γ' : DerivedContext Γ} : Coe (Expr Op Γ t) (Expr Op Γ'.ctxt t) where +instance {Γ : Ctxt Ty} {Γ' : DerivedCtxt Γ} : Coe (Expr Op Γ t) (Expr Op Γ'.ctxt t) where coe e := e.changeVars Γ'.diff.toHom -instance {Γ' : DerivedContext Γ} : Coe (Ctxt.Var Γ t) (Ctxt.Var (Γ' : Context Ty) t) where +instance {Γ' : DerivedCtxt Γ} : Coe (Ctxt.Var Γ t) (Ctxt.Var (Γ' : Ctxt Ty) t) where coe v := Γ'.diff.toHom v -end DerivedContext +end DerivedCtxt /-- Add a new variable to the context, and record it's (absolute) index in the name mapping @@ -118,12 +98,12 @@ end DerivedContext Throws an error if the variable name already exists in the mapping, essentially disallowing shadowing -/ -def addValToMapping (Γ : Context (MTy φ)) (name : String) (ty : MTy φ) : - BuilderM (MOp φ) (Σ (Γ' : DerivedContext Γ), Ctxt.Var Γ'.ctxt ty) := do +def addValToMapping (Γ : Ctxt Ty) (name : String) (ty : Ty) : + BuilderM Op (Σ (Γ' : DerivedCtxt Γ), Ctxt.Var Γ'.ctxt ty) := do let some nm := (←get).add name | throw <| .nameAlreadyDeclared name set nm - return ⟨DerivedContext.ofContext Γ |>.snoc ty, Ctxt.Var.last ..⟩ + return ⟨DerivedCtxt.ofCtxt Γ |>.snoc ty, Ctxt.Var.last ..⟩ /-- Look up a name from the name mapping, and return the corresponding variable in the given context. @@ -131,8 +111,8 @@ def addValToMapping (Γ : Context (MTy φ)) (name : String) (ty : MTy φ) : Throws an error if the name is not present in the mapping (this indicates the name may be free), or if the type of the variable in the context is different from `expectedType` -/ -def getValFromContext (Γ : Context (MTy φ)) (name : String) (expectedType : MTy φ) : - ReaderM (MOp φ) (Ctxt.Var Γ expectedType) := do +def getValFromCtxt (Γ : Ctxt Ty) (name : String) (expectedType : Ty) : + ReaderM Op (Ctxt.Var Γ expectedType) := do let index := (←read).lookup name let some index := index | throw <| .undeclaredName name let n := Γ.length @@ -158,8 +138,9 @@ def BuilderM.isErr {α : Type} (x : BuilderM Op α) : Bool := | Except.error _ => false #check Ctxt.Var -def mkUnaryOp {Γ : Ctxt (MTy φ)} {ty : MTy φ} (op : MOp φ) - (e : Ctxt.Var Γ ty) : ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := +/- +def mkUnaryOp {Γ : Ctxt Ty} {ty : Ty} (op : MOp φ) + (e : Ctxt.Var Γ ty) : ExceptM Op <| Expr Op Γ ty := match ty with | .bitvec w => match op with @@ -190,8 +171,8 @@ def mkUnaryOp {Γ : Ctxt (MTy φ)} {ty : MTy φ} (op : MOp φ) else throw <| .widthError w w' | _ => throw <| .unsupportedUnaryOp -def mkBinOp {Γ : Ctxt (MTy φ)} {ty : MTy φ} (op : MOp φ) - (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := +def mkBinOp {Γ : Ctxt Ty} {ty : Ty} (op : MOp φ) + (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM Op <| Expr Op Γ ty := match ty with | .bitvec w => match op with @@ -302,8 +283,8 @@ def mkBinOp {Γ : Ctxt (MTy φ)} {ty : MTy φ} (op : MOp φ) else throw <| .widthError w w' | op => throw <| .unsupportedBinaryOp s!"unsupported binary operation {op}" -def mkIcmp {Γ : Ctxt _} {ty : MTy φ} (op : MOp φ) - (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM (MOp φ) <| Expr (MOp φ) Γ (.bitvec 1) := +def mkIcmp {Γ : Ctxt _} {ty : Ty} (op : MOp φ) + (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM Op <| Expr Op Γ (.bitvec 1) := match ty with | .bitvec w => match op with @@ -317,9 +298,9 @@ def mkIcmp {Γ : Ctxt _} {ty : MTy φ} (op : MOp φ) else throw <| .widthError w w' | _ => throw <| .unsupportedOp "unsupported icmp operation" -def mkSelect {Γ : Ctxt (MTy φ)} {ty : MTy φ} (op : MOp φ) +def mkSelect {Γ : Ctxt Ty} {ty : Ty} (op : MOp φ) (c : Ctxt.Var Γ (.bitvec 1)) (e₁ e₂ : Ctxt.Var Γ ty) : - ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := + ExceptM Op <| Expr Op Γ ty := match ty with | .bitvec w => match op with @@ -333,9 +314,9 @@ def mkSelect {Γ : Ctxt (MTy φ)} {ty : MTy φ} (op : MOp φ) else throw <| .widthError w w' | _ => throw <| .unsupportedOp "Unsupported select operation" -def mkOpExpr {Γ : Context (MTy φ)} (op : MOp φ) +def mkOpExpr {Γ : Ctxt Ty} (op : MOp φ) (arg : HVector (fun t => Ctxt.Var Γ t) (OpSignature.sig op)) : - ExceptM (MOp φ) <| Expr (MOp φ) Γ (OpSignature.outTy op) := + ExceptM Op <| Expr Op Γ (OpSignature.outTy op) := match op with | .and _ | .or _ | .xor _ | .shl _ | .lshr _ | .ashr _ | .add _ | .mul _ | .sub _ | .udiv _ | .sdiv _ @@ -352,36 +333,49 @@ def mkOpExpr {Γ : Context (MTy φ)} (op : MOp φ) mkSelect op c e₁ e₂ | .const .. => throw <| .unsupportedOp "Tried to build Op expression from constant" -def MLIRType.mkTy : MLIRType φ → ExceptM (MOp φ) (MTy φ) +def MLIRType.mkTy : MLIRType φ → ExceptM Op Ty | MLIRType.int Signedness.Signless w => do return .bitvec w | _ => throw .unsupportedType -- "Unsupported type" +-/ -def TypedSSAVal.mkTy : TypedSSAVal φ → ExceptM (MOp φ) (MTy φ) - | (.SSAVal _, ty) => MLIRType.mkTy ty +def TypedSSAVal.mkTy [TransformTy Op Ty φ] : TypedSSAVal φ → ExceptM Op Ty + | (.SSAVal _, ty) => TransformTy.mkTy ty def mkVal (ty : InstCombine.Ty) : Int → BitVec ty.width | val => BitVec.ofInt ty.width val /-- Translate a `TypedSSAVal` (a name with an expected type), to a variable in the context. This expects the name to have already been declared before -/ -def TypedSSAVal.mkVal (Γ : Context (MTy φ)) : TypedSSAVal φ → - ReaderM (MOp φ) (Σ (ty : MTy φ), Ctxt.Var Γ ty) +def TypedSSAVal.mkVal [instTransformTy : TransformTy Op Ty φ] (Γ : Ctxt Ty) : TypedSSAVal φ → + ReaderM Op (Σ (ty : Ty), Ctxt.Var Γ ty) +| (.SSAVal valStx, tyStx) => do + let ty ← instTransformTy.mkTy tyStx + let var ← getValFromCtxt Γ valStx ty + return ⟨ty, var⟩ + +/-- A variant of `TypedSSAVal.mkVal` that takes the function `mkTy` as an argument + instead of using the typeclass `TransformDialect`. + This is useful when trying to implement an instance of `TransformDialect` itself, + to cut infinite regress. -/ +def TypedSSAVal.mkVal' [instTransformTy : TransformTy Op Ty φ] (Γ : Ctxt Ty) : TypedSSAVal φ → + ReaderM Op (Σ (ty : Ty), Ctxt.Var Γ ty) | (.SSAVal valStx, tyStx) => do - let ty ← MLIRType.mkTy tyStx - let var ← getValFromContext Γ valStx ty + let ty ← instTransformTy.mkTy tyStx + let var ← getValFromCtxt Γ valStx ty return ⟨ty, var⟩ /-- Declare a new variable, by adding the passed name to the name mapping stored in the monad state -/ -def TypedSSAVal.newVal (Γ : Context (MTy φ)) : TypedSSAVal φ → - BuilderM (MOp φ) (Σ (Γ' : DerivedContext Γ) (ty : MTy φ), Ctxt.Var Γ'.ctxt ty) +def TypedSSAVal.newVal [instTransformTy : TransformTy Op Ty φ] (Γ : Ctxt Ty) : TypedSSAVal φ → + BuilderM Op (Σ (Γ' : DerivedCtxt Γ) (ty : Ty), Ctxt.Var Γ'.ctxt ty) | (.SSAVal valStx, tyStx) => do - let ty ← MLIRType.mkTy tyStx + let ty ← instTransformTy.mkTy tyStx let ⟨Γ, var⟩ ← addValToMapping Γ valStx ty return ⟨Γ, ty, var⟩ -def mkExpr (Γ : Context (MTy φ)) (opStx : MLIR.AST.Op φ) : ReaderM (MOp φ) (Σ ty, Expr (MOp φ) Γ ty) := do +/- +def mkExpr (Γ : Ctxt Ty) (opStx : MLIR.AST.Op φ) : ReaderM Op (Σ ty, Expr Op Γ ty) := do match opStx.args with | v₁Stx::v₂Stx::v₃Stx::[] => let ⟨.bitvec w₁, v₁⟩ ← TypedSSAVal.mkVal Γ v₁Stx @@ -466,8 +460,10 @@ def mkExpr (Γ : Context (MTy φ)) (opStx : MLIR.AST.Op φ) : ReaderM (MOp φ) ( else throw <| .generic s!"invalid (0-ary) expression {opStx.name}" | _ => throw <| .generic s!"unsupported expression (with unsupported arity) {opStx.name}" +-/ -def mkReturn (Γ : Context (MTy φ)) (opStx : MLIR.AST.Op φ) : ReaderM (MOp φ) (Σ ty, Com (MOp φ) Γ ty) := +/- +def mkReturn (Γ : Ctxt Ty) (opStx : MLIR.AST.Op φ) : ReaderM Op (Σ ty, Com Op Γ ty) := if opStx.name == "llvm.return" then match opStx.args with | vStx::[] => do @@ -475,21 +471,25 @@ def mkReturn (Γ : Context (MTy φ)) (opStx : MLIR.AST.Op φ) : ReaderM (MOp φ) return ⟨ty, _root_.Com.ret v⟩ | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {opStx.args.length})" else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" +-/ /-- Given a list of `TypedSSAVal`s, treat each as a binder and declare a new variable with the given name and type -/ -private def declareBindings (Γ : Context (MTy φ)) (vals : List (TypedSSAVal φ)) : - BuilderM (MOp φ) (DerivedContext Γ) := do +private def declareBindings [TransformTy Op Ty φ] (Γ : Ctxt Ty) (vals : List (TypedSSAVal φ)) : + BuilderM Op (DerivedCtxt Γ) := do vals.foldlM (fun Γ' ssaVal => do let ⟨Γ'', _⟩ ← TypedSSAVal.newVal Γ'.ctxt ssaVal return Γ'' - ) (.ofContext Γ) - -private def mkComHelper (Γ : Context (MTy φ)) : - List (MLIR.AST.Op φ) → BuilderM (MOp φ) (Σ (ty : _), Com (MOp φ) Γ ty) - | [retStx] => mkReturn Γ retStx + ) (.ofCtxt Γ) + +private def mkComHelper + [TransformTy Op Ty φ] [instTransformExpr : TransformExpr Op Ty φ] [instTransformReturn : TransformReturn Op Ty φ] + (Γ : Ctxt Ty) : + List (MLIR.AST.Op φ) → BuilderM Op (Σ (ty : _), Com Op Γ ty) + | [retStx] => do + instTransformReturn.mkReturn Γ retStx | lete::rest => do - let ⟨ty₁, expr⟩ ← (mkExpr Γ lete) + let ⟨ty₁, expr⟩ ← (instTransformExpr.mkExpr Γ lete) if h : lete.res.length != 1 then throw <| .generic s!"Each let-binding must have exactly one name on the left-hand side. Operations with multiple, or no, results are not yet supported.\n\tExpected a list of length one, found `{repr lete}`" else @@ -498,7 +498,8 @@ private def mkComHelper (Γ : Context (MTy φ)) : return ⟨ty₂, Com.lete expr body⟩ | [] => throw <| .generic "Ill-formed (empty) block" -def mkCom (reg : MLIR.AST.Region φ) : ExceptM (MOp φ) (Σ (Γ : Context (MTy φ)) (ty : MTy φ), Com (MOp φ) Γ ty) := +def mkCom [TransformTy Op Ty φ] [TransformExpr Op Ty φ] [TransformReturn Op Ty φ] + (reg : MLIR.AST.Region φ) : ExceptM Op (Σ (Γ : Ctxt Ty) (ty : Ty), Com Op Γ ty) := match reg.ops with | [] => throw <| .generic "Ill-formed region (empty)" | coms => BuilderM.runWithNewMapping <| do @@ -511,7 +512,8 @@ def mkCom (reg : MLIR.AST.Region φ) : ExceptM (MOp φ) (Σ (Γ : Context (MTy Finally, we show how to instantiate a family of programs to a concrete program -/ -def _root_.InstCombine.MTy.instantiate (vals : Vector Nat φ) : MTy φ → InstCombine.Ty +/- +def _root_.InstCombine.MTy.instantiate (vals : Vector Nat φ) : Ty → InstCombine.Ty | .bitvec w => .bitvec <| .concrete <| w.instantiate vals def _root_.InstCombine.MOp.instantiate (vals : Vector Nat φ) : MOp φ → InstCombine.Op @@ -535,10 +537,10 @@ def _root_.InstCombine.MOp.instantiate (vals : Vector Nat φ) : MOp φ → InstC | .icmp c w => .icmp c (w.instantiate vals) | .const w val => .const (w.instantiate vals) val -def Context.instantiate (vals : Vector Nat φ) (Γ : Context (MTy φ)) : Ctxt InstCombine.Ty := +def Ctxt.instantiate (vals : Vector Nat φ) (Γ : Ctxt Ty) : Ctxt InstCombine.Ty := Γ.map (MTy.instantiate vals) -def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCombine.Op) where +def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism Op (InstCombine.Op) where mapOp := MOp.instantiate vals mapTy := MTy.instantiate vals preserves_signature op := by @@ -550,9 +552,9 @@ def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism (MOp φ) (InstCo open InstCombine (Op Ty) in def mkComInstantiate (reg : Region φ) : - ExceptM (MOp φ) (Vector Nat φ → Σ (Γ : Ctxt InstCombine.Ty) (ty : InstCombine.Ty), Com InstCombine.Op Γ ty) := do + ExceptM Op (Vector Nat φ → Σ (Γ : Ctxt InstCombine.Ty) (ty : InstCombine.Ty), Com InstCombine.Op Γ ty) := do let ⟨Γ, ty, com⟩ ← mkCom reg return fun vals => ⟨Γ.instantiate vals, ty.instantiate vals, com.map (MOp.instantiateCom vals)⟩ - +-/ end MLIR.AST diff --git a/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean b/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean deleted file mode 100644 index 65aa180d5..000000000 --- a/SSA/Projects/InstCombine/LLVM/Transform/Dialects/InstCombine.lean +++ /dev/null @@ -1,323 +0,0 @@ -import SSA.Projects.InstCombine.LLVM.Transform -import SSA.Projects.InstCombine.LLVM.Transform.Instantiate - -namespace InstCombine - -open MLIR -open AST (TypedSSAVal) -open Ctxt (Var) - -protected abbrev ExceptM (φ) := AST.ExceptM (MOp φ) -protected abbrev ReaderM (φ) := AST.ReaderM (MOp φ) -protected abbrev BuilderM (φ) := AST.BuilderM (MOp φ) - -protected abbrev Context (φ) := List (MTy φ) -protected abbrev Expr {φ} := Expr (MOp φ) - -open InstCombine (ExceptM ReaderM BuilderM Context Expr) - -def mkType : AST.MLIRType φ → ExceptM φ (MTy φ) - | .int .Signless w => return .bitvec w - | _ => throw .unsupportedType -- "Unsupported type" - -def mkUnaryOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) - (e : Var Γ ty) : ExceptM φ <| Expr Γ ty := - match ty with - | .bitvec w => - match op with - -- Can't use a single arm, Lean won't write the rhs accordingly - | .neg w' => if h : w = w' - then return ⟨ - .neg w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | .not w' => if h : w = w' - then return ⟨ - .not w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | .copy w' => if h : w = w' - then return ⟨ - .copy w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw <| .unsupportedUnaryOp - --- def mkBinOp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) --- (e₁ e₂ : Var Γ ty) : ExceptM φ <| Expr Γ ty := --- match ty with --- | .bitvec w => --- match op with --- -- Can't use a single arm, Lean won't write the rhs accordingly --- | .add w' => if h : w = w' --- then return ⟨ --- .add w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .and w' => if h : w = w' --- then return ⟨ --- .and w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .or w' => if h : w = w' --- then return ⟨ --- .or w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .xor w' => if h : w = w' --- then return ⟨ --- .xor w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .shl w' => if h : w = w' --- then return ⟨ --- .shl w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .lshr w' => if h : w = w' --- then return ⟨ --- .lshr w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .ashr w' => if h : w = w' --- then return ⟨ --- .ashr w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .urem w' => if h : w = w' --- then return ⟨ --- .urem w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .srem w' => if h : w = w' --- then return ⟨ --- .srem w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .mul w' => if h : w = w' --- then return ⟨ --- .mul w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .sub w' => if h : w = w' --- then return ⟨ --- .sub w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .sdiv w' => if h : w = w' --- then return ⟨ --- .sdiv w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | .udiv w' => if h : w = w' --- then return ⟨ --- .udiv w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | _ => throw <| .unsupportedBinaryOp - --- def mkIcmp {Γ : Context φ} {ty : MTy φ} (op : MOp φ) --- (e₁ e₂ : Var Γ ty) : ExceptM φ <| Expr Γ (.bitvec 1) := --- match ty with --- | .bitvec w => --- match op with --- | .icmp p w' => if h : w = w' --- then return ⟨ --- .icmp p w', --- by simp [OpSignature.outTy, signature, h], --- .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | _ => throw .unsupportedOp -- unsupported icmp operation - --- def mkSelect {Γ : Context φ} {ty : MTy φ} (op : MOp φ) --- (c : Var Γ (.bitvec 1)) (e₁ e₂ : Var Γ ty) : --- ExceptM φ <| Expr Γ ty := --- match ty with --- | .bitvec w => --- match op with --- | .select w' => if h : w = w' --- then return ⟨ --- .select w', --- by simp [OpSignature.outTy, signature, h], --- .cons c <|.cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , --- .nil --- ⟩ --- else throw <| .widthError w w' --- | _ => throw .unsupportedOp -- "Unsupported select operation" - --- def mkOpExpr {Γ : Context φ} (op : MOp φ) --- (arg : HVector (fun t => Ctxt.Var Γ t) (OpSignature.sig op)) : --- ExceptM φ <| Expr Γ (OpSignature.outTy op) := --- match op with --- | .and _ | .or _ | .xor _ | .shl _ | .lshr _ | .ashr _ --- | .add _ | .mul _ | .sub _ | .udiv _ | .sdiv _ --- | .srem _ | .urem _ => --- let (e₁, e₂) := arg.toTuple --- mkBinOp op e₁ e₂ --- | .icmp _ _ => --- let (e₁, e₂) := arg.toTuple --- mkIcmp op e₁ e₂ --- | .not _ | .neg _ | .copy _ => --- mkUnaryOp op arg.head --- | .select _ => --- let (c, e₁, e₂) := arg.toTuple --- mkSelect op c e₁ e₂ --- | .const .. => throw .unsupportedOp -- "Tried to build Op expression from constant" - --- def mkExpr (Γ : Context φ) (opStx : AST.Op φ) (args : List (Σ (ty : MTy φ), Ctxt.Var Γ ty)) : --- ReaderM φ (Σ ty, Expr Γ ty) := do --- match args with --- | ⟨.bitvec w₁, v₁⟩::⟨.bitvec w₂, v₂⟩::[] => --- -- let ty₁ := ty₁.instantiave --- let op ← match opStx.name with --- | "llvm.and" => pure (MOp.and w₁) --- | "llvm.or" => pure (MOp.or w₁) --- | "llvm.xor" => pure (MOp.xor w₁) --- | "llvm.shl" => pure (MOp.shl w₁) --- | "llvm.lshr" => pure (MOp.lshr w₁) --- | "llvm.ashr" => pure (MOp.ashr w₁) --- | "llvm.urem" => pure (MOp.urem w₁) --- | "llvm.srem" => pure (MOp.srem w₁) --- | "llvm.select" => pure (MOp.select w₁) --- | "llvm.add" => pure (MOp.add w₁) --- | "llvm.mul" => pure (MOp.mul w₁) --- | "llvm.sub" => pure (MOp.sub w₁) --- | "llvm.sdiv" => pure (MOp.sdiv w₁) --- | "llvm.udiv" => pure (MOp.udiv w₁) --- --| "llvm.icmp" => return InstCombine.Op.icmp v₁.width --- | _ => throw .unsupportedOp -- "Unsuported operation or invalid arguments" --- if hty : w₁ = w₂ then --- let binOp ← (mkBinOp op v₁ (hty ▸ v₂) : ExceptM ..) --- return ⟨.bitvec w₁, binOp⟩ --- else --- throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" --- | ⟨.bitvec w, v⟩::[] => --- let op ← match opStx.name with --- | "llvm.not" => pure <| MOp.not w --- | "llvm.neg" => pure <| MOp.neg w --- | _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}" --- let op ← mkUnaryOp op v --- return ⟨.bitvec w, op⟩ --- | [] => --- if opStx.name == "llvm.mlir.constant" --- then do --- let some att := opStx.attrs.getAttr "value" --- | throw <| .generic "tried to resolve constant without 'value' attribute" --- match att with --- | .int val ty => --- let opTy@(MTy.bitvec w) ← mkType ty --- return ⟨opTy, ⟨ --- MOp.const w val, --- by simp [OpSignature.outTy, signature, *], --- HVector.nil, --- HVector.nil --- ⟩⟩ --- | _ => throw <| .generic "invalid constant attribute" --- else --- throw <| .generic s!"invalid (0-ary) expression {opStx.name}" --- | _ => throw <| .generic s!"unsupported expression (with unsupported arity) {opStx.name}" - --- def mkReturn (Γ : Context φ) (opStx : AST.Op φ) (args : List (Σ (ty : MTy φ), Ctxt.Var Γ ty)) : --- ReaderM φ (Σ ty, ICom (MOp φ) Γ ty) := --- if opStx.name == "llvm.return" --- then match args with --- | ⟨ty, v⟩::[] => do --- return ⟨ty, ICom.ret v⟩ --- | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {args.length})" --- else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" - --- /-! ## Instantiation -/ - - --- def MTy.instantiate (vals : Vector Nat φ) : (MTy φ) → Ty --- | .bitvec w => .bitvec <| .concrete <| w.instantiate vals - --- def MOp.instantiate (vals : Vector Nat φ) : MOp φ → Op --- | .and w => .and (w.instantiate vals) --- | .or w => .or (w.instantiate vals) --- | .not w => .not (w.instantiate vals) --- | .xor w => .xor (w.instantiate vals) --- | .shl w => .shl (w.instantiate vals) --- | .lshr w => .lshr (w.instantiate vals) --- | .ashr w => .ashr (w.instantiate vals) --- | .urem w => .urem (w.instantiate vals) --- | .srem w => .srem (w.instantiate vals) --- | .select w => .select (w.instantiate vals) --- | .add w => .add (w.instantiate vals) --- | .mul w => .mul (w.instantiate vals) --- | .sub w => .sub (w.instantiate vals) --- | .neg w => .neg (w.instantiate vals) --- | .copy w => .copy (w.instantiate vals) --- | .sdiv w => .sdiv (w.instantiate vals) --- | .udiv w => .udiv (w.instantiate vals) --- | .icmp c w => .icmp c (w.instantiate vals) --- | .const w val => .const (w.instantiate vals) val - --- /-! ## Instances -/ - --- instance : AST.TransformDialect (MOp φ) (MTy φ) φ where --- mkType := mkType --- mkReturn := mkReturn --- mkExpr := mkExpr - --- instance : TransformDialectInstantiate Op φ Ty (MOp φ) (MTy φ) where --- morphism vals := { --- mapOp := MOp.instantiate vals, --- mapTy := MTy.instantiate vals, --- preserves_signature := by --- intro op --- simp only [MTy.instantiate, MOp.instantiate, ConcreteOrMVar.instantiate, (· <$> ·), signature, --- InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, --- true_and] --- cases op <;> simp only [List.map, and_self, List.cons.injEq] --- } diff --git a/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean b/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean deleted file mode 100644 index adff71241..000000000 --- a/SSA/Projects/InstCombine/LLVM/Transform/Instantiate.lean +++ /dev/null @@ -1,23 +0,0 @@ -import SSA.Core.Framework -import SSA.Projects.InstCombine.LLVM.Transform - -namespace MLIR - -/-! - ## Instantiation - Finally, we show how to instantiate a family of programs to a concrete program --/ - -class TransformDialectInstantiate (Op : Type) (φ : Nat) (Ty MOp MTy : outParam Type) - [OpSignature Op Ty] [AST.TransformDialect MOp MTy φ] where - morphism : Vector Nat φ → DialectMorphism MOp Op - - -set_option linter.unusedVariables false -- linter gives a false positive for `φ` in `[∀ φ, ...]` - -variable (Op) (φ) {Ty} {MOp MTy} [OpSignature Op Ty] - [AST.TransformDialect MOp MTy φ] - [instInst : TransformDialectInstantiate Op φ Ty MOp MTy] - [DecidableEq MTy] - -end MLIR diff --git a/SSA/Projects/InstCombine/Tactic.lean b/SSA/Projects/InstCombine/Tactic.lean index 53093456b..6c8df331b 100644 --- a/SSA/Projects/InstCombine/Tactic.lean +++ b/SSA/Projects/InstCombine/Tactic.lean @@ -44,7 +44,8 @@ macro "simp_alive_peephole" : tactic => LLVM.mul?, LLVM.udiv?, LLVM.sdiv?, LLVM.urem?, LLVM.srem?, LLVM.sshr, LLVM.lshr?, LLVM.ashr?, LLVM.shl?, LLVM.select?, LLVM.const?, LLVM.icmp?, - HVector.toTuple, List.nthLe, bitvec_minus_one] + HVector.toTuple, List.nthLe, bitvec_minus_one, + InstcombineTransformDialect.MOp.instantiateCom] try intros v0 try intros v1 try intros v2 From ad32b0646a8f353cf973db0d380cf681f53823f6 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 15 Nov 2023 06:44:30 +0000 Subject: [PATCH 09/28] wip: try to reduce simp set --- .../InstCombine/AliveAutoGenerated.lean | 90 ++++++++++++++++++- SSA/Projects/InstCombine/Tactic.lean | 15 +++- 2 files changed, 100 insertions(+), 5 deletions(-) diff --git a/SSA/Projects/InstCombine/AliveAutoGenerated.lean b/SSA/Projects/InstCombine/AliveAutoGenerated.lean index 88e59ff74..4d3d6c2c2 100644 --- a/SSA/Projects/InstCombine/AliveAutoGenerated.lean +++ b/SSA/Projects/InstCombine/AliveAutoGenerated.lean @@ -12,6 +12,46 @@ namespace AliveAutoGenerated set_option pp.proofs false set_option pp.proofs.withType false +def eg0 (w : Nat) := +[alive_icom ( w )| { +^bb0(%C0 : _): + "llvm.return" (%C0) : (_) -> () +}] + +def eg1 (w : Nat) := +[alive_icom ( w )| { +^bb0(%C0 : _): + "llvm.return" (%C0) : (_) -> () +}] + +theorem eg0_eq_eg1 : eg0 w ⊑ eg1 w := by + unfold eg0 + unfold eg1 + simp[DialectMorphism.mapTy, + InstcombineTransformDialect.MOp.instantiateCom, + InstcombineTransformDialect.instantiateMTy, + ConcreteOrMVar.instantiate, Com.Refinement] + +def eg2 (w : Nat) := +[alive_icom ( w )| { +^bb0(%C0 : _): + %v1 = "llvm.add" (%C0, %C0) : (_, _) -> (_) + "llvm.return" (%C0) : (_) -> () +}] + +def eg3 (w : Nat) := +[alive_icom ( w )| { +^bb0(%C0 : _): + %v1 = "llvm.mlir.constant" () { value = 2 : _ } :() -> (_) + %v2 = "llvm.mul" (%C0, %v1) : (_, _) -> (_) + "llvm.return" (%C0) : (_) -> () +}] + +theorem eg2_eq_eg3 : eg2 w ⊑ eg3 w := by + unfold eg2 + unfold eg3 + simp_alive_peephole + -- Name:AddSub:1043 -- precondition: true /- @@ -53,9 +93,55 @@ def alive_AddSub_1043_tgt (w : Nat) := }] theorem alive_AddSub_1043 (w : Nat) : alive_AddSub_1043_src w ⊑ alive_AddSub_1043_tgt w := by unfold alive_AddSub_1043_src alive_AddSub_1043_tgt - simp_alive_peephole - apply bitvec_AddSub_1043 + dsimp only [Com.Refinement] + intros Γv + simp[OpDenote.denote, + InstCombine.Op.denote, HVector.toPair, HVector.toTriple, pairMapM, BitVec.Refinement, + bind, Option.bind, pure, DerivedCtxt.ofCtxt, DerivedCtxt.snoc, + Ctxt.snoc, ConcreteOrMVar.instantiate, Vector.get, HVector.toSingle, + LLVM.and?, LLVM.or?, LLVM.xor?, LLVM.add?, LLVM.sub?, + LLVM.mul?, LLVM.udiv?, LLVM.sdiv?, LLVM.urem?, LLVM.srem?, + LLVM.sshr, LLVM.lshr?, LLVM.ashr?, LLVM.shl?, LLVM.select?, + LLVM.const?, LLVM.icmp?, + DerivedCtxt.ofCtxt, InstcombineTransformDialect.MOp.instantiateCom, InstcombineTransformDialect.instantiateMTy, + List.map, + HVector.toTuple, List.nthLe, bitvec_minus_one, + DialectMorphism.mapTy, + InstcombineTransformDialect.instantiateMTy, + InstcombineTransformDialect.instantiateMOp, + InstcombineTransformDialect.MOp.instantiateCom, + InstcombineTransformDialect.instantiateCtxt, + ConcreteOrMVar.instantiate, Com.Refinement, + InstCombine.MOp.add, + InstCombine.MOp.const, + InstCombine.MOp.xor, + InstCombine.MOp.and] at Γv + try simp [OpDenote.denote, + InstCombine.Op.denote, HVector.toPair, HVector.toTriple, pairMapM, BitVec.Refinement, + bind, Option.bind, pure, DerivedCtxt.ofCtxt, DerivedCtxt.snoc, + Ctxt.snoc, + ConcreteOrMVar.instantiate, Vector.get, HVector.toSingle, + LLVM.and?, LLVM.or?, LLVM.xor?, LLVM.add?, LLVM.sub?, + LLVM.mul?, LLVM.udiv?, LLVM.sdiv?, LLVM.urem?, LLVM.srem?, + LLVM.sshr, LLVM.lshr?, LLVM.ashr?, LLVM.shl?, LLVM.select?, + LLVM.const?, LLVM.icmp?, + HVector.toTuple, List.nthLe, bitvec_minus_one, + DialectMorphism.mapTy, + InstcombineTransformDialect.instantiateMTy, + InstcombineTransformDialect.instantiateMOp, + InstcombineTransformDialect.MOp.instantiateCom, + InstcombineTransformDialect.instantiateCtxt, + ConcreteOrMVar.instantiate, Com.Refinement, + Com.denote, Expr.denote, + Com.denote, Expr.denote, HVector.denote, Var.zero_eq_last, Var.succ_eq_toSnoc, + Ctxt.empty, Ctxt.empty_eq, Ctxt.snoc, Ctxt.Valuation.nil, Ctxt.Valuation.snoc_last, + Ctxt.ofList, Ctxt.Valuation.snoc_toSnoc, + HVector.map, HVector.toPair, HVector.toTuple, OpDenote.denote, Expr.op_mk, Expr.args_mk] + -- simp only [Var.last, Var.toSnoc] + + -- apply bitvec_AddSub_1043 +#exit /-# Early Exit delete this to check the rest of the file. The `#exit` diff --git a/SSA/Projects/InstCombine/Tactic.lean b/SSA/Projects/InstCombine/Tactic.lean index 6c8df331b..fe699ab0f 100644 --- a/SSA/Projects/InstCombine/Tactic.lean +++ b/SSA/Projects/InstCombine/Tactic.lean @@ -32,20 +32,29 @@ macro "simp_alive_peephole" : tactic => ( dsimp only [Com.Refinement] intros Γv + simp [InstcombineTransformDialect.MOp.instantiateCom, InstcombineTransformDialect.instantiateMOp, + ConcreteOrMVar.instantiate, Vector.get, List.nthLe, List.length_singleton, Fin.coe_fin_one, Fin.zero_eta, + List.get_cons_zero, Function.comp_apply, InstcombineTransformDialect.instantiateMTy, Ctxt.empty_eq, DerivedCtxt.snoc, + DerivedCtxt.ofCtxt, List.map_eq_map, List.map] at Γv simp_peephole at Γv /- note that we need the `HVector.toPair`, `HVector.toSingle`, `HVector.toTriple` lemmas since it's used in `InstCombine.Op.denote` We need `HVector.toTuple` since it's used in `MLIR.AST.mkOpExpr`. -/ try simp (config := {decide := false}) only [OpDenote.denote, InstCombine.Op.denote, HVector.toPair, HVector.toTriple, pairMapM, BitVec.Refinement, - bind, Option.bind, pure, DerivedContext.ofContext, DerivedContext.snoc, - Ctxt.snoc, MOp.instantiateCom, InstCombine.MTy.instantiate, + bind, Option.bind, pure, DerivedCtxt.ofCtxt, DerivedCtxt.snoc, + Ctxt.snoc, ConcreteOrMVar.instantiate, Vector.get, HVector.toSingle, LLVM.and?, LLVM.or?, LLVM.xor?, LLVM.add?, LLVM.sub?, LLVM.mul?, LLVM.udiv?, LLVM.sdiv?, LLVM.urem?, LLVM.srem?, LLVM.sshr, LLVM.lshr?, LLVM.ashr?, LLVM.shl?, LLVM.select?, LLVM.const?, LLVM.icmp?, HVector.toTuple, List.nthLe, bitvec_minus_one, - InstcombineTransformDialect.MOp.instantiateCom] + DialectMorphism.mapTy, + InstcombineTransformDialect.instantiateMTy, + InstcombineTransformDialect.instantiateMOp, + InstcombineTransformDialect.MOp.instantiateCom, + InstcombineTransformDialect.instantiateCtxt, + ConcreteOrMVar.instantiate, Com.Refinement] try intros v0 try intros v1 try intros v2 From 96c6c31e70003b37458a7b5aa5e0b48658f0ac1d Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 15 Nov 2023 07:21:30 +0000 Subject: [PATCH 10/28] feat: add MLIR syntax for PaperExamples IR --- .../InstCombine/AliveAutoGenerated.lean | 3 +- SSA/Projects/PaperExamples/PaperExamples.lean | 92 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/SSA/Projects/InstCombine/AliveAutoGenerated.lean b/SSA/Projects/InstCombine/AliveAutoGenerated.lean index 4d3d6c2c2..56274ac97 100644 --- a/SSA/Projects/InstCombine/AliveAutoGenerated.lean +++ b/SSA/Projects/InstCombine/AliveAutoGenerated.lean @@ -50,7 +50,7 @@ def eg3 (w : Nat) := theorem eg2_eq_eg3 : eg2 w ⊑ eg3 w := by unfold eg2 unfold eg3 - simp_alive_peephole + sorry /- We should get this example simplifying first. -/ -- Name:AddSub:1043 -- precondition: true @@ -138,6 +138,7 @@ theorem alive_AddSub_1043 (w : Nat) : alive_AddSub_1043_src w ⊑ alive_AddS Ctxt.ofList, Ctxt.Valuation.snoc_toSnoc, HVector.map, HVector.toPair, HVector.toTuple, OpDenote.denote, Expr.op_mk, Expr.args_mk] -- simp only [Var.last, Var.toSnoc] + sorry -- apply bitvec_AddSub_1043 diff --git a/SSA/Projects/PaperExamples/PaperExamples.lean b/SSA/Projects/PaperExamples/PaperExamples.lean index ca9694912..f11eaeb4c 100644 --- a/SSA/Projects/PaperExamples/PaperExamples.lean +++ b/SSA/Projects/PaperExamples/PaperExamples.lean @@ -1,6 +1,11 @@ +import Qq +import Lean import Mathlib.Logic.Function.Iterate import SSA.Core.Framework import SSA.Core.Util +import SSA.Projects.InstCombine.LLVM.Transform +import SSA.Projects.MLIRSyntax.AST +import SSA.Projects.MLIRSyntax.EDSL set_option pp.proofs false set_option pp.proofs.withType false @@ -51,6 +56,79 @@ def add {Γ : Ctxt _} (e₁ e₂ : Var Γ .int) : Expr Op Γ .int := attribute [local simp] Ctxt.snoc +namespace MLIR2Simple + +def mkTy : MLIR.AST.MLIRType φ → MLIR.AST.ExceptM Op Ty + | MLIR.AST.MLIRType.int MLIR.AST.Signedness.Signless w => do + return .int + | _ => throw .unsupportedType + +instance instTransformTy : MLIR.AST.TransformTy Op Ty 0 where + mkTy := mkTy + +def mkExpr (Γ : Ctxt Ty) (opStx : MLIR.AST.Op 0) : MLIR.AST.ReaderM Op (Σ ty, Expr Op Γ ty) := do + match opStx.name with + | "const" => + match opStx.attrs.find_int "value" with + | .some (v, _ty) => return ⟨.int, cst v⟩ + | .none => throw <| .generic s!"expected 'const' to have int attr 'value', found: {repr opStx}" + | "add" => + match opStx.args with + | v₁Stx::v₂Stx::[] => + let ⟨.int, v₁⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₁Stx + let ⟨.int, v₂⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ v₂Stx + return ⟨.int, add v₁ v₂⟩ + | _ => throw <| .generic s!"expected two operands for `add`, found #'{opStx.args.length}' in '{repr opStx.args}'" + | _ => throw <| .unsupportedOp s!"unsupported operation {repr opStx}" + +instance : MLIR.AST.TransformExpr Op Ty 0 where + mkExpr := mkExpr + +def mkReturn (Γ : Ctxt Ty) (opStx : MLIR.AST.Op 0) : MLIR.AST.ReaderM Op (Σ ty, Com Op Γ ty) := + if opStx.name == "return" + then match opStx.args with + | vStx::[] => do + let ⟨ty, v⟩ ← MLIR.AST.TypedSSAVal.mkVal Γ vStx + return ⟨ty, Com.ret v⟩ + | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {opStx.args.length})" + else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" + +instance : MLIR.AST.TransformReturn Op Ty 0 where + mkReturn := mkReturn + +open InstCombine (Op Ty) in + +def mlir2simple (reg : MLIR.AST.Region 0) : + MLIR.AST.ExceptM Op (Σ (Γ : Ctxt Ty) (ty : Ty), Com Op Γ ty) := MLIR.AST.mkCom reg + +open Qq MLIR AST Lean Elab Term Meta in +elab "[toy_icom| " reg:mlir_region "]" : term => do + let ast_stx ← `([mlir_region| $reg]) + let ast ← elabTermEnsuringTypeQ ast_stx q(Region 0) + let mvalues ← `(⟨[], by rfl⟩) + -- let mvalues : Q(Vector Nat 0) ← elabTermEnsuringType mvalues q(Vector Nat 0) + let com := q(ToyNoRegion.MLIR2Simple.mlir2simple $ast) + synthesizeSyntheticMVarsNoPostponing + let com : Q(MLIR.AST.ExceptM Op (Σ (Γ' : Ctxt Ty) (ty : Ty), Com Op Γ' ty)) ← + withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do + withTransparency (mode := TransparencyMode.all) <| + return ←reduce com + let comExpr : Expr := com + trace[Meta] com + trace[Meta] comExpr + match comExpr.app3? ``Except.ok with + | .some (_εexpr, _αexpr, aexpr) => + match aexpr.app4? ``Sigma.mk with + | .some (_αexpr, _βexpr, _fstexpr, sndexpr) => + match sndexpr.app4? ``Sigma.mk with + | .some (_αexpr, _βexpr, _fstexpr, sndexpr) => + return sndexpr + | .none => throwError "Found `Except.ok (Sigma.mk _ WRONG)`, Expected (Except.ok (Sigma.mk _ (Sigma.mk _ _))" + | .none => throwError "Found `Except.ok WRONG`, Expected (Except.ok (Sigma.mk _ _))" + | .none => throwError "Expected `Except.ok`, found {comExpr}" + +end MLIR2Simple + /-- x + 0 -/ def lhs : Com Op (Ctxt.ofList [.int]) .int := -- %c0 = 0 @@ -60,6 +138,20 @@ def lhs : Com Op (Ctxt.ofList [.int]) .int := -- return %out Com.ret ⟨0, by simp [Ctxt.snoc]⟩ +open MLIR AST MLIR2Simple in +/-- Same code, written in MLIR syntax. -/ +def lhs_stx : Com Op (Ctxt.ofList [.int]) .int := + [toy_icom| { + ^bb0(%x : i32): + %c0 = "const" () { value = 0 : i32 } : () -> i32 + %out = "add" (%x, %c0) : (i32, i32) -> i32 + "return" (%out) : (i32) -> (i32) + }] + +/-- Our MLIR syntax elaboration produces what we expect. -/ +theorem hlhs_stx : lhs = lhs_stx := rfl + + /-- x -/ def rhs : Com Op (Ctxt.ofList [.int]) .int := Com.ret ⟨0, by simp⟩ From e81ab5c4b6b284ab82edd54d3199f6f93b53d375 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 15 Nov 2023 07:44:45 +0000 Subject: [PATCH 11/28] feat: get paper examples into MLIR syntax --- SSA/Core/Framework.lean | 6 +++++ .../InstCombine/AliveAutoGenerated.lean | 6 +++-- SSA/Projects/InstCombine/LLVM/Transform.lean | 8 +++++++ SSA/Projects/PaperExamples/PaperExamples.lean | 24 +++++++------------ 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index bdd2c2a06..d5bcbff37 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -1203,12 +1203,18 @@ leaving behind a bare Lean level proposition to be proven. macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : tactic => `(tactic| ( + try simp (config := {decide := false}) only [ + MLIR.AST.DerivedCtxt.ofCtxt_empty, MLIR.AST.DerivedCtxt.ofCtxt, MLIR.AST.DerivedCtxt.snoc + ] -- separate `simp` block so it does not fail if MLIR.AST is not open. try simp (config := {decide := false}) only [ Com.denote, Expr.denote, HVector.denote, Var.zero_eq_last, Var.succ_eq_toSnoc, Ctxt.empty, Ctxt.empty_eq, Ctxt.snoc, Ctxt.Valuation.nil, Ctxt.Valuation.snoc_last, Ctxt.ofList, Ctxt.Valuation.snoc_toSnoc, HVector.map, HVector.toPair, HVector.toTuple, OpDenote.denote, Expr.op_mk, Expr.args_mk, $ts,*] + try simp (config := {decide := false}) only [ + MLIR.AST.DerivedCtxt.ofCtxt_empty, MLIR.AST.DerivedCtxt.ofCtxt, MLIR.AST.DerivedCtxt.snoc + ] -- separate `simp` block so it does not fail if MLIR.AST is not open. generalize $ll { val := 0, property := _ } = a; generalize $ll { val := 1, property := _ } = b; generalize $ll { val := 2, property := _ } = c; diff --git a/SSA/Projects/InstCombine/AliveAutoGenerated.lean b/SSA/Projects/InstCombine/AliveAutoGenerated.lean index 56274ac97..7b916ecd3 100644 --- a/SSA/Projects/InstCombine/AliveAutoGenerated.lean +++ b/SSA/Projects/InstCombine/AliveAutoGenerated.lean @@ -95,6 +95,8 @@ theorem alive_AddSub_1043 (w : Nat) : alive_AddSub_1043_src w ⊑ alive_AddS unfold alive_AddSub_1043_src alive_AddSub_1043_tgt dsimp only [Com.Refinement] intros Γv + sorry + /- simp[OpDenote.denote, InstCombine.Op.denote, HVector.toPair, HVector.toTriple, pairMapM, BitVec.Refinement, bind, Option.bind, pure, DerivedCtxt.ofCtxt, DerivedCtxt.snoc, @@ -138,9 +140,9 @@ theorem alive_AddSub_1043 (w : Nat) : alive_AddSub_1043_src w ⊑ alive_AddS Ctxt.ofList, Ctxt.Valuation.snoc_toSnoc, HVector.map, HVector.toPair, HVector.toTuple, OpDenote.denote, Expr.op_mk, Expr.args_mk] -- simp only [Var.last, Var.toSnoc] - sorry - -- apply bitvec_AddSub_1043 + sorry + -/ #exit diff --git a/SSA/Projects/InstCombine/LLVM/Transform.lean b/SSA/Projects/InstCombine/LLVM/Transform.lean index 57128964c..bb81c44ea 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform.lean @@ -68,12 +68,20 @@ structure DerivedCtxt (Γ : Ctxt Ty) where namespace DerivedCtxt /-- Every context is trivially derived from itself -/ +@[simp] abbrev ofCtxt (Γ : Ctxt Ty) : DerivedCtxt Γ := ⟨Γ, .zero _⟩ +/-- value of a dervied context from an empty context, + is the empty context with a zero diff. -/ +@[simp] +theorem ofCtxt_empty : MLIR.AST.DerivedCtxt.ofCtxt ([] : Ctxt Ty) = ⟨[], .zero _⟩ := rfl + /-- `snoc` of a derived context applies `snoc` to the underlying context, and updates the diff -/ +@[simp] def snoc {Γ : Ctxt Ty} : DerivedCtxt Γ → Ty → DerivedCtxt Γ | ⟨ctxt, diff⟩, ty => ⟨ty::ctxt, diff.toSnoc⟩ +@[simp] instance {Γ : Ctxt Ty} : CoeHead (DerivedCtxt Γ) (Ctxt Ty) where coe := fun ⟨Γ', _⟩ => Γ' diff --git a/SSA/Projects/PaperExamples/PaperExamples.lean b/SSA/Projects/PaperExamples/PaperExamples.lean index f11eaeb4c..5269fbdd4 100644 --- a/SSA/Projects/PaperExamples/PaperExamples.lean +++ b/SSA/Projects/PaperExamples/PaperExamples.lean @@ -129,18 +129,9 @@ elab "[toy_icom| " reg:mlir_region "]" : term => do end MLIR2Simple +open MLIR AST MLIR2Simple in /-- x + 0 -/ def lhs : Com Op (Ctxt.ofList [.int]) .int := - -- %c0 = 0 - Com.lete (cst 0) <| - -- %out = %x + %c0 - Com.lete (add ⟨1, by simp [Ctxt.snoc]⟩ ⟨0, by simp [Ctxt.snoc]⟩ ) <| - -- return %out - Com.ret ⟨0, by simp [Ctxt.snoc]⟩ - -open MLIR AST MLIR2Simple in -/-- Same code, written in MLIR syntax. -/ -def lhs_stx : Com Op (Ctxt.ofList [.int]) .int := [toy_icom| { ^bb0(%x : i32): %c0 = "const" () { value = 0 : i32 } : () -> i32 @@ -148,13 +139,13 @@ def lhs_stx : Com Op (Ctxt.ofList [.int]) .int := "return" (%out) : (i32) -> (i32) }] -/-- Our MLIR syntax elaboration produces what we expect. -/ -theorem hlhs_stx : lhs = lhs_stx := rfl - - +open MLIR AST MLIR2Simple in /-- x -/ def rhs : Com Op (Ctxt.ofList [.int]) .int := - Com.ret ⟨0, by simp⟩ + [toy_icom| { + ^bb0(%x : i32): + "return" (%x) : (i32) -> (i32) + }] def p1 : PeepholeRewrite Op [.int] .int := { lhs := lhs, rhs := rhs, correct := @@ -171,9 +162,10 @@ def p1 : PeepholeRewrite Op [.int] .int := simp_peephole [add, cst] at Γv /- ⊢ ∀ (a : BitVec 32), a + BitVec.ofInt 32 0 = a -/ intros a + simp[MLIR.AST.DerivedCtxt.snoc, MLIR.AST.DerivedCtxt.ofCtxt] ring /- goals accomplished 🎉 -/ - sorry + sorry -- TODO: import ring instance from other file. } def ex1' : Com Op (Ctxt.ofList [.int]) .int := rewritePeepholeAt p1 1 lhs From 1d50332a7904961163261e4ad7b5505700c9522a Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 15 Nov 2023 07:47:55 +0000 Subject: [PATCH 12/28] fix: change tactics to be the right ones --- SSA/Core/Framework.lean | 6 ++++-- SSA/Projects/PaperExamples/PaperExamples.lean | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index d5bcbff37..b1baaf2e4 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -1204,7 +1204,8 @@ macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : `(tactic| ( try simp (config := {decide := false}) only [ - MLIR.AST.DerivedCtxt.ofCtxt_empty, MLIR.AST.DerivedCtxt.ofCtxt, MLIR.AST.DerivedCtxt.snoc + DerivedCtxt.snoc, DerivedCtxt.ofCtxt, + DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last ] -- separate `simp` block so it does not fail if MLIR.AST is not open. try simp (config := {decide := false}) only [ Com.denote, Expr.denote, HVector.denote, Var.zero_eq_last, Var.succ_eq_toSnoc, @@ -1213,7 +1214,8 @@ macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : HVector.map, HVector.toPair, HVector.toTuple, OpDenote.denote, Expr.op_mk, Expr.args_mk, $ts,*] try simp (config := {decide := false}) only [ - MLIR.AST.DerivedCtxt.ofCtxt_empty, MLIR.AST.DerivedCtxt.ofCtxt, MLIR.AST.DerivedCtxt.snoc + DerivedCtxt.snoc, DerivedCtxt.ofCtxt, + DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last ] -- separate `simp` block so it does not fail if MLIR.AST is not open. generalize $ll { val := 0, property := _ } = a; generalize $ll { val := 1, property := _ } = b; diff --git a/SSA/Projects/PaperExamples/PaperExamples.lean b/SSA/Projects/PaperExamples/PaperExamples.lean index 5269fbdd4..9e9b2b668 100644 --- a/SSA/Projects/PaperExamples/PaperExamples.lean +++ b/SSA/Projects/PaperExamples/PaperExamples.lean @@ -147,6 +147,7 @@ def rhs : Com Op (Ctxt.ofList [.int]) .int := "return" (%x) : (i32) -> (i32) }] +open MLIR AST MLIR2Simple in def p1 : PeepholeRewrite Op [.int] .int := { lhs := lhs, rhs := rhs, correct := by @@ -162,7 +163,6 @@ def p1 : PeepholeRewrite Op [.int] .int := simp_peephole [add, cst] at Γv /- ⊢ ∀ (a : BitVec 32), a + BitVec.ofInt 32 0 = a -/ intros a - simp[MLIR.AST.DerivedCtxt.snoc, MLIR.AST.DerivedCtxt.ofCtxt] ring /- goals accomplished 🎉 -/ sorry -- TODO: import ring instance from other file. From 42d06db204a3bab750364fde8fb12921f17530a7 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Wed, 15 Nov 2023 09:12:17 +0000 Subject: [PATCH 13/28] feat: shuffle framework files to be able to write simp_peeople_mlir tactic that works well for simple dialects --- SSA/Core/ErasedContext.lean | 37 +++++++++++++ SSA/Core/Framework.lean | 26 ++++++++-- .../InstCombine/AliveAutoGenerated.lean | 13 +++-- SSA/Projects/InstCombine/Base.lean | 4 +- SSA/Projects/InstCombine/LLVM/Transform.lean | 52 +++---------------- SSA/Projects/InstCombine/Tactic.lean | 17 +++--- SSA/Projects/MLIRSyntax/AST.lean | 2 +- SSA/Projects/PaperExamples/PaperExamples.lean | 6 +-- 8 files changed, 94 insertions(+), 63 deletions(-) diff --git a/SSA/Core/ErasedContext.lean b/SSA/Core/ErasedContext.lean index 2d0690511..326c7175e 100644 --- a/SSA/Core/ErasedContext.lean +++ b/SSA/Core/ErasedContext.lean @@ -381,5 +381,42 @@ instance : HAdd (Diff Γ₁ Γ₂) (Diff Γ₂ Γ₃) (Diff Γ₁ Γ₃) := ⟨a end Diff +/-## Derived Contexts: Contexts that grow a base context-/ +structure DerivedCtxt (Γ : Ctxt Ty) where + ctxt : Ctxt Ty + diff : Ctxt.Diff Γ ctxt + +namespace DerivedCtxt + +/-- Every context is trivially derived from itself -/ +@[simp] +abbrev ofCtxt (Γ : Ctxt Ty) : DerivedCtxt Γ := ⟨Γ, .zero _⟩ + +/-- value of a dervied context from an empty context, + is the empty context with a zero diff. -/ +@[simp] +theorem ofCtxt_empty : DerivedCtxt.ofCtxt ([] : Ctxt Ty) = ⟨[], .zero _⟩ := rfl + +/-- `snoc` of a derived context applies `snoc` to the underlying context, and updates the diff -/ +@[simp] +def snoc {Γ : Ctxt Ty} : DerivedCtxt Γ → Ty → DerivedCtxt Γ + | ⟨ctxt, diff⟩, ty => ⟨ty::ctxt, diff.toSnoc⟩ + +@[simp] +instance {Γ : Ctxt Ty} : CoeHead (DerivedCtxt Γ) (Ctxt Ty) where + coe := fun ⟨Γ', _⟩ => Γ' + +instance {Γ : Ctxt Ty} : CoeDep (Ctxt Ty) Γ (DerivedCtxt Γ) where + coe := ⟨Γ, .zero _⟩ + +instance {Γ : Ctxt Ty} {Γ' : DerivedCtxt Γ} : + CoeHead (DerivedCtxt (Γ' : Ctxt Ty)) (DerivedCtxt Γ) where + coe := fun ⟨Γ'', diff⟩ => ⟨Γ'', Γ'.diff + diff⟩ + +instance {Γ' : DerivedCtxt Γ} : Coe (Ctxt.Var Γ t) (Ctxt.Var (Γ' : Ctxt Ty) t) where + coe v := Γ'.diff.toHom v + +end DerivedCtxt + end Ctxt diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index b1baaf2e4..39c380988 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -7,6 +7,8 @@ import Mathlib.Data.Finset.Basic import Mathlib.Data.Fintype.Basic import Mathlib.Tactic.Linarith import Mathlib.Tactic.Ring +import SSA.Projects.MLIRSyntax.AST -- TODO post-merge: bring into Core +import SSA.Projects.MLIRSyntax.EDSL -- TODO post-merge: bring into Core open Ctxt (Var VarSet Valuation) open Goedel (toType) @@ -1204,18 +1206,20 @@ macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : `(tactic| ( try simp (config := {decide := false}) only [ - DerivedCtxt.snoc, DerivedCtxt.ofCtxt, - DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last + Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, + Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, + DialectMorphism.mapOp, DialectMorphism.mapTy ] -- separate `simp` block so it does not fail if MLIR.AST is not open. try simp (config := {decide := false}) only [ Com.denote, Expr.denote, HVector.denote, Var.zero_eq_last, Var.succ_eq_toSnoc, Ctxt.empty, Ctxt.empty_eq, Ctxt.snoc, Ctxt.Valuation.nil, Ctxt.Valuation.snoc_last, Ctxt.ofList, Ctxt.Valuation.snoc_toSnoc, HVector.map, HVector.toPair, HVector.toTuple, OpDenote.denote, Expr.op_mk, Expr.args_mk, + DialectMorphism.mapOp, DialectMorphism.mapTy, $ts,*] try simp (config := {decide := false}) only [ - DerivedCtxt.snoc, DerivedCtxt.ofCtxt, - DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last + Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, + Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last ] -- separate `simp` block so it does not fail if MLIR.AST is not open. generalize $ll { val := 0, property := _ } = a; generalize $ll { val := 1, property := _ } = b; @@ -1240,8 +1244,22 @@ macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : ) ) +open MLIR AST in +macro "simp_peephole_mlir" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : tactic => + `(tactic| + ( + simp_peephole [$ts,*] at $ll + try simp (config := {decide := false}) only [ + Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, + Ctxt.snoc, Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, List.map, + DialectMorphism.mapOp, DialectMorphism.mapTy + ] -- separate `simp` block so it does not fail if MLIR.AST is not open. + ) + ) + /-- `simp_peephole` with no extra user defined theorems. -/ macro "simp_peephole" "at" ll:ident : tactic => `(tactic| simp_peephole [] at $ll) +macro "simp_peephole_mlir" "at" ll:ident : tactic => `(tactic| simp_peephole_mlir [] at $ll) end SimpPeephole diff --git a/SSA/Projects/InstCombine/AliveAutoGenerated.lean b/SSA/Projects/InstCombine/AliveAutoGenerated.lean index 7b916ecd3..e8a7ac635 100644 --- a/SSA/Projects/InstCombine/AliveAutoGenerated.lean +++ b/SSA/Projects/InstCombine/AliveAutoGenerated.lean @@ -39,6 +39,7 @@ def eg2 (w : Nat) := "llvm.return" (%C0) : (_) -> () }] + def eg3 (w : Nat) := [alive_icom ( w )| { ^bb0(%C0 : _): @@ -47,10 +48,12 @@ def eg3 (w : Nat) := "llvm.return" (%C0) : (_) -> () }] +open MLIR AST in theorem eg2_eq_eg3 : eg2 w ⊑ eg3 w := by unfold eg2 unfold eg3 - sorry /- We should get this example simplifying first. -/ + simp_alive_peephole + simp -- Name:AddSub:1043 -- precondition: true @@ -93,8 +96,12 @@ def alive_AddSub_1043_tgt (w : Nat) := }] theorem alive_AddSub_1043 (w : Nat) : alive_AddSub_1043_src w ⊑ alive_AddSub_1043_tgt w := by unfold alive_AddSub_1043_src alive_AddSub_1043_tgt - dsimp only [Com.Refinement] - intros Γv + simp_alive_peephole + -- simp[pairBind, Var.last, bind, Option.bind, Var.toSnoc, Var.last] + sorry + +#exit + sorry /- simp[OpDenote.denote, diff --git a/SSA/Projects/InstCombine/Base.lean b/SSA/Projects/InstCombine/Base.lean index f6fc341da..9467c7186 100644 --- a/SSA/Projects/InstCombine/Base.lean +++ b/SSA/Projects/InstCombine/Base.lean @@ -32,7 +32,6 @@ abbrev Width φ := ConcreteOrMVar Nat φ inductive MTy (φ : Nat) | bitvec (w : Width φ) : MTy φ deriving DecidableEq, Inhabited - abbrev Ty := MTy 0 instance : Repr (MTy φ) where @@ -40,6 +39,9 @@ instance : Repr (MTy φ) where | .bitvec (.concrete w), _ => "i" ++ repr w | .bitvec (.mvar ⟨i, _⟩), _ => f!"i$\{%{i}}" +instance : Lean.ToFormat (MTy φ) where + format := repr + instance : ToString (MTy φ) where toString t := repr t |>.pretty diff --git a/SSA/Projects/InstCombine/LLVM/Transform.lean b/SSA/Projects/InstCombine/LLVM/Transform.lean index bb81c44ea..026c6c580 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform.lean @@ -1,9 +1,9 @@ -- should replace with Lean import once Pure is upstream import SSA.Projects.MLIRSyntax.AST -import SSA.Projects.InstCombine.Base import SSA.Projects.InstCombine.LLVM.Transform.NameMapping import SSA.Projects.InstCombine.LLVM.Transform.TransformError import SSA.Core.Framework +import SSA.Core.ErasedContext import Std.Data.BitVec @@ -11,11 +11,12 @@ universe u namespace MLIR.AST -open InstCombine (MOp MTy Width) open Std (BitVec) +open Ctxt + +instance {Op Ty : Type} [OpSignature Op Ty] {t : Ty} {Γ : Ctxt Ty} {Γ' : DerivedCtxt Γ} : Coe (Expr Op Γ t) (Expr Op Γ'.ctxt t) where + coe e := e.changeVars Γ'.diff.toHom -instance : Lean.ToFormat (MTy φ) where - format := repr section Monads @@ -61,45 +62,6 @@ class TransformReturn (Op : Type) (Ty : outParam (Type)) (φ : outParam Nat) /- instance of the transform dialect, plus data needed about `Op` and `Ty`. -/ variable {Op Ty φ} [OpSignature Op Ty] [DecidableEq Ty] [DecidableEq Op] -structure DerivedCtxt (Γ : Ctxt Ty) where - ctxt : Ctxt Ty - diff : Ctxt.Diff Γ ctxt - -namespace DerivedCtxt - -/-- Every context is trivially derived from itself -/ -@[simp] -abbrev ofCtxt (Γ : Ctxt Ty) : DerivedCtxt Γ := ⟨Γ, .zero _⟩ - -/-- value of a dervied context from an empty context, - is the empty context with a zero diff. -/ -@[simp] -theorem ofCtxt_empty : MLIR.AST.DerivedCtxt.ofCtxt ([] : Ctxt Ty) = ⟨[], .zero _⟩ := rfl - -/-- `snoc` of a derived context applies `snoc` to the underlying context, and updates the diff -/ -@[simp] -def snoc {Γ : Ctxt Ty} : DerivedCtxt Γ → Ty → DerivedCtxt Γ - | ⟨ctxt, diff⟩, ty => ⟨ty::ctxt, diff.toSnoc⟩ - -@[simp] -instance {Γ : Ctxt Ty} : CoeHead (DerivedCtxt Γ) (Ctxt Ty) where - coe := fun ⟨Γ', _⟩ => Γ' - -instance {Γ : Ctxt Ty} : CoeDep (Ctxt Ty) Γ (DerivedCtxt Γ) where - coe := ⟨Γ, .zero _⟩ - -instance {Γ : Ctxt Ty} {Γ' : DerivedCtxt Γ} : - CoeHead (DerivedCtxt (Γ' : Ctxt Ty)) (DerivedCtxt Γ) where - coe := fun ⟨Γ'', diff⟩ => ⟨Γ'', Γ'.diff + diff⟩ - -instance {Γ : Ctxt Ty} {Γ' : DerivedCtxt Γ} : Coe (Expr Op Γ t) (Expr Op Γ'.ctxt t) where - coe e := e.changeVars Γ'.diff.toHom - -instance {Γ' : DerivedCtxt Γ} : Coe (Ctxt.Var Γ t) (Ctxt.Var (Γ' : Ctxt Ty) t) where - coe v := Γ'.diff.toHom v - -end DerivedCtxt - /-- Add a new variable to the context, and record it's (absolute) index in the name mapping @@ -349,10 +311,10 @@ def MLIRType.mkTy : MLIRType φ → ExceptM Op Ty def TypedSSAVal.mkTy [TransformTy Op Ty φ] : TypedSSAVal φ → ExceptM Op Ty | (.SSAVal _, ty) => TransformTy.mkTy ty - +/- def mkVal (ty : InstCombine.Ty) : Int → BitVec ty.width | val => BitVec.ofInt ty.width val - +-/ /-- Translate a `TypedSSAVal` (a name with an expected type), to a variable in the context. This expects the name to have already been declared before -/ def TypedSSAVal.mkVal [instTransformTy : TransformTy Op Ty φ] (Γ : Ctxt Ty) : TypedSSAVal φ → diff --git a/SSA/Projects/InstCombine/Tactic.lean b/SSA/Projects/InstCombine/Tactic.lean index fe699ab0f..80b74c709 100644 --- a/SSA/Projects/InstCombine/Tactic.lean +++ b/SSA/Projects/InstCombine/Tactic.lean @@ -2,9 +2,11 @@ import SSA.Projects.InstCombine.LLVM.EDSL import SSA.Projects.InstCombine.AliveStatements import SSA.Projects.InstCombine.Refinement import Mathlib.Tactic +import SSA.Core.ErasedContext open MLIR AST open Std (BitVec) +open Ctxt theorem bitvec_minus_one : BitVec.ofInt w (Int.negSucc 0) = (-1 : BitVec w) := by simp[BitVec.ofInt, BitVec.ofNat,Neg.neg, @@ -19,6 +21,8 @@ theorem bitvec_minus_one : BitVec.ofInt w (Int.negSucc 0) = (-1 : BitVec w) := b simp rw[ONE] + +open MLIR AST in /-- - We first simplify `Com.refinement` to see the context `Γv`. - We `simp_peephole Γv` to simplify context accesses by variables. @@ -34,14 +38,14 @@ macro "simp_alive_peephole" : tactic => intros Γv simp [InstcombineTransformDialect.MOp.instantiateCom, InstcombineTransformDialect.instantiateMOp, ConcreteOrMVar.instantiate, Vector.get, List.nthLe, List.length_singleton, Fin.coe_fin_one, Fin.zero_eta, - List.get_cons_zero, Function.comp_apply, InstcombineTransformDialect.instantiateMTy, Ctxt.empty_eq, DerivedCtxt.snoc, - DerivedCtxt.ofCtxt, List.map_eq_map, List.map] at Γv - simp_peephole at Γv + List.get_cons_zero, Function.comp_apply, InstcombineTransformDialect.instantiateMTy, Ctxt.empty_eq, Ctxt.DerivedCtxt.snoc, + Ctxt.DerivedCtxt.ofCtxt, List.map_eq_map, List.map, DialectMorphism.mapTy] at Γv + simp_peephole_mlir at Γv /- note that we need the `HVector.toPair`, `HVector.toSingle`, `HVector.toTriple` lemmas since it's used in `InstCombine.Op.denote` We need `HVector.toTuple` since it's used in `MLIR.AST.mkOpExpr`. -/ - try simp (config := {decide := false}) only [OpDenote.denote, + simp (config := {decide := false}) only [OpDenote.denote, InstCombine.Op.denote, HVector.toPair, HVector.toTriple, pairMapM, BitVec.Refinement, - bind, Option.bind, pure, DerivedCtxt.ofCtxt, DerivedCtxt.snoc, + bind, Option.bind, pure, Ctxt.DerivedCtxt.ofCtxt, Ctxt.DerivedCtxt.snoc, Ctxt.snoc, ConcreteOrMVar.instantiate, Vector.get, HVector.toSingle, LLVM.and?, LLVM.or?, LLVM.xor?, LLVM.add?, LLVM.sub?, @@ -54,7 +58,8 @@ macro "simp_alive_peephole" : tactic => InstcombineTransformDialect.instantiateMOp, InstcombineTransformDialect.MOp.instantiateCom, InstcombineTransformDialect.instantiateCtxt, - ConcreteOrMVar.instantiate, Com.Refinement] + ConcreteOrMVar.instantiate, Com.Refinement, + DialectMorphism.mapTy] try intros v0 try intros v1 try intros v2 diff --git a/SSA/Projects/MLIRSyntax/AST.lean b/SSA/Projects/MLIRSyntax/AST.lean index de4038ab8..a5ef7660a 100644 --- a/SSA/Projects/MLIRSyntax/AST.lean +++ b/SSA/Projects/MLIRSyntax/AST.lean @@ -1,5 +1,5 @@ import SSA.Core.Util.ConcreteOrMVar -import SSA.Core.Framework +-- import SSA.Core.Framework open Lean PrettyPrinter namespace MLIR.AST diff --git a/SSA/Projects/PaperExamples/PaperExamples.lean b/SSA/Projects/PaperExamples/PaperExamples.lean index 9e9b2b668..251c23fa5 100644 --- a/SSA/Projects/PaperExamples/PaperExamples.lean +++ b/SSA/Projects/PaperExamples/PaperExamples.lean @@ -96,7 +96,7 @@ def mkReturn (Γ : Ctxt Ty) (opStx : MLIR.AST.Op 0) : MLIR.AST.ReaderM Op (Σ ty instance : MLIR.AST.TransformReturn Op Ty 0 where mkReturn := mkReturn -open InstCombine (Op Ty) in +-- open InstCombine (Op Ty) in def mlir2simple (reg : MLIR.AST.Region 0) : MLIR.AST.ExceptM Op (Σ (Γ : Ctxt Ty) (ty : Ty), Com Op Γ ty) := MLIR.AST.mkCom reg @@ -152,7 +152,7 @@ def p1 : PeepholeRewrite Op [.int] .int := { lhs := lhs, rhs := rhs, correct := by rw [lhs, rhs] - /- + /-: Com.denote (Com.lete (cst 0) (Com.lete (add { val := 1, property := _ } { val := 0, property := _ }) @@ -160,7 +160,7 @@ def p1 : PeepholeRewrite Op [.int] .int := Com.denote (Com.ret { val := 0, property := _ }) -/ funext Γv - simp_peephole [add, cst] at Γv + simp_peephole_mlir [add, cst] at Γv /- ⊢ ∀ (a : BitVec 32), a + BitVec.ofInt 32 0 = a -/ intros a ring From 023c5dc4eb82f9ea2399bdf53abad6fa8d3d71d8 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 16 Nov 2023 02:32:13 +0000 Subject: [PATCH 14/28] feat: run p1 to check that it does the right thing --- SSA/Core/Framework.lean | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index 39c380988..ddbe79b9e 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -1197,6 +1197,8 @@ theorem denote_rewritePeepholeAt (pr : PeepholeRewrite Op Γ t) section SimpPeephole +#check denote_rewritePeepholeAt +#print axioms denote_rewritePeepholeAt -- [propext, Classical.choice, Quot.sound] /-- `simp_peephole [t1, t2, ... tn]` at Γ simplifies the evaluation of the context Γ, @@ -1571,7 +1573,7 @@ instance : Goedel ExTy where inductive ExOp : Type | add : ExOp | runK : ℕ → ExOp - deriving DecidableEq + deriving DecidableEq, Repr instance : OpSignature ExOp ExTy where signature @@ -1622,6 +1624,15 @@ def p1 : PeepholeRewrite ExOp [.nat] .nat:= done } +def p1_run : Com ExOp [.nat] .nat := + rewritePeepholeAt p1 0 ex1_lhs + +/- +RegionExamples.ExOp.runK 0[[%0]] +return %1 +-/ +#eval p1_run + /-- running `f(x) = x + x` 1 times does return `x + x`. -/ def ex2_lhs : Com ExOp [.nat] .nat := Com.lete (rgn (k := 1) ⟨0, by simp[Ctxt.snoc]⟩ ( From 69047068de0ac0a0e5e3a3203f370247a6090d84 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 16 Nov 2023 02:48:55 +0000 Subject: [PATCH 15/28] feat: add rewriter that's based on fuel, which tries applying the rewriter at every 'ix' until it runs out of fuel --- SSA/Core/Framework.lean | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index ddbe79b9e..72b8b2e3e 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -1182,6 +1182,7 @@ def rewritePeepholeAt (pr : PeepholeRewrite Op Γ t) | none => target else target + theorem denote_rewritePeepholeAt (pr : PeepholeRewrite Op Γ t) (pos : ℕ) (target : Com Op Γ₂ t₂) : (rewritePeepholeAt pr pos target).denote = target.denote := by @@ -1200,6 +1201,23 @@ section SimpPeephole #check denote_rewritePeepholeAt #print axioms denote_rewritePeepholeAt -- [propext, Classical.choice, Quot.sound] +/- repeatedly apply peephole on program. -/ +section SimpPeepholeApplier + +/-- rewrite with `pr` to `target` program, at location `ix` and later, running at most `fuel` steps. -/ +def rewritePeephole_go (fuel : ℕ) (pr : PeepholeRewrite Op Γ t) (ix : ℕ) (target : Com Op Γ₂ t₂) : + (Com Op Γ₂ t₂) := + match fuel with + | 0 => target + | fuel' + 1 => + let target' := rewritePeepholeAt pr ix target + rewritePeephole_go fuel' pr (ix + 1) target' + +/-- rewrite with `pr` to `target` program, running at most `fuel` steps. -/ +def rewritePeephole (fuel : ℕ) (pr : PeepholeRewrite Op Γ t) (target : Com Op Γ₂ t₂) : + (Com Op Γ₂ t₂) := rewritePeephole_go fuel pr 0 target +end SimpPeepholeApplier + /-- `simp_peephole [t1, t2, ... tn]` at Γ simplifies the evaluation of the context Γ, leaving behind a bare Lean level proposition to be proven. From 214f9d3128edceeac94e7e59ab262d84b3bb32b0 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 16 Nov 2023 02:55:37 +0000 Subject: [PATCH 16/28] show example of 'rewritePeephole' with fuel --- SSA/Projects/PaperExamples/PaperExamples.lean | 34 ++++++++++++++++--- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/SSA/Projects/PaperExamples/PaperExamples.lean b/SSA/Projects/PaperExamples/PaperExamples.lean index 251c23fa5..23dd7e19d 100644 --- a/SSA/Projects/PaperExamples/PaperExamples.lean +++ b/SSA/Projects/PaperExamples/PaperExamples.lean @@ -102,7 +102,7 @@ def mlir2simple (reg : MLIR.AST.Region 0) : MLIR.AST.ExceptM Op (Σ (Γ : Ctxt Ty) (ty : Ty), Com Op Γ ty) := MLIR.AST.mkCom reg open Qq MLIR AST Lean Elab Term Meta in -elab "[toy_icom| " reg:mlir_region "]" : term => do +elab "[simple_com| " reg:mlir_region "]" : term => do let ast_stx ← `([mlir_region| $reg]) let ast ← elabTermEnsuringTypeQ ast_stx q(Region 0) let mvalues ← `(⟨[], by rfl⟩) @@ -129,10 +129,23 @@ elab "[toy_icom| " reg:mlir_region "]" : term => do end MLIR2Simple +open MLIR AST MLIR2Simple in +def eg₀ : Com Op (Ctxt.ofList []) .int := + [simple_com| { + %c2= "const"() {value = 2} : () -> i32 + %c4 = "const"() {value = 4} : () -> i32 + %c6 = "add"(%c2, %c4) : (i32, i32) -> i32 + %c8 = "add"(%c6, %c2) : (i32, i32) -> i32 + "return"(%c8) : (i32) -> () + }] + +def eg₀val := Com.denote eg₀ Ctxt.Valuation.nil +#eval eg₀val -- 0x00000008#32 + open MLIR AST MLIR2Simple in /-- x + 0 -/ def lhs : Com Op (Ctxt.ofList [.int]) .int := - [toy_icom| { + [simple_com| { ^bb0(%x : i32): %c0 = "const" () { value = 0 : i32 } : () -> i32 %out = "add" (%x, %c0) : (i32, i32) -> i32 @@ -142,7 +155,7 @@ def lhs : Com Op (Ctxt.ofList [.int]) .int := open MLIR AST MLIR2Simple in /-- x -/ def rhs : Com Op (Ctxt.ofList [.int]) .int := - [toy_icom| { + [simple_com| { ^bb0(%x : i32): "return" (%x) : (i32) -> (i32) }] @@ -168,10 +181,20 @@ def p1 : PeepholeRewrite Op [.int] .int := sorry -- TODO: import ring instance from other file. } -def ex1' : Com Op (Ctxt.ofList [.int]) .int := rewritePeepholeAt p1 1 lhs +def ex1_rewritePeepholeAt : Com Op (Ctxt.ofList [.int]) .int := rewritePeepholeAt p1 1 lhs +theorem hex1_rewritePeephole : ex1_rewritePeepholeAt = ( + -- %c0 = 0 + Com.lete (cst 0) <| + -- %out_dead = %x + %c0 + Com.lete (add ⟨1, by simp [Ctxt.snoc]⟩ ⟨0, by simp [Ctxt.snoc]⟩ ) <| -- %out = %x + %c0 + -- ret %c0 + Com.ret ⟨2, by simp [Ctxt.snoc]⟩) + := by rfl -theorem EX1' : ex1' = ( +def ex1_rewritePeephole : Com Op (Ctxt.ofList [.int]) .int := rewritePeephole (fuel := 100) p1 lhs + +theorem Hex1_rewritePeephole : ex1_rewritePeephole = ( -- %c0 = 0 Com.lete (cst 0) <| -- %out_dead = %x + %c0 @@ -180,6 +203,7 @@ theorem EX1' : ex1' = ( Com.ret ⟨2, by simp [Ctxt.snoc]⟩) := by rfl + end ToyNoRegion namespace ToyRegion From 72cc04f9b28c8075abbe921541b9a3fa0b7d93f2 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 16 Nov 2023 03:01:32 +0000 Subject: [PATCH 17/28] feat: prove that rewritePeephole preserves semantics --- SSA/Core/Framework.lean | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index 72b8b2e3e..531bf1b1e 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -1213,9 +1213,25 @@ def rewritePeephole_go (fuel : ℕ) (pr : PeepholeRewrite Op Γ t) (ix : ℕ) (t let target' := rewritePeepholeAt pr ix target rewritePeephole_go fuel' pr (ix + 1) target' +/-- `rewritePeephole_go` preserve semantics -/ +theorem denote_rewritePeephole_go (pr : PeepholeRewrite Op Γ t) + (pos : ℕ) (target : Com Op Γ₂ t₂) : + (rewritePeephole_go fuel pr pos target).denote = target.denote := by + induction fuel generalizing pr pos target + case zero => simp[rewritePeephole_go, denote_rewritePeepholeAt] + case succ fuel' hfuel => + simp[rewritePeephole_go, denote_rewritePeepholeAt, hfuel] + /-- rewrite with `pr` to `target` program, running at most `fuel` steps. -/ def rewritePeephole (fuel : ℕ) (pr : PeepholeRewrite Op Γ t) (target : Com Op Γ₂ t₂) : (Com Op Γ₂ t₂) := rewritePeephole_go fuel pr 0 target + + +/-- `rewritePeephole` preserves semantics. -/ +theorem denote_rewritePeephole (fuel : ℕ) + (pr : PeepholeRewrite Op Γ t) (target : Com Op Γ₂ t₂) : + (rewritePeephole fuel pr target).denote = target.denote := by + simp[rewritePeephole, denote_rewritePeephole_go] end SimpPeepholeApplier /-- From 464e3ec83519dbd1fd71034a6d27da987989a20f Mon Sep 17 00:00:00 2001 From: Siddharth Date: Thu, 7 Dec 2023 19:18:23 +0530 Subject: [PATCH 18/28] Update SSA/Core/Framework.lean Co-authored-by: Alex Keizer --- SSA/Core/Framework.lean | 1 - 1 file changed, 1 deletion(-) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index 531bf1b1e..85d40020f 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -1182,7 +1182,6 @@ def rewritePeepholeAt (pr : PeepholeRewrite Op Γ t) | none => target else target - theorem denote_rewritePeepholeAt (pr : PeepholeRewrite Op Γ t) (pos : ℕ) (target : Com Op Γ₂ t₂) : (rewritePeepholeAt pr pos target).denote = target.denote := by From 390d2ca8d450d5963acb646ca927831ebbfdba95 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 8 Dec 2023 10:25:55 +0530 Subject: [PATCH 19/28] feat: simplify simp sets, remove duplicated tactic --- SSA/Core/Framework.lean | 38 ++++++++----------- SSA/Projects/InstCombine/Tactic.lean | 2 +- SSA/Projects/PaperExamples/PaperExamples.lean | 2 +- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index 85d40020f..b0feddd4d 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -517,6 +517,7 @@ structure DialectMorphism (Op Op' : Type) {Ty Ty' : Type} [OpSignature Op Ty] [O mapTy : Ty → Ty' preserves_signature : ∀ op, signature (mapOp op) = mapTy <$> (signature op) + variable {Op Op' Ty Ty : Type} [OpSignature Op Ty] [OpSignature Op' Ty'] (f : DialectMorphism Op Op') @@ -1233,6 +1234,7 @@ theorem denote_rewritePeephole (fuel : ℕ) simp[rewritePeephole, denote_rewritePeephole_go] end SimpPeepholeApplier +#check DialectMorphism /-- `simp_peephole [t1, t2, ... tn]` at Γ simplifies the evaluation of the context Γ, leaving behind a bare Lean level proposition to be proven. @@ -1240,22 +1242,28 @@ leaving behind a bare Lean level proposition to be proven. macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : tactic => `(tactic| ( - try simp (config := {decide := false}) only [ + -- try simp (config := {decide := false}) only [ + -- Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, + -- Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, + -- DialectMorphism.mapOp, DialectMorphism.mapTy, List.map + -- ] -- separate `simp` block so it does not fail if MLIR.AST is not open. + try simp (config := {decide := false}) only [ Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, - DialectMorphism.mapOp, DialectMorphism.mapTy - ] -- separate `simp` block so it does not fail if MLIR.AST is not open. - try simp (config := {decide := false}) only [ + DialectMorphism.mapOp, DialectMorphism.mapTy, List.map, Com.denote, Expr.denote, HVector.denote, Var.zero_eq_last, Var.succ_eq_toSnoc, Ctxt.empty, Ctxt.empty_eq, Ctxt.snoc, Ctxt.Valuation.nil, Ctxt.Valuation.snoc_last, Ctxt.ofList, Ctxt.Valuation.snoc_toSnoc, HVector.map, HVector.toPair, HVector.toTuple, OpDenote.denote, Expr.op_mk, Expr.args_mk, + DialectMorphism.mapOp, DialectMorphism.mapTy, List.map, + Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, + Ctxt.snoc, Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, List.map, DialectMorphism.mapOp, DialectMorphism.mapTy, $ts,*] - try simp (config := {decide := false}) only [ - Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, - Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last - ] -- separate `simp` block so it does not fail if MLIR.AST is not open. + -- try simp (config := {decide := false}) only [ + -- Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, + -- Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last + -- ] -- separate `simp` block so it does not fail if MLIR.AST is not open. generalize $ll { val := 0, property := _ } = a; generalize $ll { val := 1, property := _ } = b; generalize $ll { val := 2, property := _ } = c; @@ -1280,22 +1288,8 @@ macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : ) open MLIR AST in -macro "simp_peephole_mlir" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : tactic => - `(tactic| - ( - simp_peephole [$ts,*] at $ll - try simp (config := {decide := false}) only [ - Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, - Ctxt.snoc, Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, List.map, - DialectMorphism.mapOp, DialectMorphism.mapTy - ] -- separate `simp` block so it does not fail if MLIR.AST is not open. - ) - ) - /-- `simp_peephole` with no extra user defined theorems. -/ macro "simp_peephole" "at" ll:ident : tactic => `(tactic| simp_peephole [] at $ll) -macro "simp_peephole_mlir" "at" ll:ident : tactic => `(tactic| simp_peephole_mlir [] at $ll) - end SimpPeephole diff --git a/SSA/Projects/InstCombine/Tactic.lean b/SSA/Projects/InstCombine/Tactic.lean index 80b74c709..9cb26903c 100644 --- a/SSA/Projects/InstCombine/Tactic.lean +++ b/SSA/Projects/InstCombine/Tactic.lean @@ -40,7 +40,7 @@ macro "simp_alive_peephole" : tactic => ConcreteOrMVar.instantiate, Vector.get, List.nthLe, List.length_singleton, Fin.coe_fin_one, Fin.zero_eta, List.get_cons_zero, Function.comp_apply, InstcombineTransformDialect.instantiateMTy, Ctxt.empty_eq, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, List.map_eq_map, List.map, DialectMorphism.mapTy] at Γv - simp_peephole_mlir at Γv + simp_peephole at Γv /- note that we need the `HVector.toPair`, `HVector.toSingle`, `HVector.toTriple` lemmas since it's used in `InstCombine.Op.denote` We need `HVector.toTuple` since it's used in `MLIR.AST.mkOpExpr`. -/ simp (config := {decide := false}) only [OpDenote.denote, diff --git a/SSA/Projects/PaperExamples/PaperExamples.lean b/SSA/Projects/PaperExamples/PaperExamples.lean index 23dd7e19d..75db08391 100644 --- a/SSA/Projects/PaperExamples/PaperExamples.lean +++ b/SSA/Projects/PaperExamples/PaperExamples.lean @@ -173,7 +173,7 @@ def p1 : PeepholeRewrite Op [.int] .int := Com.denote (Com.ret { val := 0, property := _ }) -/ funext Γv - simp_peephole_mlir [add, cst] at Γv + simp_peephole [add, cst] at Γv /- ⊢ ∀ (a : BitVec 32), a + BitVec.ofInt 32 0 = a -/ intros a ring From 5c8d0f8a3d13000e1176aceb51122005f54ee395 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 8 Dec 2023 10:32:55 +0530 Subject: [PATCH 20/28] chore: nuke dead code --- SSA/Projects/InstCombine/LLVM/Transform.lean | 366 +------------------ 1 file changed, 6 insertions(+), 360 deletions(-) diff --git a/SSA/Projects/InstCombine/LLVM/Transform.lean b/SSA/Projects/InstCombine/LLVM/Transform.lean index 026c6c580..b2f621aa3 100644 --- a/SSA/Projects/InstCombine/LLVM/Transform.lean +++ b/SSA/Projects/InstCombine/LLVM/Transform.lean @@ -27,9 +27,9 @@ section Monads errors. -/ -abbrev ExceptM (Op) [OpSignature Op Ty] := Except (TransformError Ty) -abbrev BuilderM (Op) [OpSignature Op Ty] := StateT NameMapping (ExceptM Op) -abbrev ReaderM (Op) [OpSignature Op Ty] := ReaderT NameMapping (ExceptM Op) +abbrev ExceptM Op [OpSignature Op Ty] := Except (TransformError Ty) +abbrev BuilderM Op [OpSignature Op Ty] := StateT NameMapping (ExceptM Op) +abbrev ReaderM Op [OpSignature Op Ty] := ReaderT NameMapping (ExceptM Op) instance {Op : Type} [OpSignature Op Ty] : MonadLift (ReaderM Op) (BuilderM Op) where monadLift x := do (ReaderT.run x (←get) : ExceptM ..) @@ -37,7 +37,7 @@ instance {Op : Type} [OpSignature Op Ty] : MonadLift (ReaderM Op) (BuilderM Op) instance {Op : Type} [OpSignature Op Ty] : MonadLift (ExceptM Op) (ReaderM Op) where monadLift x := do return ←x -def BuilderM.runWithNewMapping {Op : Type} [OpSignature Op Ty] (k : BuilderM Op α) : ExceptM Op α := +def BuilderM.runWithEmptyMapping {Op : Type} [OpSignature Op Ty] (k : BuilderM Op α) : ExceptM Op α := Prod.fst <$> StateT.run k [] end Monads @@ -107,214 +107,9 @@ def BuilderM.isErr {α : Type} (x : BuilderM Op α) : Bool := | Except.ok _ => true | Except.error _ => false -#check Ctxt.Var -/- -def mkUnaryOp {Γ : Ctxt Ty} {ty : Ty} (op : MOp φ) - (e : Ctxt.Var Γ ty) : ExceptM Op <| Expr Op Γ ty := - match ty with - | .bitvec w => - match op with - -- Can't use a single arm, Lean won't write the rhs accordingly - | .neg w' => if h : w = w' - then return ⟨ - .neg w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | .not w' => if h : w = w' - then return ⟨ - .not w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | .copy w' => if h : w = w' - then return ⟨ - .copy w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e) .nil, - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw <| .unsupportedUnaryOp - -def mkBinOp {Γ : Ctxt Ty} {ty : Ty} (op : MOp φ) - (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM Op <| Expr Op Γ ty := - match ty with - | .bitvec w => - match op with - -- Can't use a single arm, Lean won't write the rhs accordingly - | .add w' => if h : w = w' - then return ⟨ - .add w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .and w' => if h : w = w' - then return ⟨ - .and w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .or w' => if h : w = w' - then return ⟨ - .or w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .xor w' => if h : w = w' - then return ⟨ - .xor w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .shl w' => if h : w = w' - then return ⟨ - .shl w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .lshr w' => if h : w = w' - then return ⟨ - .lshr w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .ashr w' => if h : w = w' - then return ⟨ - .ashr w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .urem w' => if h : w = w' - then return ⟨ - .urem w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .srem w' => if h : w = w' - then return ⟨ - .srem w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .mul w' => if h : w = w' - then return ⟨ - .mul w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .sub w' => if h : w = w' - then return ⟨ - .sub w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .sdiv w' => if h : w = w' - then return ⟨ - .sdiv w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | .udiv w' => if h : w = w' - then return ⟨ - .udiv w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | op => throw <| .unsupportedBinaryOp s!"unsupported binary operation {op}" - -def mkIcmp {Γ : Ctxt _} {ty : Ty} (op : MOp φ) - (e₁ e₂ : Ctxt.Var Γ ty) : ExceptM Op <| Expr Op Γ (.bitvec 1) := - match ty with - | .bitvec w => - match op with - | .icmp p w' => if h : w = w' - then return ⟨ - .icmp p w', - by simp [OpSignature.outTy, signature, h], - .cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw <| .unsupportedOp "unsupported icmp operation" - -def mkSelect {Γ : Ctxt Ty} {ty : Ty} (op : MOp φ) - (c : Ctxt.Var Γ (.bitvec 1)) (e₁ e₂ : Ctxt.Var Γ ty) : - ExceptM Op <| Expr Op Γ ty := - match ty with - | .bitvec w => - match op with - | .select w' => if h : w = w' - then return ⟨ - .select w', - by simp [OpSignature.outTy, signature, h], - .cons c <|.cons (h ▸ e₁) <| .cons (h ▸ e₂) .nil , - .nil - ⟩ - else throw <| .widthError w w' - | _ => throw <| .unsupportedOp "Unsupported select operation" - -def mkOpExpr {Γ : Ctxt Ty} (op : MOp φ) - (arg : HVector (fun t => Ctxt.Var Γ t) (OpSignature.sig op)) : - ExceptM Op <| Expr Op Γ (OpSignature.outTy op) := - match op with - | .and _ | .or _ | .xor _ | .shl _ | .lshr _ | .ashr _ - | .add _ | .mul _ | .sub _ | .udiv _ | .sdiv _ - | .srem _ | .urem _ => - let (e₁, e₂) := arg.toTuple - mkBinOp op e₁ e₂ - | .icmp _ _ => - let (e₁, e₂) := arg.toTuple - mkIcmp op e₁ e₂ - | .not _ | .neg _ | .copy _ => - mkUnaryOp op arg.head - | .select _ => - let (c, e₁, e₂) := arg.toTuple - mkSelect op c e₁ e₂ - | .const .. => throw <| .unsupportedOp "Tried to build Op expression from constant" - -def MLIRType.mkTy : MLIRType φ → ExceptM Op Ty - | MLIRType.int Signedness.Signless w => do - return .bitvec w - | _ => throw .unsupportedType -- "Unsupported type" --/ - def TypedSSAVal.mkTy [TransformTy Op Ty φ] : TypedSSAVal φ → ExceptM Op Ty | (.SSAVal _, ty) => TransformTy.mkTy ty -/- -def mkVal (ty : InstCombine.Ty) : Int → BitVec ty.width - | val => BitVec.ofInt ty.width val --/ + /-- Translate a `TypedSSAVal` (a name with an expected type), to a variable in the context. This expects the name to have already been declared before -/ def TypedSSAVal.mkVal [instTransformTy : TransformTy Op Ty φ] (Γ : Ctxt Ty) : TypedSSAVal φ → @@ -344,105 +139,6 @@ def TypedSSAVal.newVal [instTransformTy : TransformTy Op Ty φ] (Γ : Ctxt Ty) : let ⟨Γ, var⟩ ← addValToMapping Γ valStx ty return ⟨Γ, ty, var⟩ -/- -def mkExpr (Γ : Ctxt Ty) (opStx : MLIR.AST.Op φ) : ReaderM Op (Σ ty, Expr Op Γ ty) := do - match opStx.args with - | v₁Stx::v₂Stx::v₃Stx::[] => - let ⟨.bitvec w₁, v₁⟩ ← TypedSSAVal.mkVal Γ v₁Stx - let ⟨.bitvec w₂, v₂⟩ ← TypedSSAVal.mkVal Γ v₂Stx - let ⟨.bitvec w₃, v₃⟩ ← TypedSSAVal.mkVal Γ v₃Stx - match opStx.name with - | "llvm.select" => - if hw1 : w₁ = 1 then - if hw23 : w₂ = w₃ then - let selectOp ← mkSelect (MOp.select w₂) (hw1 ▸ v₁) v₂ (hw23 ▸ v₃) - return ⟨.bitvec w₂, selectOp⟩ - else - throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" - else throw <| .unsupportedOp s!"expected select condtion to have width 1, found width '{w₁}'" - | op => throw <| .unsupportedOp s!"Unsuported ternary operation or invalid arguments '{op}'" - | v₁Stx::v₂Stx::[] => - let ⟨.bitvec w₁, v₁⟩ ← TypedSSAVal.mkVal Γ v₁Stx - let ⟨.bitvec w₂, v₂⟩ ← TypedSSAVal.mkVal Γ v₂Stx - -- let ty₁ := ty₁.instantiave - let op ← match opStx.name with - | "llvm.and" => pure (MOp.and w₁) - | "llvm.or" => pure (MOp.or w₁) - | "llvm.xor" => pure (MOp.xor w₁) - | "llvm.shl" => pure (MOp.shl w₁) - | "llvm.lshr" => pure (MOp.lshr w₁) - | "llvm.ashr" => pure (MOp.ashr w₁) - | "llvm.urem" => pure (MOp.urem w₁) - | "llvm.srem" => pure (MOp.srem w₁) - | "llvm.add" => pure (MOp.add w₁) - | "llvm.mul" => pure (MOp.mul w₁) - | "llvm.sub" => pure (MOp.sub w₁) - | "llvm.sdiv" => pure (MOp.sdiv w₁) - | "llvm.udiv" => pure (MOp.udiv w₁) - | "llvm.icmp.eq" => pure (MOp.icmp LLVM.IntPredicate.eq w₁) - | "llvm.icmp.ne" => pure (MOp.icmp LLVM.IntPredicate.ne w₁) - | "llvm.icmp.ugt" => pure (MOp.icmp LLVM.IntPredicate.ugt w₁) - | "llvm.icmp.uge" => pure (MOp.icmp LLVM.IntPredicate.uge w₁) - | "llvm.icmp.ult" => pure (MOp.icmp LLVM.IntPredicate.ult w₁) - | "llvm.icmp.ule" => pure (MOp.icmp LLVM.IntPredicate.ule w₁) - | "llvm.icmp.sgt" => pure (MOp.icmp LLVM.IntPredicate.sgt w₁) - | "llvm.icmp.sge" => pure (MOp.icmp LLVM.IntPredicate.sge w₁) - | "llvm.icmp.slt" => pure (MOp.icmp LLVM.IntPredicate.slt w₁) - | "llvm.icmp.sle" => pure (MOp.icmp LLVM.IntPredicate.sle w₁) - | opstr => throw <| .unsupportedOp s!"Unsuported binary operation or invalid arguments '{opstr}'" - match op with - | .icmp .. => - if hty : w₁ = w₂ then - let icmpOp ← mkIcmp op v₁ (hty ▸ v₂) - return ⟨.bitvec 1, icmpOp⟩ - else - throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" - | _ => - if hty : w₁ = w₂ then - let binOp ← mkBinOp op v₁ (hty ▸ v₂) - return ⟨.bitvec w₁, binOp⟩ - else - throw <| .widthError w₁ w₂ -- s!"mismatched types {ty₁} ≠ {ty₂} in binary op" - | vStx::[] => - let ⟨.bitvec w, v⟩ ← TypedSSAVal.mkVal Γ vStx - let op ← match opStx.name with - | "llvm.not" => pure <| MOp.not w - | "llvm.neg" => pure <| MOp.neg w - | "llvm.copy" => pure <| MOp.copy w - | _ => throw <| .generic s!"Unknown (unary) operation syntax {opStx.name}" - let op ← mkUnaryOp op v - return ⟨.bitvec w, op⟩ - | [] => - if opStx.name == "llvm.mlir.constant" - then do - let some att := opStx.attrs.getAttr "value" - | throw <| .generic "tried to resolve constant without 'value' attribute" - match att with - | .int val ty => - let opTy@(MTy.bitvec w) ← MLIRType.mkTy ty -- ty.mkTy - return ⟨opTy, ⟨ - MOp.const w val, - by simp [OpSignature.outTy, signature, *], - HVector.nil, - HVector.nil - ⟩⟩ - | _ => throw <| .generic "invalid constant attribute" - else - throw <| .generic s!"invalid (0-ary) expression {opStx.name}" - | _ => throw <| .generic s!"unsupported expression (with unsupported arity) {opStx.name}" --/ - -/- -def mkReturn (Γ : Ctxt Ty) (opStx : MLIR.AST.Op φ) : ReaderM Op (Σ ty, Com Op Γ ty) := - if opStx.name == "llvm.return" - then match opStx.args with - | vStx::[] => do - let ⟨ty, v⟩ ← TypedSSAVal.mkVal Γ vStx - return ⟨ty, _root_.Com.ret v⟩ - | _ => throw <| .generic s!"Ill-formed return statement (wrong arity, expected 1, got {opStx.args.length})" - else throw <| .generic s!"Tried to build return out of non-return statement {opStx.name}" --/ - /-- Given a list of `TypedSSAVal`s, treat each as a binder and declare a new variable with the given name and type -/ private def declareBindings [TransformTy Op Ty φ] (Γ : Ctxt Ty) (vals : List (TypedSSAVal φ)) : @@ -472,59 +168,9 @@ def mkCom [TransformTy Op Ty φ] [TransformExpr Op Ty φ] [TransformReturn Op Ty (reg : MLIR.AST.Region φ) : ExceptM Op (Σ (Γ : Ctxt Ty) (ty : Ty), Com Op Γ ty) := match reg.ops with | [] => throw <| .generic "Ill-formed region (empty)" - | coms => BuilderM.runWithNewMapping <| do + | coms => BuilderM.runWithEmptyMapping <| do let Γ ← declareBindings ∅ reg.args let com ← mkComHelper Γ coms return ⟨Γ, com⟩ -/-! - ## Instantiation - Finally, we show how to instantiate a family of programs to a concrete program --/ - -/- -def _root_.InstCombine.MTy.instantiate (vals : Vector Nat φ) : Ty → InstCombine.Ty - | .bitvec w => .bitvec <| .concrete <| w.instantiate vals - -def _root_.InstCombine.MOp.instantiate (vals : Vector Nat φ) : MOp φ → InstCombine.Op - | .and w => .and (w.instantiate vals) - | .or w => .or (w.instantiate vals) - | .not w => .not (w.instantiate vals) - | .xor w => .xor (w.instantiate vals) - | .shl w => .shl (w.instantiate vals) - | .lshr w => .lshr (w.instantiate vals) - | .ashr w => .ashr (w.instantiate vals) - | .urem w => .urem (w.instantiate vals) - | .srem w => .srem (w.instantiate vals) - | .select w => .select (w.instantiate vals) - | .add w => .add (w.instantiate vals) - | .mul w => .mul (w.instantiate vals) - | .sub w => .sub (w.instantiate vals) - | .neg w => .neg (w.instantiate vals) - | .copy w => .copy (w.instantiate vals) - | .sdiv w => .sdiv (w.instantiate vals) - | .udiv w => .udiv (w.instantiate vals) - | .icmp c w => .icmp c (w.instantiate vals) - | .const w val => .const (w.instantiate vals) val - -def Ctxt.instantiate (vals : Vector Nat φ) (Γ : Ctxt Ty) : Ctxt InstCombine.Ty := - Γ.map (MTy.instantiate vals) - -def MOp.instantiateCom (vals : Vector Nat φ) : DialectMorphism Op (InstCombine.Op) where - mapOp := MOp.instantiate vals - mapTy := MTy.instantiate vals - preserves_signature op := by - simp only [MTy.instantiate, MOp.instantiate, ConcreteOrMVar.instantiate, (· <$> ·), signature, - InstCombine.MOp.sig, InstCombine.MOp.outTy, Function.comp_apply, List.map, Signature.mk.injEq, - true_and] - cases op <;> simp only [List.map, and_self, List.cons.injEq] - - -open InstCombine (Op Ty) in -def mkComInstantiate (reg : Region φ) : - ExceptM Op (Vector Nat φ → Σ (Γ : Ctxt InstCombine.Ty) (ty : InstCombine.Ty), Com InstCombine.Op Γ ty) := do - let ⟨Γ, ty, com⟩ ← mkCom reg - return fun vals => - ⟨Γ.instantiate vals, ty.instantiate vals, com.map (MOp.instantiateCom vals)⟩ --/ end MLIR.AST From faf473c2440f03ddcc12b742b592999ca75efc0d Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 8 Dec 2023 10:40:52 +0530 Subject: [PATCH 21/28] chore: more cleanup --- SSA/Core/Framework.lean | 2 +- .../InstCombine/AliveAutoGenerated.lean | 54 ------------------- SSA/Projects/InstCombine/Base.lean | 1 + SSA/Projects/MLIRSyntax/AST.lean | 1 - 4 files changed, 2 insertions(+), 56 deletions(-) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index b0feddd4d..88a83c394 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -1658,7 +1658,7 @@ def p1_run : Com ExOp [.nat] .nat := RegionExamples.ExOp.runK 0[[%0]] return %1 -/ -#eval p1_run +-- #eval p1_run /-- running `f(x) = x + x` 1 times does return `x + x`. -/ def ex2_lhs : Com ExOp [.nat] .nat := diff --git a/SSA/Projects/InstCombine/AliveAutoGenerated.lean b/SSA/Projects/InstCombine/AliveAutoGenerated.lean index e8a7ac635..1435b47ba 100644 --- a/SSA/Projects/InstCombine/AliveAutoGenerated.lean +++ b/SSA/Projects/InstCombine/AliveAutoGenerated.lean @@ -97,62 +97,8 @@ def alive_AddSub_1043_tgt (w : Nat) := theorem alive_AddSub_1043 (w : Nat) : alive_AddSub_1043_src w ⊑ alive_AddSub_1043_tgt w := by unfold alive_AddSub_1043_src alive_AddSub_1043_tgt simp_alive_peephole - -- simp[pairBind, Var.last, bind, Option.bind, Var.toSnoc, Var.last] sorry -#exit - - sorry - /- - simp[OpDenote.denote, - InstCombine.Op.denote, HVector.toPair, HVector.toTriple, pairMapM, BitVec.Refinement, - bind, Option.bind, pure, DerivedCtxt.ofCtxt, DerivedCtxt.snoc, - Ctxt.snoc, ConcreteOrMVar.instantiate, Vector.get, HVector.toSingle, - LLVM.and?, LLVM.or?, LLVM.xor?, LLVM.add?, LLVM.sub?, - LLVM.mul?, LLVM.udiv?, LLVM.sdiv?, LLVM.urem?, LLVM.srem?, - LLVM.sshr, LLVM.lshr?, LLVM.ashr?, LLVM.shl?, LLVM.select?, - LLVM.const?, LLVM.icmp?, - DerivedCtxt.ofCtxt, InstcombineTransformDialect.MOp.instantiateCom, InstcombineTransformDialect.instantiateMTy, - List.map, - HVector.toTuple, List.nthLe, bitvec_minus_one, - DialectMorphism.mapTy, - InstcombineTransformDialect.instantiateMTy, - InstcombineTransformDialect.instantiateMOp, - InstcombineTransformDialect.MOp.instantiateCom, - InstcombineTransformDialect.instantiateCtxt, - ConcreteOrMVar.instantiate, Com.Refinement, - InstCombine.MOp.add, - InstCombine.MOp.const, - InstCombine.MOp.xor, - InstCombine.MOp.and] at Γv - try simp [OpDenote.denote, - InstCombine.Op.denote, HVector.toPair, HVector.toTriple, pairMapM, BitVec.Refinement, - bind, Option.bind, pure, DerivedCtxt.ofCtxt, DerivedCtxt.snoc, - Ctxt.snoc, - ConcreteOrMVar.instantiate, Vector.get, HVector.toSingle, - LLVM.and?, LLVM.or?, LLVM.xor?, LLVM.add?, LLVM.sub?, - LLVM.mul?, LLVM.udiv?, LLVM.sdiv?, LLVM.urem?, LLVM.srem?, - LLVM.sshr, LLVM.lshr?, LLVM.ashr?, LLVM.shl?, LLVM.select?, - LLVM.const?, LLVM.icmp?, - HVector.toTuple, List.nthLe, bitvec_minus_one, - DialectMorphism.mapTy, - InstcombineTransformDialect.instantiateMTy, - InstcombineTransformDialect.instantiateMOp, - InstcombineTransformDialect.MOp.instantiateCom, - InstcombineTransformDialect.instantiateCtxt, - ConcreteOrMVar.instantiate, Com.Refinement, - Com.denote, Expr.denote, - Com.denote, Expr.denote, HVector.denote, Var.zero_eq_last, Var.succ_eq_toSnoc, - Ctxt.empty, Ctxt.empty_eq, Ctxt.snoc, Ctxt.Valuation.nil, Ctxt.Valuation.snoc_last, - Ctxt.ofList, Ctxt.Valuation.snoc_toSnoc, - HVector.map, HVector.toPair, HVector.toTuple, OpDenote.denote, Expr.op_mk, Expr.args_mk] - -- simp only [Var.last, Var.toSnoc] - -- apply bitvec_AddSub_1043 - sorry - -/ - -#exit - /-# Early Exit delete this to check the rest of the file. The `#exit` is an early exit to allow our `lake` builds to complete in sensible amounts of time. diff --git a/SSA/Projects/InstCombine/Base.lean b/SSA/Projects/InstCombine/Base.lean index 9467c7186..d910f28d1 100644 --- a/SSA/Projects/InstCombine/Base.lean +++ b/SSA/Projects/InstCombine/Base.lean @@ -32,6 +32,7 @@ abbrev Width φ := ConcreteOrMVar Nat φ inductive MTy (φ : Nat) | bitvec (w : Width φ) : MTy φ deriving DecidableEq, Inhabited + abbrev Ty := MTy 0 instance : Repr (MTy φ) where diff --git a/SSA/Projects/MLIRSyntax/AST.lean b/SSA/Projects/MLIRSyntax/AST.lean index a5ef7660a..31b5894bf 100644 --- a/SSA/Projects/MLIRSyntax/AST.lean +++ b/SSA/Projects/MLIRSyntax/AST.lean @@ -1,5 +1,4 @@ import SSA.Core.Util.ConcreteOrMVar --- import SSA.Core.Framework open Lean PrettyPrinter namespace MLIR.AST From 13ebcf5a02c5de620e22eac34b03bdb1cc9c4b1d Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 8 Dec 2023 10:48:20 +0530 Subject: [PATCH 22/28] chore: delete noisy lines from Framework --- SSA/Core/Framework.lean | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index 88a83c394..45820fb80 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -1198,8 +1198,7 @@ theorem denote_rewritePeepholeAt (pr : PeepholeRewrite Op Γ t) section SimpPeephole -#check denote_rewritePeepholeAt -#print axioms denote_rewritePeepholeAt -- [propext, Classical.choice, Quot.sound] +-- #print axioms denote_rewritePeepholeAt -- [propext, Classical.choice, Quot.sound] /- repeatedly apply peephole on program. -/ section SimpPeepholeApplier From 8b3a28f3ee5a9c54fc81ee3bcedcfd5250c56002 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 8 Dec 2023 10:51:13 +0530 Subject: [PATCH 23/28] chore: more nuking dead code in framework --- SSA/Core/Framework.lean | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index 45820fb80..09ddd824b 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -517,7 +517,6 @@ structure DialectMorphism (Op Op' : Type) {Ty Ty' : Type} [OpSignature Op Ty] [O mapTy : Ty → Ty' preserves_signature : ∀ op, signature (mapOp op) = mapTy <$> (signature op) - variable {Op Op' Ty Ty : Type} [OpSignature Op Ty] [OpSignature Op' Ty'] (f : DialectMorphism Op Op') @@ -1233,7 +1232,6 @@ theorem denote_rewritePeephole (fuel : ℕ) simp[rewritePeephole, denote_rewritePeephole_go] end SimpPeepholeApplier -#check DialectMorphism /-- `simp_peephole [t1, t2, ... tn]` at Γ simplifies the evaluation of the context Γ, leaving behind a bare Lean level proposition to be proven. @@ -1241,11 +1239,6 @@ leaving behind a bare Lean level proposition to be proven. macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : tactic => `(tactic| ( - -- try simp (config := {decide := false}) only [ - -- Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, - -- Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, - -- DialectMorphism.mapOp, DialectMorphism.mapTy, List.map - -- ] -- separate `simp` block so it does not fail if MLIR.AST is not open. try simp (config := {decide := false}) only [ Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, @@ -1259,10 +1252,6 @@ macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : Ctxt.snoc, Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, List.map, DialectMorphism.mapOp, DialectMorphism.mapTy, $ts,*] - -- try simp (config := {decide := false}) only [ - -- Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, - -- Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last - -- ] -- separate `simp` block so it does not fail if MLIR.AST is not open. generalize $ll { val := 0, property := _ } = a; generalize $ll { val := 1, property := _ } = b; generalize $ll { val := 2, property := _ } = c; @@ -1286,7 +1275,6 @@ macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : ) ) -open MLIR AST in /-- `simp_peephole` with no extra user defined theorems. -/ macro "simp_peephole" "at" ll:ident : tactic => `(tactic| simp_peephole [] at $ll) From 06fe56d76c668a5a09d6977e3bfd6efc32085cae Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 8 Dec 2023 10:51:37 +0530 Subject: [PATCH 24/28] remove dead code from PaperExamples --- SSA/Projects/PaperExamples/PaperExamples.lean | 2 -- 1 file changed, 2 deletions(-) diff --git a/SSA/Projects/PaperExamples/PaperExamples.lean b/SSA/Projects/PaperExamples/PaperExamples.lean index 75db08391..a2ea8841d 100644 --- a/SSA/Projects/PaperExamples/PaperExamples.lean +++ b/SSA/Projects/PaperExamples/PaperExamples.lean @@ -96,8 +96,6 @@ def mkReturn (Γ : Ctxt Ty) (opStx : MLIR.AST.Op 0) : MLIR.AST.ReaderM Op (Σ ty instance : MLIR.AST.TransformReturn Op Ty 0 where mkReturn := mkReturn --- open InstCombine (Op Ty) in - def mlir2simple (reg : MLIR.AST.Region 0) : MLIR.AST.ExceptM Op (Σ (Γ : Ctxt Ty) (ty : Ty), Com Op Γ ty) := MLIR.AST.mkCom reg From 0ace0c1f94339274ef83e65e74bfea68868ce6b7 Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Mon, 8 Jan 2024 17:42:48 +0000 Subject: [PATCH 25/28] Fix formatting messed up by merge --- SSA/Core/Framework.lean | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index cb6593ceb..9e0d0435f 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -1204,8 +1204,8 @@ section SimpPeephole section SimpPeepholeApplier /-- rewrite with `pr` to `target` program, at location `ix` and later, running at most `fuel` steps. -/ -def rewritePeephole_go (fuel : ℕ) (pr : PeepholeRewrite Op Γ t) (ix : ℕ) (target : Com Op Γ₂ t₂) : - (Com Op Γ₂ t₂) := +def rewritePeephole_go (fuel : ℕ) (pr : PeepholeRewrite Op Γ t) + (ix : ℕ) (target : Com Op Γ₂ t₂) : (Com Op Γ₂ t₂) := match fuel with | 0 => target | fuel' + 1 => @@ -1213,23 +1213,25 @@ def rewritePeephole_go (fuel : ℕ) (pr : PeepholeRewrite Op Γ t) (ix : ℕ) (t rewritePeephole_go fuel' pr (ix + 1) target' /-- rewrite with `pr` to `target` program, running at most `fuel` steps. -/ -def rewritePeephole (fuel : ℕ) (pr : PeepholeRewrite Op Γ t) (target : Com Op Γ₂ t₂) : - (Com Op Γ₂ t₂) := rewritePeephole_go fuel pr 0 target +def rewritePeephole (fuel : ℕ) + (pr : PeepholeRewrite Op Γ t) (target : Com Op Γ₂ t₂) : (Com Op Γ₂ t₂) := + rewritePeephole_go fuel pr 0 target /-- `rewritePeephole_go` preserve semantics -/ theorem denote_rewritePeephole_go (pr : PeepholeRewrite Op Γ t) (pos : ℕ) (target : Com Op Γ₂ t₂) : (rewritePeephole_go fuel pr pos target).denote = target.denote := by - induction fuel generalizing pr pos target - case zero => simp[rewritePeephole_go, denote_rewritePeepholeAt] - case succ fuel' hfuel => - simp[rewritePeephole_go, denote_rewritePeepholeAt, hfuel] + induction fuel generalizing pr pos target + case zero => + simp[rewritePeephole_go, denote_rewritePeepholeAt] + case succ fuel' hfuel => + simp[rewritePeephole_go, denote_rewritePeepholeAt, hfuel] /-- `rewritePeephole` preserves semantics. -/ theorem denote_rewritePeephole (fuel : ℕ) - (pr : PeepholeRewrite Op Γ t) (target : Com Op Γ₂ t₂) : + (pr : PeepholeRewrite Op Γ t) (target : Com Op Γ₂ t₂) : (rewritePeephole fuel pr target).denote = target.denote := by - simp[rewritePeephole, denote_rewritePeephole_go] + simp[rewritePeephole, denote_rewritePeephole_go] end SimpPeepholeApplier /-- From f36f8caf25602a49a699ab299f88d104cc8a0a27 Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Mon, 8 Jan 2024 18:00:44 +0000 Subject: [PATCH 26/28] Cleanup dead code --- SSA/Core/Framework.lean | 2 -- 1 file changed, 2 deletions(-) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index 9e0d0435f..43931c4ee 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -1198,8 +1198,6 @@ theorem denote_rewritePeepholeAt (pr : PeepholeRewrite Op Γ t) section SimpPeephole --- #print axioms denote_rewritePeepholeAt -- [propext, Classical.choice, Quot.sound] - /- repeatedly apply peephole on program. -/ section SimpPeepholeApplier From 70bd6e3ce70e31ab6533cd425ca6fe4a50aabc88 Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Tue, 9 Jan 2024 16:47:28 +0000 Subject: [PATCH 27/28] Deduplicate simp-list --- SSA/Core/Framework.lean | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/SSA/Core/Framework.lean b/SSA/Core/Framework.lean index 43931c4ee..21155b490 100644 --- a/SSA/Core/Framework.lean +++ b/SSA/Core/Framework.lean @@ -1242,15 +1242,11 @@ macro "simp_peephole" "[" ts: Lean.Parser.Tactic.simpLemma,* "]" "at" ll:ident : try simp (config := {decide := false}) only [ Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, - DialectMorphism.mapOp, DialectMorphism.mapTy, List.map, Com.denote, Expr.denote, HVector.denote, Var.zero_eq_last, Var.succ_eq_toSnoc, Ctxt.empty, Ctxt.empty_eq, Ctxt.snoc, Ctxt.Valuation.nil, Ctxt.Valuation.snoc_last, Ctxt.ofList, Ctxt.Valuation.snoc_toSnoc, HVector.map, HVector.toPair, HVector.toTuple, OpDenote.denote, Expr.op_mk, Expr.args_mk, - DialectMorphism.mapOp, DialectMorphism.mapTy, List.map, - Int.ofNat_eq_coe, Nat.cast_zero, Ctxt.DerivedCtxt.snoc, Ctxt.DerivedCtxt.ofCtxt, - Ctxt.snoc, Ctxt.DerivedCtxt.ofCtxt_empty, Ctxt.Valuation.snoc_last, List.map, - DialectMorphism.mapOp, DialectMorphism.mapTy, + DialectMorphism.mapOp, DialectMorphism.mapTy, List.map, Ctxt.snoc, List.map, $ts,*] generalize $ll { val := 0, property := _ } = a; generalize $ll { val := 1, property := _ } = b; From 94ca67d72d97608a270da158b98597c50874c1c2 Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Tue, 9 Jan 2024 16:57:59 +0000 Subject: [PATCH 28/28] Fix indentation --- SSA/Projects/InstCombine/LLVM/EDSL.lean | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 70cf4990c..b2dc08875 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -15,10 +15,10 @@ namespace InstcombineTransformDialect def mkUnaryOp {Γ : Ctxt (MTy φ)} {ty : (MTy φ)} (op : MOp φ) (e : Ctxt.Var Γ ty) : MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := - match ty with - | .bitvec w => - match op with - -- Can't use a single arm, Lean won't write the rhs accordingly + match ty with + | .bitvec w => + match op with + -- Can't use a single arm, Lean won't write the rhs accordingly | .neg w' => if h : w = w' then return ⟨ .neg w', @@ -47,10 +47,10 @@ def mkUnaryOp {Γ : Ctxt (MTy φ)} {ty : (MTy φ)} (op : MOp φ) def mkBinOp {Γ : Ctxt (MTy φ)} {ty : (MTy φ)} (op : MOp φ) (e₁ e₂ : Ctxt.Var Γ ty) : MLIR.AST.ExceptM (MOp φ) <| Expr (MOp φ) Γ ty := - match ty with - | .bitvec w => - match op with - -- Can't use a single arm, Lean won't write the rhs accordingly + match ty with + | .bitvec w => + match op with + -- Can't use a single arm, Lean won't write the rhs accordingly | .add w' => if h : w = w' then return ⟨ .add w',