Skip to content
This repository has been archived by the owner on Jun 17, 2024. It is now read-only.

Rewriter #29

Draft
wants to merge 23 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions MLIR/Semantics/Dominance.lean
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def singleBBRegionOpObeySSA (op: Op δ) (ctx: DomContext δ) : Option (DomContex
| Op.mk _ operands [] regions _ (MLIRType.fn (MLIRType.tuple operandsTy) _) => do
let b := operandsDefinitionObeySSA operands operandsTy ctx
match b with
| true => (singleBBRegionRegionsObeySSA regions ctx)
| true => do
let _ ← singleBBRegionRegionsObeySSA regions ctx
ctx
| false => none
| _ => none

Expand All @@ -72,12 +74,12 @@ def singleBBRegionRegionsObeySSA (regions: List (Region δ)) (ctx: DomContext δ
| region::regions' => do
let _ <- (singleBBRegionRegionObeySSA region ctx)
(singleBBRegionRegionsObeySSA regions' ctx)
| [] => none
| [] => some ctx

def singleBBRegionRegionObeySSA (region: Region δ) (ctx: DomContext δ) : Option (DomContext δ) :=
match region with
| .mk [] => ctx
| .mk [bb] => (singleBBRegionBBObeySSA bb ctx)
| .mk [bb] => singleBBRegionBBObeySSA bb ctx
| _ => Option.none

def singleBBRegionBBObeySSA (bb: BasicBlock δ) (ctx: DomContext δ) : Option (DomContext δ) :=
Expand All @@ -93,13 +95,16 @@ def singleBBRegionStmtsObeySSA (stmts: List (BasicBlockStmt δ)) (ctx: DomContex

def singleBBRegionStmtObeySSA (stmt: BasicBlockStmt δ) (ctx: DomContext δ) : Option (DomContext δ) :=
match stmt with
| .StmtOp op => singleBBRegionOpObeySSA op ctx
| .StmtOp op => do
_ ← singleBBRegionOpObeySSA op ctx
ctx
| .StmtAssign res none op => do
-- TODO: replace it with an `as`, when I'll know how to do it
let ctx' <- match op with
| Op.mk _ _ _ _ _ (MLIRType.fn _ (MLIRType.tuple [τ])) => (valDefinitionObeySSA res τ ctx)
| _ => none
singleBBRegionOpObeySSA op ctx
let _ ← singleBBRegionOpObeySSA op ctx
ctx'
| _ => none
end
termination_by
Expand Down
171 changes: 168 additions & 3 deletions MLIR/Semantics/Matching.lean
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,175 @@ private def multiple_example: Op builtin := [mlir_op|

-- Match an MTerm program in some IR, then concretize
-- the MTerm using the resulting matching context.
def multiple_example_result : Option (List (BasicBlockStmt builtin)) := do
let (val, ctx) ←
private def multiple_example_result : Option (List (BasicBlockStmt builtin)) := do
let (_, ctx) ←
matchMProgInOp multiple_example test_addi_multiple_pattern []
let res ← MTerm.concretizeProg test_addi_multiple_pattern ctx
val
res

#eval multiple_example_result

/-
### Exact program matching

This section defines functions to check if an operation, or SSA values
definitions/uses are inside a bigger program.
-/

mutual
variable (mOp: BasicBlockStmt δ)

def isOpInOp (op: Op δ) : Bool :=
match op with
| .mk _ _ _ regions _ _ => isOpInRegions regions

def isOpInRegions (regions: List (Region δ)) : Bool :=
match regions with
| [] => False
| region::regions' => isOpInRegion region || isOpInRegions regions'

def isOpInRegion (region: Region δ) : Bool :=
match region with
| .mk bbs => isOpInBBs bbs

def isOpInBBs (bbs: List (BasicBlock δ)) : Bool :=
match bbs with
| [] => False
| bb::bbs' => isOpInBB bb || isOpInBBs bbs'

def isOpInBB (bb: BasicBlock δ) : Bool :=
match bb with
| .mk _ _ stmts => isOpInBBStmts stmts

def isOpInBBStmts (stmts: List (BasicBlockStmt δ)) : Bool :=
match stmts with
| [] => False
| stmt::stmts' => isOpInBBStmt stmt || isOpInBBStmts stmts'

def isOpInBBStmt (stmt: BasicBlockStmt δ) : Bool :=
match stmt, mOp with
| .StmtOp op, _ => isOpInOp op
| .StmtAssign res ix (Op.mk name operands [] [] (AttrDict.mk []) typ),
.StmtAssign res' ix' (Op.mk name' operands' [] [] (AttrDict.mk []) typ') =>
res == res' && ix == ix' && name == name' && operands == operands' && typ == typ'
| .StmtAssign _ _ op, _ => isOpInOp op
end


mutual
variable (mVar: SSAVal)

def isSSADefInOp (op: Op δ) : Bool :=
match op with
| .mk _ _ _ regions _ _ => isSSADefInRegions regions

def isSSADefInRegions (regions: List (Region δ)) : Bool :=
match regions with
| [] => False
| region::regions' => isSSADefInRegion region || isSSADefInRegions regions'

def isSSADefInRegion (region: Region δ) : Bool :=
match region with
| .mk bbs => isSSADefInBBs bbs

def isSSADefInBBs (bbs: List (BasicBlock δ)) : Bool :=
match bbs with
| [] => False
| bb::bbs' => isSSADefInBB bb || isSSADefInBBs bbs'

def isSSADefInBB (bb: BasicBlock δ) : Bool :=
match bb with
| .mk _ _ stmts => isSSADefInBBStmts stmts

def isSSADefInBBStmts (stmts: List (BasicBlockStmt δ)) : Bool :=
match stmts with
| [] => False
| stmt::stmts' => isSSADefInBBStmt stmt || isSSADefInBBStmts stmts'

def isSSADefInBBStmt (stmt: BasicBlockStmt δ) : Bool :=
match stmt with
| .StmtOp op => isSSADefInOp op
| .StmtAssign res _ op => res == mVar || isSSADefInOp op
end

mutual
variable (mVar: SSAVal)

def isSSAUseInOp (op: Op δ) : Bool :=
match op with
| .mk _ args _ regions _ _ =>
args.contains mVar || isSSAUseInRegions regions

def isSSAUseInRegions (regions: List (Region δ)) : Bool :=
match regions with
| [] => False
| region::regions' => isSSAUseInRegion region || isSSAUseInRegions regions'

def isSSAUseInRegion (region: Region δ) : Bool :=
match region with
| .mk bbs => isSSAUseInBBs bbs

def isSSAUseInBBs (bbs: List (BasicBlock δ)) : Bool :=
match bbs with
| [] => False
| bb::bbs' => isSSAUseInBB bb || isSSAUseInBBs bbs'

def isSSAUseInBB (bb: BasicBlock δ) : Bool :=
match bb with
| .mk _ _ stmts => isSSAUseInBBStmts stmts

def isSSAUseInBBStmts (stmts: List (BasicBlockStmt δ)) : Bool :=
match stmts with
| [] => False
| stmt::stmts' => isSSAUseInBBStmt stmt || isSSAUseInBBStmts stmts'

def isSSAUseInBBStmt (stmt: BasicBlockStmt δ) : Bool :=
match stmt with
| .StmtOp op => isSSAUseInOp op
| .StmtAssign res _ op => isSSAUseInOp op
end

mutual
variable (mVar: SSAVal)

def getDefiningOpInOp (op: Op δ) : Option (BasicBlockStmt δ) :=
match op with
| .mk _ _ _ regions _ _ => getDefiningOpInRegions regions

def getDefiningOpInRegions (regions: List (Region δ)) : Option (BasicBlockStmt δ) :=
match regions with
| [] => none
| region::regions' =>
match getDefiningOpInRegion region with
| some op => some op
| none => getDefiningOpInRegions regions'

def getDefiningOpInRegion (region: Region δ) : Option (BasicBlockStmt δ) :=
match region with
| .mk bbs => getDefiningOpInBBs bbs

def getDefiningOpInBBs (bbs: List (BasicBlock δ)) : Option (BasicBlockStmt δ) :=
match bbs with
| [] => none
| bb::bbs' =>
match getDefiningOpInBB bb with
| some op => some op
| none => getDefiningOpInBBs bbs'

def getDefiningOpInBB (bb: BasicBlock δ) : Option (BasicBlockStmt δ) :=
match bb with
| .mk _ _ stmts => getDefiningOpInBBStmts stmts

def getDefiningOpInBBStmts (stmts: List (BasicBlockStmt δ)) : Option (BasicBlockStmt δ) :=
match stmts with
| [] => none
| stmt::stmts' =>
match getDefiningOpInBBStmt stmt with
| some op => some op
| none => getDefiningOpInBBStmts stmts'

def getDefiningOpInBBStmt (stmt: BasicBlockStmt δ) : Option (BasicBlockStmt δ) :=
match stmt with
| .StmtOp op => getDefiningOpInOp op
| .StmtAssign res _ op => if res == mVar then some stmt else getDefiningOpInOp op
end
37 changes: 37 additions & 0 deletions MLIR/Semantics/Refinement.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/-
## Refinement

This file defines the definition of a refinement.
-/

import MLIR.Semantics.Semantics
import MLIR.Dialects.ArithSemantics
import MLIR.Semantics.UB
open MLIR.AST

def SSAEnv.refinement (env1 env2: SSAEnv δ) : Prop :=
∀ v val, env2.getT v = some val -> env1.getT v = some val

-- One SSA environment is a refinement of the other if all variables
-- defined in the environment are also defined by the other first one.
-- If the SSA environment is none (corresponding to an UB), then we have
-- a refinement for all other possible SSA environment, or UB
def refinement (env1 env2: Option R × SSAEnv δ) : Prop :=
match env1, env2 with
| (none, _), _ => True
| (some r1, env1), (some r2, env2) =>
r1 = r2 ∧ SSAEnv.refinement env1 env2
| _, _ => False

theorem SSAEnv.refinement_set :
refinement env1 env2 ->
refinement (SSAEnv.set name τ val env1) (SSAEnv.set name τ val env2) := by
simp [refinement, SSAEnv.refinement]
intros Href name' val' Hget
byCases H: name = name'
. rw [SSAEnv.getT_set_eq] <;> try assumption
rw [SSAEnv.getT_set_eq] at Hget <;> assumption
. rw [SSAEnv.getT_set_ne] <;> try assumption
rw [SSAEnv.getT_set_ne] at Hget <;> try assumption
apply Href
assumption
Loading