From b52c0673dd0bb572847118472318f03c014fd2ef Mon Sep 17 00:00:00 2001 From: ineol Date: Mon, 14 Oct 2024 18:17:51 +0100 Subject: [PATCH] Use automatic structures for bv_automata (#685) --- SSA/Experimental/Bits/AutoStructs/Basic.lean | 446 ++++++++++ SSA/Experimental/Bits/AutoStructs/Defs.lean | 283 +++++++ .../Bits/AutoStructs/FinEnum.lean | 41 + .../Bits/AutoStructs/FiniteStateMachine.lean | 794 ++++++++++++++++++ .../Bits/AutoStructs/ForLean.lean | 9 + .../Bits/AutoStructs/FormulaToAuto.lean | 324 +++++++ SSA/Experimental/Bits/AutoStructs/Tactic.lean | 311 +++++++ SSA/Experimental/Bits/Fast/Tactic.lean | 35 +- SSA/Experimental/Bits/SafeNativeDecide.lean | 38 + SSA/Projects/InstCombine/TacticAuto.lean | 8 + 10 files changed, 2255 insertions(+), 34 deletions(-) create mode 100644 SSA/Experimental/Bits/AutoStructs/Basic.lean create mode 100644 SSA/Experimental/Bits/AutoStructs/Defs.lean create mode 100644 SSA/Experimental/Bits/AutoStructs/FinEnum.lean create mode 100644 SSA/Experimental/Bits/AutoStructs/FiniteStateMachine.lean create mode 100644 SSA/Experimental/Bits/AutoStructs/ForLean.lean create mode 100644 SSA/Experimental/Bits/AutoStructs/FormulaToAuto.lean create mode 100644 SSA/Experimental/Bits/AutoStructs/Tactic.lean create mode 100644 SSA/Experimental/Bits/SafeNativeDecide.lean diff --git a/SSA/Experimental/Bits/AutoStructs/Basic.lean b/SSA/Experimental/Bits/AutoStructs/Basic.lean new file mode 100644 index 000000000..38f730020 --- /dev/null +++ b/SSA/Experimental/Bits/AutoStructs/Basic.lean @@ -0,0 +1,446 @@ +/- +Released under Apache 2.0 license as described in the file LICENSE. +-/ + +import Std.Data.HashSet +import Std.Data.HashMap +import Std.Data.HashMap.Lemmas +import Mathlib.Data.FinEnum +import Mathlib.Data.Finset.Basic +import Mathlib.Data.Finset.Card +import Mathlib.Data.List.Infix +import SSA.Experimental.Bits.AutoStructs.ForLean +import SSA.Experimental.Bits.AutoStructs.FinEnum + +abbrev State := Nat + +-- where to add the wellformedness conditions? a typeclass? +structure NFA (A : Type 0) [BEq A] [Hashable A] [DecidableEq A] [FinEnum A] where + stateMax : State + initials : Std.HashSet State + finals : Std.HashSet State + trans : Std.HashMap (State × A) (Std.HashSet State) +deriving Repr + +section basics + +variable {A : Type} [BEq A] [Hashable A] [DecidableEq A] [FinEnum A] + +def NFA.empty : NFA A := { + stateMax := 0 + initials := ∅ + finals := ∅ + trans := ∅ +} + +def NFA.newState (m : NFA A) : State × NFA A := + let old := m.stateMax + let m := { m with stateMax := old + 1 } + (old, m) + +def NFA.addTrans (m : NFA A) (a : A) (s s' : State) : NFA A := + let ns := m.trans.getD (s, a) ∅ + let ns := ns.insert s' + { m with trans := m.trans.insert (s, a) ns } + +def NFA.addManyTrans (m : NFA A) (a : List A) (s s' : State) : NFA A := + a.foldl (init := m) fun m a => m.addTrans a s s' + +def NFA.addInitial (m : NFA A) (s : State) : NFA A := + { m with initials := m.initials.insert s } + +def NFA.addFinal (m : NFA A) (s : State) : NFA A := + { m with finals := m.finals.insert s } + +def NFA.transSet (m : NFA A) (ss : Std.HashSet State) (a : A) : Std.HashSet State := + ss.fold (init := ∅) fun ss' s => + ss'.insertMany $ m.trans.getD (s, a) ∅ + +instance NFA_Inhabited : Inhabited (NFA A) where + default := NFA.empty + +end basics + +-- A generic function to define automata constructions using worklist algorithms +section worklist + +/- Question: +why does `Decidable (s ∈ l)` require `LawfulBEq` if `l` is a list and `DecidableEq` if `l` is an array? +-/ + +variable (A : Type) [BEq A] [Hashable A] [DecidableEq A] [FinEnum A] +variable {S : Type} [BEq S] [LawfulBEq S] [Hashable S] [DecidableEq S] +variable (stateSpace : Finset S) + +abbrev inSS : Type := { sa : S // sa ∈ stateSpace } + +private structure worklist.St where + m : NFA A + map : Std.HashMap (inSS stateSpace) State := ∅ + worklist : Array { sa : inSS stateSpace // sa ∈ map.keys} := ∅ + worklist_nodup : worklist.toList.Nodup + +private def worklist.St.meas (st : worklist.St A stateSpace) : ℕ := + Finset.card $ stateSpace.attach.filter fun x => x ∉ st.map.keys ∨ x ∈ st.worklist.toList.map fun x => x.val + +private def worklist.St.addOrCreateState (st : worklist.St A stateSpace) (final? : Bool) (sa : S) : State × worklist.St A stateSpace := sorry + +theorem List.dropLast_nodup (l : List X) : l.Nodup → l.dropLast.Nodup := by + have hsl := List.dropLast_sublist l + apply List.Nodup.sublist; trivial + +theorem List.dropLasnodup (l : List X) : l.Nodup → l.dropLast.Nodup := by + have hsl := List.dropLast_sublist l + apply List.Nodup.sublist; trivial + +@[simp] +theorem Array.not_elem_back_pop (a : Array X) (x : X) : a.toList.Nodup → a.back? = some x → x ∉ a.pop := by sorry + +@[simp] +theorem Array.not_elem_back_pop_list (a : Array X) (x : X) : a.toList.Nodup → a.back? = some x → x ∉ a.toList.dropLast := by sorry + +@[simp] +theorem Array.back_mem (a : Array X) (x : X) : a.back? = some x → x ∈ a := by sorry + +def worklist.initState (init : inSS stateSpace) : worklist.St A stateSpace := + let m := NFA.empty + let (s, m) := m.newState + let m := m.addInitial s + let map : Std.HashMap _ _ := {(init, s)} + let init' : { sa : inSS stateSpace // sa ∈ map.keys } := ⟨init, by sorry /- TODO: should be trivial -/⟩ + let worklist := Array.singleton init' + { m, map, worklist, worklist_nodup := by simp [worklist] } + +def worklistRunAux (final : S → Bool) (f : S → Array (A × {sa : S | sa ∈ stateSpace })) (init : inSS stateSpace) : NFA A := + let st0 := worklist.initState _ _ init + go st0 +where go (st0 : worklist.St A stateSpace) : NFA A := + if hemp : st0.worklist.isEmpty then st0.m else + let sa? := st0.worklist.back? + have h1 : match sa? with + | none => True + | some _ => { st0 with worklist := st0.worklist.pop, worklist_nodup := by apply List.dropLast_nodup; exact st0.worklist_nodup }.meas < st0.meas := by + have hrem : sa? = st0.worklist.back? := by simp + rcases sa? with ⟨⟩ | ⟨⟨sa, hsa1⟩, hsa2⟩ <;> simp + apply Finset.card_lt_card + simp [worklist.St.meas, Finset.ssubset_iff, Finset.subset_iff] + use sa + use hsa1 + constructor + { constructor; assumption; symm at hrem; apply Array.not_elem_back_pop_list at hrem; intros hc; simp at hrem + rw [← List.map_dropLast] at hc + apply List.exists_of_mem_map at hc + rcases hc with ⟨⟨⟨sa', hsa''⟩, hsa'⟩, hc, heq⟩ + simp at heq + subst heq + aesop + exact st0.worklist_nodup } + constructor + { right; use hsa2; symm at hrem; apply Array.back_mem at hrem; apply Array.Mem.val; trivial } + rintro sa' hsa' hh; rcases hh with hnin | hin + { left; trivial } + right + rw [← List.map_dropLast] at hin + apply List.exists_of_mem_map at hin + rcases hin with ⟨⟨⟨sa', hsa''⟩, hsa'⟩, hc, heq⟩ + simp at heq + rcases heq + use hsa' + apply List.mem_of_mem_dropLast + trivial + + match sa? with + | some sa => + let wl := st0.worklist.pop + let st1 := { st0 with worklist := wl, worklist_nodup := by simp [wl]; apply List.dropLast_nodup; exact st0.worklist_nodup } + if let some s := st1.map.get? sa then + let a := f sa + let st2 := a.foldl (init := st1) fun st (a', sa') => + let (s', st') := st.addOrCreateState _ _ (final sa') sa' + have : st'.meas ≤ st.meas := by sorry -- should follow from the aux function's properties + let m := st'.m.addTrans a' s s' + { st' with m } + have hincl : st1.map.keys ⊆ st2.map.keys := by sorry + have : st1.meas < st0.meas := by simp at h1; simp [st1]; assumption + have : st2.meas ≤ st1.meas := by + apply Finset.card_le_card + simp [worklist.St.meas, Finset.subset_iff] + intros sa' hsa' h -- we need to know that map.keys is ever growing too + rcases h with hnin | ⟨hsa'', hin⟩ + { left; simp [st1] at hincl; intros hc; apply hnin; apply hincl; assumption } + by_cases hnew : ⟨sa', hsa'⟩ ∈ st0.map.keys + all_goals try (left; trivial) + right; use hnew + sorry + have : st2.meas < st0.meas := by omega + go st2 + else + st0.m -- never happens + | none => st0.m -- never happens + termination_by st0.meas + +end worklist + +section sink + +variable {A : Type} [BEq A] [Hashable A] [DecidableEq A] [FinEnum A] + +def NFA.addSink (m : NFA A) : NFA A := + let (sink, m) := m.newState + -- let m := m.addInitial sink -- TODO(leo) is that right? + (List.range m.stateMax).foldl (init := m) fun m s => + (FinEnum.toList (α := A)).foldl (init := m) fun m a => + let stuck := if let some trans := m.trans.get? (s, a) then trans.isEmpty else true + if stuck then m.addTrans a s sink else m + +def NFA.flipFinals (m : NFA A) : NFA A := + let oldFinals := m.finals + let newFinals := (List.range m.stateMax).foldl (init := ∅) fun fins s => + if oldFinals.contains s then fins else fins.insert s + { m with finals := newFinals } + +end sink + +section product + +variable {A : Type} [BEq A] [Hashable A] [DecidableEq A] [FinEnum A] + +private structure product.State where + m : NFA A + map : Std.HashMap (State × State) State := ∅ + worklist : Array (State × State) := ∅ + +private def product.State.measure (st : product.State (A := A)) (m1 m2 : NFA A) := + ((Multiset.range m1.stateMax).product (Multiset.range m2.stateMax)).sub + ((Multiset.ofList st.map.keys).sub (Multiset.ofList st.worklist.toList)) + +def product (final? : Bool → Bool → Bool) (m1 m2 : NFA A) : NFA A := + let map : Std.HashMap (State × State) State := ∅ + let worklist : Array (State × State) := ∅ + let st : product.State (A := A) := { m := NFA.empty } + let st := init st + go st +where init (st : product.State (A := A)) : product.State (A := A) := + m1.initials.fold (init := st) fun st s1 => + m2.initials.fold (init := st) fun st s2 => + let (s, m) := st.m.newState + let m := m.addInitial s + let map := st.map.insert (s1, s2) s + let worklist := st.worklist.push (s1, s2) + {m, map, worklist} + + go (st0 : product.State (A := A)) : NFA A := Id.run do + if hne : st0.worklist.size == 0 then + return st0.m + else + let some (s1, s2) := st0.worklist.get? (st0.worklist.size - 1) | return st0.m + let st := { st0 with worklist := st0.worklist.pop } + let some s := st.map.get? (s1, s2) | return NFA.empty + let st := if final? (m1.finals.contains s1) (m2.finals.contains s2) then + { st with m := st.m.addFinal s } + else + st + let st := (FinEnum.toList (α := A)).foldl (init := st) fun st a => + if let some s1trans := m1.trans.get? (s1, a) then + if let some s2trans := m2.trans.get? (s2, a) then + s1trans.fold (init := st) fun st s1' => + s2trans.fold (init := st) fun st s2' => + if let some s' := st.map.get? (s1', s2') then + let m := st.m.addTrans a s s' + { st with m } + else + let (s', m) := st.m.newState + let worklist := st.worklist.push (s1', s2') + let map := st.map.insert (s1', s2') s' + let m := m.addTrans a s s' + { m, map, worklist } + else + st + else + st + go st + termination_by st0.measure m1 m2 + -- for termination, we need to know that worklist ⊆ map.keys ⊆ [0..max1] × [0..max2] + -- or something like this + decreasing_by { + sorry + -- apply List.foldl_lt (fun st : product.State => sizeOf (st.measure m1 m2)) + -- · simp + -- cases (final? (m1.finals.contains s1) (m2.finals.contains s2)) + -- { simp [product.State.measure, sizeOf] + -- apply Multiset.sizeOf_subset + -- { sorry } + -- { intros contra + -- congr contra + + -- } + -- } + -- { sorry } + -- · sorry + } + + + +def NFA.inter (m1 m2 : NFA A) : NFA A := product (fun b1 b2 => b1 && b2) m1 m2 +def NFA.union (m1 m2 : NFA A) : NFA A := + -- FIXME add a sink state to each automata, or modify product + product (fun b1 b2 => b1 || b2) m1.addSink m2.addSink + +end product + +def HashSet.inter [BEq A] [Hashable A] (m1 m2 : Std.HashSet A) : Std.HashSet A := + m1.fold (init := Std.HashSet.empty) fun mi x => if m2.contains x then mi.insert x else mi + +def Std.HashSet.isDisjoint [BEq A] [Hashable A] (m1 m2 : Std.HashSet A) : Bool := + (HashSet.inter m1 m2).isEmpty + +def HashSet.areIncluded [BEq A] [Hashable A] (m1 m2 : Std.HashSet A) : Bool := + m1.all (fun x => m2.contains x) + +section determinization + +variable {A : Type} [BEq A] [Hashable A] [DecidableEq A] [FinEnum A] + +instance hashableHashSet [Hashable A] : Hashable (Std.HashSet A) where + hash s := s.fold (init := 0) fun h x => mixHash h (hash x) + +private structure NFA.determinize.St where + map : Std.HashMap (Std.HashSet State) State + worklist : List (Std.HashSet State) + m : NFA A + +open NFA.determinize in +def NFA.determinize (mi : NFA A) : NFA A := + let m : NFA A := NFA.empty + let (si, m) := m.newState + let m := m.addInitial si + let map : Std.HashMap _ _ := { (mi.initials, si) } + let st : St := { map, worklist := [mi.initials], m} + go st +where go (st : St) : NFA A := Id.run do + if let some (ss, worklist) := st.worklist.next? then + let st := { st with worklist } + let some s := st.map[ss]? | return NFA.empty + let m := if !ss.isDisjoint mi.finals then st.m.addFinal s else st.m + let st := (FinEnum.toList A).foldl (init := { st with m }) fun st a => + let ss' := mi.transSet ss a + let (s', st) := if let some s' := st.map[ss']? then (s', st) else + let (s', m) := st.m.newState + let map := st.map.insert ss' s' + let worklist := ss' :: st.worklist + (s', { map, worklist, m }) + { st with m := st.m.addTrans a s s'} + go st + else + st.m + decreasing_by sorry + +def NFA.neg (m : NFA A) : NFA A := m.determinize.flipFinals + +end determinization + +section universality + +variable {A : Type} [BEq A] [Hashable A] [DecidableEq A] [FinEnum A] + +private structure isUniversal.State where + visited : List (Std.HashSet State) := ∅ -- TODO: slow + worklist : List (Std.HashSet State) := ∅ + +/-- Returns true when `L(m) = A*` -/ +def NFA.isUniversal (m : NFA A) : Bool := + let st := { visited := [], worklist := [m.initials]} + go st +where go (st : isUniversal.State) : Bool := + if let some (ss, worklist) := st.worklist.next? then + let st := { st with worklist } + if ss.isDisjoint m.finals then + false + else + let st := { st with visited := ss :: st.visited } + let st := (FinEnum.toList (α := A)).foldl (init := st) fun st a => + let ss' := m.transSet ss a + if st.worklist.any (fun ss'' => HashSet.areIncluded ss'' ss') || + st.visited.any (fun ss'' => HashSet.areIncluded ss'' ss') then + st + else + { st with worklist := ss'::st.worklist} + go st + else + true + decreasing_by sorry + +/-- Recognizes the empty word -/ +def NFA.emptyWord : NFA A := + let m := NFA.empty + let (s, m) := m.newState + let m := m.addInitial s + let m := m.addFinal s + m + +/-- Returns true when `L(m) ∪ {ε} = A*`. This is useful because the + bitvector of with width zero has strange properties. + -/ +def NFA.isUniversal' (m : NFA A) : Bool := + m.union NFA.emptyWord |> NFA.isUniversal + +-- TODO: this relies on the fact that all states are reachable! +def NFA.isEmpty (m : NFA A) : Bool := m.finals.isEmpty +def NFA.isNotEmpty (m : NFA A) : Bool := !m.finals.isEmpty + +end universality + +instance: Hashable (BitVec n) where + hash x := Hashable.hash x.toFin + +section lift_proj + +-- Defined as bv'[i] = bv[f i] +def transport (f : Fin n2 → Fin n1) (bv : BitVec n1) : BitVec n2 := + (Fin.list n2).foldl (init := BitVec.zero n2) fun bv' (i : Fin _) => bv' ||| (BitVec.twoPow n2 i * bv[f i].toNat) + +variable {n : Nat} + +-- Morally, n2 >= n1 +def NFA.lift (m1: NFA (BitVec n1)) (f : Fin n1 → Fin n2) : NFA (BitVec n2) := + let m2 : NFA (BitVec n2) := { m1 with trans := Std.HashMap.empty } + (List.range m2.stateMax).foldl (init := m2) fun m2 s => + (FinEnum.toList (BitVec n2)).foldl (init := m2) fun m2 (bv : BitVec n2) => + let newtrans := m1.trans.getD (s, transport f bv) ∅ + let oldtrans := m2.trans.getD (s, bv) ∅ + let trans := newtrans.union oldtrans + if trans.isEmpty then m2 else { m2 with trans := m2.trans.insert (s, bv) trans } + +-- Morally, n1 >= n2 +def NFA.proj (m1: NFA (BitVec n1)) (f : Fin n2 → Fin n1) : NFA (BitVec n2) := + let m2 : NFA (BitVec n2) := { m1 with trans := Std.HashMap.empty } + m1.trans.keys.foldl (init := m2) fun m2 (s, bv) => + let trans := m1.trans.getD (s, bv) ∅ + let bv' := transport f bv + let oldtrans := m2.trans.getD (s, bv') ∅ + { m2 with trans := m2.trans.insert (s, bv') (trans.union oldtrans) } + +end lift_proj + +/- + TODOs: + 1. to detect overflow (eg `AdditionNoOverflows?` in `HackersDelight.lean`), abs, ... + 2. clarify what happens with automata over the alphabet `BitVec 0`... + 3. for abs, do we have to eschew FSMs? + ... + n. Maybe we can deal with some shifts and powers of 2 if they are of the form `k` or `w - k`. + + To deal with the fact that some results only hold for `w > 0`, + we could first try to prove `w > 0` with omega, and then use this + fact automatically? + + Possible improvements for performance: + 0. Implement language equality directly instead of encoding it + 1. change the representation, eg. transitions as + Array (Array (A × State)) + as in LASH, or with BDD as in MONA + 2. Use the interleaving technique to reduce the number of transitions: + instead of having the alphabet `BitVec n`, it's simply `Bool` and the + representation of (a00, a01, .., a0k)(a10, a11, .., a1k)...(an0, an1, .., ank) + is (a00 a10 ... an0 a01 a11 ... an1 ...... ank) +-/ diff --git a/SSA/Experimental/Bits/AutoStructs/Defs.lean b/SSA/Experimental/Bits/AutoStructs/Defs.lean new file mode 100644 index 000000000..13e67e60f --- /dev/null +++ b/SSA/Experimental/Bits/AutoStructs/Defs.lean @@ -0,0 +1,283 @@ +/- +Released under Apache 2.0 license as described in the file LICENSE. +-/ +import Mathlib.Data.Bool.Basic +import Mathlib.Data.Fin.Basic +import SSA.Projects.InstCombine.ForLean +import SSA.Experimental.Bits.Fast.BitStream + +namespace AutoStructs + +/-! +# Term Language +This file defines the term language the decision procedure operates on, +and the denotation of these terms into operations on bitstreams -/ + +/-- A `Term` is an expression in the language our decision procedure operates on, +it represent an infinite bitstream (with free variables) -/ +inductive Term : Type +| var : Nat → Term +/-- The constant `0` -/ +| zero : Term +/-- The constant `-1` -/ +| negOne : Term +/-- The constant `1` -/ +| one : Term +/-- Bitwise and -/ +| and : Term → Term → Term +/-- Bitwise or -/ +| or : Term → Term → Term +/-- Bitwise xor -/ +| xor : Term → Term → Term +/-- Bitwise complement -/ +| not : Term → Term +-- /-- Append a single bit the start (i.e., least-significant end) of the bitstream -/ +-- | ls (b : Bool) : Term → Term +/-- Addition -/ +| add : Term → Term → Term +/-- Subtraction -/ +| sub : Term → Term → Term +/-- Negation -/ +| neg : Term → Term +/-- Increment (i.e., add one) -/ +| incr : Term → Term +/-- Decrement (i.e., subtract one) -/ +| decr : Term → Term +deriving Repr +-- /-- `repeatBit` is an operation that will repeat the infinitely repeat the +-- least significant `true` bit of the input. + +-- That is `repeatBit t` is all-zeroes iff `t` is all-zeroes. +-- Otherwise, there is some number `k` s.t. `repeatBit t` is all-ones after +-- dropping the least significant `k` bits -/ +-- | repeatBit : Term → Term + +instance : Inhabited Term where + default := .zero + +open Term + +instance : Add Term := ⟨add⟩ +instance : Sub Term := ⟨sub⟩ +instance : One Term := ⟨one⟩ +instance : Zero Term := ⟨zero⟩ +instance : Neg Term := ⟨neg⟩ + +/-- `t.arity` is the max free variable id that occurs in the given term `t`, +and thus is an upper bound on the number of free variables that occur in `t`. + +Note that the upper bound is not perfect: +a term like `var 10` only has a single free variable, but its arity will be `11` -/ +@[simp] def Term.arity : Term → Nat +| (var n) => n+1 +| zero => 0 +| one => 0 +| negOne => 0 +| Term.and t₁ t₂ => max (arity t₁) (arity t₂) +| Term.or t₁ t₂ => max (arity t₁) (arity t₂) +| Term.xor t₁ t₂ => max (arity t₁) (arity t₂) +| Term.not t => arity t +-- | ls _ t => arity t +| add t₁ t₂ => max (arity t₁) (arity t₂) +| sub t₁ t₂ => max (arity t₁) (arity t₂) +| neg t => arity t +| incr t => arity t +| decr t => arity t +-- | repeatBit t => arity t + + +/-- +Evaluate a term `t` to the BitVec it represents. + +This differs from `Term.eval` in that `Term.evalFin` uses `Term.arity` to +determine the number of free variables that occur in the given term, +and only require that many bitstream values to be given in `vars`. +-/ +@[simp] def Term.evalFin (t : Term) (vars : Fin (arity t) → BitVec w) : BitVec w := + match t with + | var n => vars (Fin.last n) + | zero => BitVec.zero w + | one => 1 + | negOne => -1 + | and t₁ t₂ => + let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i)) + let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i)) + x₁ &&& x₂ + | or t₁ t₂ => + let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i)) + let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i)) + x₁ ||| x₂ + | xor t₁ t₂ => + let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i)) + let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i)) + x₁ ^^^ x₂ + | not t => ~~~(t.evalFin vars) + -- | ls b t => (t.evalFin vars).concat b + | add t₁ t₂ => + let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i)) + let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i)) + x₁ + x₂ + | sub t₁ t₂ => + let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i)) + let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i)) + x₁ - x₂ + | neg t => -(Term.evalFin t vars) + | incr t => Term.evalFin t vars + 1 + | decr t => Term.evalFin t vars - 1 + -- | repeatBit t => BitStream.repeatBit (Term.evalFin t vars) + +@[simp] def Term.evalNat (t : Term) (vars : Nat → BitVec w) : BitVec w := + match t with + | var n => vars (Fin.last n) + | zero => BitVec.zero w + | one => 1 + | negOne => -1 + | and t₁ t₂ => + let x₁ := t₁.evalNat vars + let x₂ := t₂.evalNat vars + x₁ &&& x₂ + | or t₁ t₂ => + let x₁ := t₁.evalNat vars + let x₂ := t₂.evalNat vars + x₁ ||| x₂ + | xor t₁ t₂ => + let x₁ := t₁.evalNat vars + let x₂ := t₂.evalNat vars + x₁ ^^^ x₂ + | not t => ~~~(t.evalNat vars) + -- | ls b t => (t.evalNat vars).concat b + | add t₁ t₂ => + let x₁ := t₁.evalNat vars + let x₂ := t₂.evalNat vars + x₁ + x₂ + | sub t₁ t₂ => + let x₁ := t₁.evalNat vars + let x₂ := t₂.evalNat vars + x₁ - x₂ + | neg t => -(Term.evalNat t vars) + | incr t => Term.evalNat t vars + 1 + | decr t => Term.evalNat t vars - 1 + -- | repeatBit t => BitStream.repeatBit (Term.evalFin t vars) +@[simp] def Term.evalFinStream (t : Term) (vars : Fin (arity t) → BitStream) : BitStream := + match t with + | var n => vars (Fin.last n) + | zero => BitStream.zero + | one => BitStream.one + | negOne => BitStream.negOne + | and t₁ t₂ => + let x₁ := t₁.evalFinStream (fun i => vars (Fin.castLE (by simp [arity]) i)) + let x₂ := t₂.evalFinStream (fun i => vars (Fin.castLE (by simp [arity]) i)) + x₁ &&& x₂ + | or t₁ t₂ => + let x₁ := t₁.evalFinStream (fun i => vars (Fin.castLE (by simp [arity]) i)) + let x₂ := t₂.evalFinStream (fun i => vars (Fin.castLE (by simp [arity]) i)) + x₁ ||| x₂ + | xor t₁ t₂ => + let x₁ := t₁.evalFinStream (fun i => vars (Fin.castLE (by simp [arity]) i)) + let x₂ := t₂.evalFinStream (fun i => vars (Fin.castLE (by simp [arity]) i)) + x₁ ^^^ x₂ + | not t => ~~~(t.evalFinStream vars) + | add t₁ t₂ => + let x₁ := t₁.evalFinStream (fun i => vars (Fin.castLE (by simp [arity]) i)) + let x₂ := t₂.evalFinStream (fun i => vars (Fin.castLE (by simp [arity]) i)) + x₁ + x₂ + | sub t₁ t₂ => + let x₁ := t₁.evalFinStream (fun i => vars (Fin.castLE (by simp [arity]) i)) + let x₂ := t₂.evalFinStream (fun i => vars (Fin.castLE (by simp [arity]) i)) + x₁ - x₂ + | neg t => -(Term.evalFinStream t vars) + | incr t => BitStream.incr (Term.evalFinStream t vars) + | decr t => BitStream.decr (Term.evalFinStream t vars) + +inductive RelationOrdering +| lt | le | gt | ge +deriving Repr + +inductive Relation +| eq +| signed (ord : RelationOrdering) +| unsigned (ord : RelationOrdering) +deriving Repr + +@[simp] +def evalRelation {w} (rel : Relation) (bv1 bv2 : BitVec w) : Bool := + match rel with + | .eq => bv1 = bv2 + | .signed .lt => bv1 <ₛ bv2 + | .signed .le => bv1 ≤ₛ bv2 + | .signed .gt => bv1 >ₛ bv2 + | .signed .ge => bv1 ≥ₛ bv2 + | .unsigned .lt => bv1 <ᵤ bv2 + | .unsigned .le => bv1 ≤ᵤ bv2 + | .unsigned .gt => bv1 >ᵤ bv2 + | .unsigned .ge => bv1 ≥ᵤ bv2 + +inductive Binop +| and | or | impl | equiv +deriving Repr + +@[simp] +def evalBinop (op : Binop) (b1 b2 : Bool) : Bool := + match op with + | .and => b1 && b2 + | .or => b1 || b2 + | .impl => b1 -> b2 + | .equiv => b1 <-> b2 + +@[simp] +def evalBinop' (op : Binop) (b1 b2 : Prop) : Prop := + match op with + | .and => b1 ∧ b2 + | .or => b1 ∨ b2 + | .impl => b1 → b2 + | .equiv => b1 ↔ b2 +inductive Unop +| neg +deriving Repr + +inductive Formula : Type +| atom : Relation → Term → Term → Formula +| msbSet : Term → Formula +| unop : Unop → Formula → Formula +| binop : Binop → Formula → Formula → Formula +deriving Repr + +instance : Inhabited Formula := ⟨Formula.msbSet default⟩ + +@[simp] +def Formula.arity : Formula → Nat +| atom _ t1 t2 => max t1.arity t2.arity +| msbSet t => t.arity +| unop _ φ => φ.arity +| binop _ φ1 φ2 => max φ1.arity φ2.arity + +@[simp] +def Formula.sat {w : Nat} (φ : Formula) (ρ : Fin φ.arity → BitVec w) : Bool := + match φ with + | .atom rel t1 t2 => + let bv1 := t1.evalFin (fun n => ρ $ Fin.castLE (by simp [arity]) n) + let bv2 := t2.evalFin (fun n => ρ $ Fin.castLE (by simp [arity]) n) + evalRelation rel bv1 bv2 + | .unop .neg φ => !φ.sat ρ + | .binop op φ1 φ2 => + let b1 := φ1.sat (fun n => ρ $ Fin.castLE (by simp [arity]) n) + let b2 := φ2.sat (fun n => ρ $ Fin.castLE (by simp [arity]) n) + evalBinop op b1 b2 + | .msbSet t => (t.evalFin ρ).msb + +@[simp] +def Formula.sat' {w : Nat} (φ : Formula) (ρ : Nat → BitVec w) : Prop := + match φ with + | .atom rel t1 t2 => + let bv1 := t1.evalNat ρ + let bv2 := t2.evalNat ρ + evalRelation rel bv1 bv2 + | .unop .neg φ => ¬ φ.sat' ρ + | .binop op φ1 φ2 => + let b1 := φ1.sat' ρ + let b2 := φ2.sat' ρ + evalBinop' op b1 b2 + | .msbSet t => (t.evalNat ρ).msb + +@[simp] +def envOfArray {w} (a : Array (BitVec w)) : Nat → BitVec w := fun n => a.getD n 0 diff --git a/SSA/Experimental/Bits/AutoStructs/FinEnum.lean b/SSA/Experimental/Bits/AutoStructs/FinEnum.lean new file mode 100644 index 000000000..0f9be3a0b --- /dev/null +++ b/SSA/Experimental/Bits/AutoStructs/FinEnum.lean @@ -0,0 +1,41 @@ +/- +Released under Apache 2.0 license as described in the file LICENSE. +-/ +import Mathlib.Data.FinEnum +import Mathlib.Tactic.FinCases + +instance: FinEnum (BitVec w) where + card := 2^w + equiv := { + toFun := fun x => x.toFin + invFun := fun x => BitVec.ofFin x + left_inv := by intros bv; simp + right_inv := by intros n; simp + } + +instance : FinEnum Bool where + card := 2 + equiv := { + toFun := fun x => if x then 1 else 0, + invFun := fun (x : Fin 2 ) => if x == 0 then false else true, + left_inv := by intros _; simp, + right_inv := by intros x; fin_cases x <;> simp + } + +instance finEnumUnit : FinEnum Unit where + card := 1 + equiv := { + toFun := fun _ => 0, + invFun := fun (_ : Fin 1) => (), + left_inv := by intros _; simp, + right_inv := by intros x; fin_cases x; simp + } + +instance finEnumEmpty : FinEnum Empty where + card := 0 + equiv := { + toFun := fun x => Empty.elim x + invFun := fun (x : Fin 0) => Fin.elim0 x + left_inv := by intros x; cases x + right_inv := by intros x; fin_cases x + } diff --git a/SSA/Experimental/Bits/AutoStructs/FiniteStateMachine.lean b/SSA/Experimental/Bits/AutoStructs/FiniteStateMachine.lean new file mode 100644 index 000000000..ae4496fb6 --- /dev/null +++ b/SSA/Experimental/Bits/AutoStructs/FiniteStateMachine.lean @@ -0,0 +1,794 @@ +import Mathlib.Data.Fintype.Card +import Mathlib.Data.FinEnum +import Mathlib.Data.Fintype.Sum +import Mathlib.Data.Fintype.Sigma +import Mathlib.Data.Fintype.Pi +import Mathlib.Data.Fintype.BigOperators +import Mathlib.Tactic.Zify +import Mathlib.Tactic.Ring +import SSA.Experimental.Bits.AutoStructs.Defs +import SSA.Experimental.Bits.AutoStructs.FinEnum +import SSA.Experimental.Bits.Fast.Circuit +import SSA.Experimental.Bits.Fast.BitStream +namespace AutoStructs + +open Sum + +section FSM +variable {α β α' β' : Type} {γ : β → Type} + +/-- `FSM n` represents a function `BitStream → ⋯ → BitStream → BitStream`, +where `n` is the number of `BitStream` arguments, +as a finite state machine. +-/ +structure FSM (arity : Type) : Type 1 := + /-- + The arity of the (finite) type `α` determines how many bits the internal carry state of this + FSM has -/ + ( α : Type ) + [ i : FinEnum α ] + [ dec_eq : DecidableEq α ] + /-- + `initCarry` is the value of the initial internal carry state. + It maps each `α` to a bit, thus it is morally a bitvector where the width is the arity of `α` + -/ + ( initCarry : α → Bool ) + /-- + `nextBitCirc` is a family of Boolean circuits, + which may refer to the current input bits *and* the current state bits + as free variables in the circuit. + + `nextBitCirc none` computes the current output bit. + `nextBitCirc (some a)`, computes the *one* bit of the new state that corresponds to `a : α`. -/ + ( nextBitCirc : Option α → Circuit (α ⊕ arity) ) + +attribute [instance] FSM.i FSM.dec_eq + +namespace FSM + +variable {arity : Type} (p : FSM arity) + +/-- The state of FSM `p` is given by a function from `p.α` to `Bool`. + +Note that `p.α` is assumed to be a finite type, so `p.State` is morally +a finite bitvector whose width is given by the arity of `p.α` -/ +abbrev State : Type := p.α → Bool + +/-- `p.nextBit state in` computes both the next state bits and the output bit, +where `state` are the *current* state bits, and `in` are the current input bits. -/ +def nextBit : p.State → (arity → Bool) → p.State × Bool := + fun carry inputBits => + let input := Sum.elim carry inputBits + let newState : p.State := fun (a : p.α) => (p.nextBitCirc (some a)).eval input + let outBit : Bool := (p.nextBitCirc none).eval input + (newState, outBit) + +/-- `p.carry in i` computes the internal carry state at step `i`, given input *streams* `in` -/ +def carry (x : arity → BitStream) : ℕ → p.State + | 0 => p.initCarry + | n+1 => (p.nextBit (carry x n) (fun i => x i n)).1 + +/-- `eval p` morally gives the function `BitStream → ... → BitStream` represented by FSM `p` -/ +def eval (x : arity → BitStream) : BitStream := + fun n => (p.nextBit (p.carry x n) (fun i => x i n)).2 + +/-- `eval'` is an alternative definition of `eval` -/ +def eval' (x : arity → BitStream) : BitStream := + BitStream.corec (fun ⟨x, (carry : p.State)⟩ => + let x_head := (x · |>.head) + let next := p.nextBit carry x_head + let x_tail := (x · |>.tail) + ((x_tail, next.fst), next.snd) + ) (x, p.initCarry) + +/-- `p.changeInitCarry c` yields an FSM with `c` as the initial state -/ +def changeInitCarry (p : FSM arity) (c : p.α → Bool) : FSM arity := + { p with initCarry := c } + +theorem carry_changeInitCarry_succ + (p : FSM arity) (c : p.α → Bool) (x : arity → BitStream) : ∀ n, + (p.changeInitCarry c).carry x (n+1) = + (p.changeInitCarry (p.nextBit c (fun a => x a 0)).1).carry + (fun a i => x a (i+1)) n + | 0 => by simp [carry, changeInitCarry, nextBit] + | n+1 => by + rw [carry, carry_changeInitCarry_succ p _ _ n] + simp [nextBit, carry, changeInitCarry] + +theorem eval_changeInitCarry_succ + (p : FSM arity) (c : p.α → Bool) (x : arity → BitStream) (n : ℕ) : + (p.changeInitCarry c).eval x (n+1) = + (p.changeInitCarry (p.nextBit c (fun a => x a 0)).1).eval + (fun a i => x a (i+1)) n := by + rw [eval, carry_changeInitCarry_succ] + simp [eval, changeInitCarry, nextBit] + +/-- unfolds the definition of `eval` -/ +theorem eval_eq_carry (x : arity → BitStream) (n : ℕ) : + p.eval x n = (p.nextBit (p.carry x n) (fun i => x i n)).2 := + rfl + +theorem eval_eq_eval' : + p.eval x = p.eval' x := by + funext i + simp only [eval, eval'] + induction i generalizing p x + case zero => rfl + case succ i ih => + sorry + +/-- `p.changeVars f` changes the arity of an `FSM`. +The function `f` determines how the new input bits map to the input expected by `p` -/ +def changeVars {arity2 : Type} (changeVars : arity → arity2) : FSM arity2 := + { p with nextBitCirc := fun a => (p.nextBitCirc a).map (Sum.map id changeVars) } + +/-- +Given an FSM `p` of arity `n`, +a family of `n` FSMs `qᵢ` of posibly different arities `mᵢ`, +and given yet another arity `m` such that `mᵢ ≤ m` for all `i`, +we can compose `p` with `qᵢ` yielding a single FSM of arity `m`, +such that each FSM `qᵢ` computes the `i`th bit that is fed to the FSM `p`. -/ +def compose [FinEnum arity] [DecidableEq arity] + (new_arity : Type) -- `new_arity` is the resulting arity + (q_arity : arity → Type) -- `q_arityₐ` is the arity of FSM `qₐ` + (vars : ∀ (a : arity), q_arity a → new_arity) + -- ^^ `vars` is the function that tells us, for each FSM `qₐ`, + -- which bits of the final `new_arity` corresponds to the `q_arityₐ` bits expected by `qₐ` + (q : ∀ (a : arity), FSM (q_arity a)) : -- `q` gives the FSMs to be composed with `p` + FSM new_arity := + { α := p.α ⊕ (Σ a, (q a).α), + i := by letI := p.i; infer_instance, + dec_eq := by + letI := p.dec_eq + letI := fun a => (q a).dec_eq + infer_instance, + initCarry := Sum.elim p.initCarry (λ x => (q x.1).initCarry x.2), + nextBitCirc := λ a => + match a with + | none => (p.nextBitCirc none).bind + (Sum.elim + (fun a => Circuit.var true (inl (inl a))) + (fun a => ((q a).nextBitCirc none).map + (Sum.elim (fun d => (inl (inr ⟨a, d⟩))) (fun q => inr (vars a q))))) + | some (inl a) => + (p.nextBitCirc (some a)).bind + (Sum.elim + (fun a => Circuit.var true (inl (inl a))) + (fun a => ((q a).nextBitCirc none).map + (Sum.elim (fun d => (inl (inr ⟨a, d⟩))) (fun q => inr (vars a q))))) + | some (inr ⟨x, y⟩) => + ((q x).nextBitCirc (some y)).map + (Sum.elim + (fun a => inl (inr ⟨_, a⟩)) + (fun a => inr (vars x a))) } + +lemma carry_compose [FinEnum arity] [DecidableEq arity] + (new_arity : Type) + (q_arity : arity → Type) + (vars : ∀ (a : arity), q_arity a → new_arity) + (q : ∀ (a : arity), FSM (q_arity a)) + (x : new_arity → BitStream) : ∀ (n : ℕ), + (p.compose new_arity q_arity vars q).carry x n = + let z := p.carry (λ a => (q a).eval (fun i => x (vars _ i))) n + Sum.elim z (fun a => (q a.1).carry (fun t => x (vars _ t)) n a.2) + | 0 => by simp [carry, compose] + | n+1 => by + rw [carry, carry_compose _ _ _ _ _ n] + ext y + cases y + · simp [carry, nextBit, compose, Circuit.eval_bind, eval] + congr + ext z + cases z + · simp + · simp [Circuit.eval_map, carry] + congr + ext s + cases s + · simp + · simp + · simp [Circuit.eval_map, carry, compose, eval, carry, nextBit] + congr + ext z + cases z + · simp + · simp + +/-- Evaluating a composed fsm is equivalent to composing the evaluations of the constituent FSMs -/ +lemma eval_compose [FinEnum arity] [DecidableEq arity] + (new_arity : Type) + (q_arity : arity → Type) + (vars : ∀ (a : arity), q_arity a → new_arity) + (q : ∀ (a : arity), FSM (q_arity a)) + (x : new_arity → BitStream) : + (p.compose new_arity q_arity vars q).eval x = + p.eval (λ a => (q a).eval (fun i => x (vars _ i))) := by + ext n + rw [eval, carry_compose, eval] + simp [compose, nextBit, Circuit.eval_bind] + congr + ext a + cases a + simp + simp [Circuit.eval_map, eval, nextBit] + congr + ext a + cases a + simp + simp + +def and : FSM Bool := + { α := Empty, + initCarry := Empty.elim, + nextBitCirc := fun a => a.elim + (Circuit.and + (Circuit.var true (inr true)) + (Circuit.var true (inr false))) Empty.elim } + +@[simp] lemma eval_and (x : Bool → BitStream) : and.eval x = (x true) &&& (x false) := by + ext n; cases n <;> simp [and, eval, nextBit] + +def or : FSM Bool := + { α := Empty, + initCarry := Empty.elim, + nextBitCirc := fun a => a.elim + (Circuit.or + (Circuit.var true (inr true)) + (Circuit.var true (inr false))) Empty.elim } + +@[simp] lemma eval_or (x : Bool → BitStream) : or.eval x = (x true) ||| (x false) := by + ext n; cases n <;> simp [and, eval, nextBit] + +def xor : FSM Bool := + { α := Empty, + initCarry := Empty.elim, + nextBitCirc := fun a => a.elim + (Circuit.xor + (Circuit.var true (inr true)) + (Circuit.var true (inr false))) Empty.elim } + +@[simp] lemma eval_xor (x : Bool → BitStream) : xor.eval x = (x true) ^^^ (x false) := by + ext n; cases n <;> simp [and, eval, nextBit] + +def add : FSM Bool := + { α := Unit, + initCarry := λ _ => false, + nextBitCirc := fun a => + match a with + | some () => + (Circuit.var true (inr true) &&& Circuit.var true (inr false)) + ||| (Circuit.var true (inr true) &&& Circuit.var true (inl ())) + ||| (Circuit.var true (inr false) &&& Circuit.var true (inl ())) + | none => Circuit.var true (inr true) ^^^ + Circuit.var true (inr false) ^^^ + Circuit.var true (inl ()) } + +/-- The internal carry state of the `add` FSM agrees with +the carry bit of addition as implemented on bitstreams -/ +theorem carry_add_succ (x : Bool → BitStream) (n : ℕ) : + add.carry x (n+1) = + fun _ => (BitStream.addAux (x true) (x false) n).2 := by + ext a; obtain rfl : a = () := rfl + induction n with + | zero => + simp [carry, BitStream.addAux, nextBit, add, BitVec.adcb] + | succ n ih => + unfold carry + simp [nextBit, ih, Circuit.eval, BitStream.addAux, BitVec.adcb] + +@[simp] theorem carry_zero (x : arity → BitStream) : carry p x 0 = p.initCarry := rfl +@[simp] theorem initCarry_add : add.initCarry = (fun _ => false) := rfl + +@[simp] lemma eval_add (x : Bool → BitStream) : add.eval x = (x true) + (x false) := by + ext n + simp only [eval] + cases n + · show Bool.xor _ _ = Bool.xor _ _; simp + · rw [carry_add_succ] + conv => {rhs; simp only [(· + ·), BitStream.add, Add.add, BitStream.addAux, BitVec.adcb]} + simp [nextBit, eval, add] +/-! +We don't really need subtraction or negation FSMs, +given that we can reduce both those operations to just addition and bitwise complement -/ + +def sub : FSM Bool := + { α := Unit, + initCarry := fun _ => false, + nextBitCirc := fun a => + match a with + | some () => + (Circuit.var false (inr true) &&& Circuit.var true (inr false)) ||| + ((Circuit.var false (inr true) ^^^ Circuit.var true (inr false)) &&& + (Circuit.var true (inl ()))) + | none => Circuit.var true (inr true) ^^^ + Circuit.var true (inr false) ^^^ + Circuit.var true (inl ()) } + +theorem carry_sub (x : Bool → BitStream) : ∀ (n : ℕ), sub.carry x (n+1) = + fun _ => (BitStream.subAux (x true) (x false) n).2 + | 0 => by + simp [carry, nextBit, Function.funext_iff, BitStream.subAux, sub] + | n+1 => by + rw [carry, carry_sub _ n] + simp [nextBit, eval, sub, BitStream.sub, BitStream.subAux, Bool.xor_not_left'] + +@[simp] +theorem eval_sub (x : Bool → BitStream) : sub.eval x = (x true) - (x false) := by + simp only [(· - ·), Sub.sub] + ext n + cases n + · simp [eval, sub, nextBit, BitStream.sub, BitStream.subAux, carry] + · rw [eval, carry_sub] + simp [nextBit, eval, sub, BitStream.sub, BitStream.subAux] + +def neg : FSM Unit := + { α := Unit, + i := by infer_instance, + initCarry := λ _ => true, + nextBitCirc := fun a => + match a with + | some () => Circuit.var false (inr ()) &&& Circuit.var true (inl ()) + | none => Circuit.var false (inr ()) ^^^ Circuit.var true (inl ()) } + +theorem carry_neg (x : Unit → BitStream) : ∀ (n : ℕ), neg.carry x (n+1) = + fun _ => (BitStream.negAux (x ()) n).2 + | 0 => by + simp [carry, nextBit, Function.funext_iff, BitStream.negAux, neg] + | n+1 => by + rw [carry, carry_neg _ n] + simp [nextBit, eval, neg, BitStream.neg, BitStream.negAux, Bool.xor_not_left'] + +@[simp] lemma eval_neg (x : Unit → BitStream) : neg.eval x = -(x ()) := by + show _ = BitStream.neg _ + ext n + cases n + · simp [eval, neg, nextBit, BitStream.neg, BitStream.negAux, carry] + · rw [eval, carry_neg] + simp [nextBit, eval, neg, BitStream.neg, BitStream.negAux] + +def not : FSM Unit := + { α := Empty, + initCarry := Empty.elim, + nextBitCirc := fun _ => Circuit.var false (inr ()) } + +@[simp] lemma eval_not (x : Unit → BitStream) : not.eval x = ~~~(x ()) := by + ext; simp [eval, not, nextBit] + +def zero : FSM (Fin 0) := + { α := Empty, + initCarry := Empty.elim, + nextBitCirc := fun _ => Circuit.fals } + +@[simp] lemma eval_zero (x : Fin 0 → BitStream) : zero.eval x = BitStream.zero := by + ext; simp [zero, eval, nextBit] + +def one : FSM (Fin 0) := + { α := Unit, + i := by infer_instance, + initCarry := λ _ => true, + nextBitCirc := fun a => + match a with + | some () => Circuit.fals + | none => Circuit.var true (inl ()) } + +@[simp] theorem carry_one (x : Fin 0 → BitStream) (n : ℕ) : + one.carry x (n+1) = fun _ => false := by + simp [carry, nextBit, one] + +@[simp] lemma eval_one (x : Fin 0 → BitStream) : one.eval x = BitStream.one := by + ext n + cases n + · rfl + · simp [eval, carry_one, nextBit] + +def negOne : FSM (Fin 0) := + { α := Empty, + i := by infer_instance, + initCarry := Empty.elim, + nextBitCirc := fun _ => Circuit.tru } + +@[simp] lemma eval_negOne (x : Fin 0 → BitStream) : negOne.eval x = BitStream.negOne := by + ext; simp [negOne, eval, nextBit] + +def ls (b : Bool) : FSM Unit := + { α := Unit, + initCarry := fun _ => b, + nextBitCirc := fun x => + match x with + | none => Circuit.var true (inl ()) + | some () => Circuit.var true (inr ()) } + +theorem carry_ls (b : Bool) (x : Unit → BitStream) : ∀ (n : ℕ), + (ls b).carry x (n+1) = fun _ => x () n + | 0 => by + simp [carry, nextBit, Function.funext_iff, ls] + | n+1 => by + rw [carry, carry_ls _ _ n] + simp [nextBit, eval, ls] + +@[simp] lemma eval_ls (b : Bool) (x : Unit → BitStream) : + (ls b).eval x = (x ()).concat b := by + ext n + cases n + · rfl + · simp [carry_ls, eval, nextBit, BitStream.concat] + +def var (n : ℕ) : FSM (Fin (n+1)) := + { α := Empty, + i := by infer_instance, + initCarry := Empty.elim, + nextBitCirc := λ _ => Circuit.var true (inr (Fin.last _)) } + +@[simp] lemma eval_var (n : ℕ) (x : Fin (n+1) → BitStream) : (var n).eval x = x (Fin.last n) := by + ext m; cases m <;> simp [var, eval, carry, nextBit] + +def incr : FSM Unit := + { α := Unit, + initCarry := fun _ => true, + nextBitCirc := fun x => + match x with + | none => (Circuit.var true (inr ())) ^^^ (Circuit.var true (inl ())) + | some _ => (Circuit.var true (inr ())) &&& (Circuit.var true (inl ())) } + +theorem carry_incr (x : Unit → BitStream) : ∀ (n : ℕ), + incr.carry x (n+1) = fun _ => (BitStream.incrAux (x ()) n).2 + | 0 => by + simp [carry, nextBit, Function.funext_iff, BitStream.incrAux, incr] + | n+1 => by + rw [carry, carry_incr _ n] + simp [nextBit, eval, incr, incr, BitStream.incrAux] + +@[simp] lemma eval_incr (x : Unit → BitStream) : incr.eval x = (x ()).incr := by + ext n + cases n + · simp [eval, incr, nextBit, carry, BitStream.incr, BitStream.incrAux] + · rw [eval, carry_incr]; rfl + +def decr : FSM Unit := + { α := Unit, + i := by infer_instance, + initCarry := λ _ => true, + nextBitCirc := fun x => + match x with + | none => (Circuit.var true (inr ())) ^^^ (Circuit.var true (inl ())) + | some _ => (Circuit.var false (inr ())) &&& (Circuit.var true (inl ())) } + +theorem carry_decr (x : Unit → BitStream) : ∀ (n : ℕ), decr.carry x (n+1) = + fun _ => (BitStream.decrAux (x ()) n).2 + | 0 => by + simp [carry, nextBit, Function.funext_iff, BitStream.decrAux, decr] + | n+1 => by + rw [carry, carry_decr _ n] + simp [nextBit, eval, decr, BitStream.decrAux] + +@[simp] lemma eval_decr (x : Unit → BitStream) : decr.eval x = BitStream.decr (x ()) := by + ext n + cases n + · simp [eval, decr, nextBit, carry, BitStream.decr, BitStream.decrAux] + · rw [eval, carry_decr]; rfl + +theorem evalAux_eq_zero_of_set {arity : Type _} (p : FSM arity) + (R : Set (p.α → Bool)) (hR : ∀ x s, (p.nextBit s x).1 ∈ R → s ∈ R) + (hi : p.initCarry ∉ R) (hr1 : ∀ x s, (p.nextBit s x).2 = true → s ∈ R) + (x : arity → BitStream) (n : ℕ) : p.eval x n = false ∧ p.carry x n ∉ R := by + simp (config := {singlePass := true}) only [← not_imp_not] at hR hr1 + simp only [Bool.not_eq_true] at hR hr1 + induction n with + | zero => + simp only [eval, carry] + exact ⟨hr1 _ _ hi, hi⟩ + | succ n ih => + simp only [eval, carry] at ih ⊢ + exact ⟨hr1 _ _ (hR _ _ ih.2), hR _ _ ih.2⟩ + +theorem eval_eq_zero_of_set {arity : Type _} (p : FSM arity) + (R : Set (p.α → Bool)) (hR : ∀ x s, (p.nextBit s x).1 ∈ R → s ∈ R) + (hi : p.initCarry ∉ R) (hr1 : ∀ x s, (p.nextBit s x).2 = true → s ∈ R) : + p.eval = fun _ _ => false := by + ext x n + rw [eval] + exact (evalAux_eq_zero_of_set p R hR hi hr1 x n).1 + +def repeatBit : FSM Unit where + α := Unit + initCarry := fun () => false + nextBitCirc := fun _ => + .or (.var true <| .inl ()) (.var true <| .inr ()) + +@[simp] theorem eval_repeatBit : + repeatBit.eval x = BitStream.repeatBit (x ()) := by + unfold BitStream.repeatBit + rw [eval_eq_eval', eval'] + apply BitStream.corec_eq_corec + (R := fun a b => a.1 () = b.2 ∧ (a.2 ()) = b.1) + · simp [repeatBit] + · intro ⟨y, a⟩ ⟨b, x⟩ h + simp at h + simp [h, nextBit, BitStream.head] + +end FSM + +structure FSMSolution (t : Term) extends FSM (Fin t.arity) := + ( good : t.evalFinStream = toFSM.eval ) + +def composeUnary + (p : FSM Unit) + {t : Term} + (q : FSMSolution t) : + FSM (Fin t.arity) := + p.compose + (Fin t.arity) + _ + (λ _ => id) + (λ _ => q.toFSM) + +def composeBinary + (p : FSM Bool) + {t₁ t₂ : Term} + (q₁ : FSMSolution t₁) + (q₂ : FSMSolution t₂) : + FSM (Fin (max t₁.arity t₂.arity)) := + p.compose (Fin (max t₁.arity t₂.arity)) + (λ b => Fin (cond b t₁.arity t₂.arity)) + (λ b i => Fin.castLE (by cases b <;> simp) i) + (λ b => match b with + | true => q₁.toFSM + | false => q₂.toFSM) + +def composeBinary' + (p : FSM Bool) + {n m : Nat} + (q₁ : FSM (Fin n)) + (q₂ : FSM (Fin m)) : + FSM (Fin (max n m)) := + p.compose (Fin (max n m)) + (λ b => Fin (cond b n m)) + (λ b i => Fin.castLE (by cases b <;> simp) i) + (λ b => match b with + | true => q₁ + | false => q₂) + +@[simp] lemma composeUnary_eval + (p : FSM Unit) + {t : Term} + (q : FSMSolution t) + (x : Fin t.arity → BitStream) : + (composeUnary p q).eval x = p.eval (λ _ => t.evalFinStream x) := by + rw [composeUnary, FSM.eval_compose, q.good]; rfl + +@[simp] lemma composeBinary_eval + (p : FSM Bool) + {t₁ t₂ : Term} + (q₁ : FSMSolution t₁) + (q₂ : FSMSolution t₂) + (x : Fin (max t₁.arity t₂.arity) → BitStream) : + (composeBinary p q₁ q₂).eval x = p.eval + (λ b => cond b (t₁.evalFinStream (fun i => x (Fin.castLE (by simp) i))) + (t₂.evalFinStream (fun i => x (Fin.castLE (by simp) i)))) := by + rw [composeBinary, FSM.eval_compose, q₁.good, q₂.good] + ext b + cases b <;> dsimp <;> congr <;> funext b <;> cases b <;> simp + +instance {α β : Type} [Fintype α] [Fintype β] (b : Bool) : + Fintype (cond b α β) := by + cases b <;> simp <;> infer_instance + +open Term + +def termEvalEqFSM : ∀ (t : Term), FSMSolution t + | var n => + { toFSM := FSM.var n, + good := by ext; simp [Term.evalFin] } + | zero => + { toFSM := FSM.zero, + good := by ext; simp [Term.evalFin] } + | one => + { toFSM := FSM.one, + good := by ext; simp [Term.evalFin] } + | negOne => + { toFSM := FSM.negOne, + good := by ext; simp [Term.evalFin] } + | Term.and t₁ t₂ => + let q₁ := termEvalEqFSM t₁ + let q₂ := termEvalEqFSM t₂ + { toFSM := composeBinary FSM.and q₁ q₂, + good := by ext; simp } + | Term.or t₁ t₂ => + let q₁ := termEvalEqFSM t₁ + let q₂ := termEvalEqFSM t₂ + { toFSM := composeBinary FSM.or q₁ q₂, + good := by ext; simp } + | Term.xor t₁ t₂ => + let q₁ := termEvalEqFSM t₁ + let q₂ := termEvalEqFSM t₂ + { toFSM := composeBinary FSM.xor q₁ q₂, + good := by ext; simp } + | Term.not t => + let q := termEvalEqFSM t + { toFSM := by dsimp [arity]; exact composeUnary FSM.not q, + good := by ext; simp } + | add t₁ t₂ => + let q₁ := termEvalEqFSM t₁ + let q₂ := termEvalEqFSM t₂ + { toFSM := composeBinary FSM.add q₁ q₂, + good := by ext; simp } + | sub t₁ t₂ => + let q₁ := termEvalEqFSM t₁ + let q₂ := termEvalEqFSM t₂ + { toFSM := composeBinary FSM.sub q₁ q₂, + good := by ext; simp } + | neg t => + let q := termEvalEqFSM t + { toFSM := by dsimp [arity]; exact composeUnary FSM.neg q, + good := by ext; simp } + | incr t => + let q := termEvalEqFSM t + { toFSM := by dsimp [arity]; exact composeUnary FSM.incr q, + good := by ext; simp } + | decr t => + let q := termEvalEqFSM t + { toFSM := by dsimp [arity]; exact composeUnary FSM.decr q, + good := by ext; simp } + +/-! +FSM that implement bitwise-and. Since we use `0` as the good state, +we keep the invariant that if both inputs are good and our state is `0`, then we produce a `0`. +If not, we produce an infinite sequence of `1`. +-/ +def and : FSM Bool := + { α := Unit, + initCarry := fun _ => false, + nextBitCirc := fun a => + match a with + | some () => + -- Only if both are `0` we produce a `0`. + (Circuit.var true (inr false) ||| + ((Circuit.var false (inr true) ||| + -- But if we have failed and have value `1`, then we produce a `1` from our state. + (Circuit.var true (inl ()))))) + | none => -- must succeed in both arguments, so we are `0` if both are `0`. + Circuit.var true (inr true) ||| + Circuit.var true (inr false) + } + +/-! +FSM that implement bitwise-or. Since we use `0` as the good state, +we keep the invariant that if either inputs is `0` then our state is `0`. +If not, we produce a `1`. +-/ +def or : FSM Bool := + { α := Unit, + initCarry := fun _ => false, + nextBitCirc := fun a => + match a with + | some () => + -- If either succeeds, then the full thing succeeds + ((Circuit.var true (inr false) &&& + ((Circuit.var false (inr true)) ||| + -- On the other hand, if we have failed, then propagate failure. + (Circuit.var true (inl ()))))) + | none => -- can succeed in either argument, so we are `0` if either is `0`. + Circuit.var true (inr true) &&& + Circuit.var true (inr false) + } + +/-! +FSM that implement logical not. +we keep the invariant that if the input ever fails and becomes a `1`, then we produce a `0`. +IF not, we produce an infinite sequence of `1`. + +EDIT: Aha, this doesn't work! +We need NFA to DFA here (as the presburger book does), +where we must produce an infinite sequence of`0` iff the input can *ever* become a `1`. +But here, since we phrase things directly in terms of producing sequences, it's a bit less clear +what we should do :) + +- Alternatively, we need to be able to decide `eventually always zero`. +- Alternatively, we push negations inside, and decide `⬝ ≠ ⬝` and `⬝ ≰ ⬝`. +-/ + +inductive Result : Type + | falseAfter (n : ℕ) : Result + | trueFor (n : ℕ) : Result + | trueForall : Result +deriving Repr, DecidableEq + +def card_compl [Fintype α] [DecidableEq α] (c : Circuit α) : ℕ := + Finset.card $ (@Finset.univ (α → Bool) _).filter (fun a => c.eval a = false) + +theorem decideIfZeroAux_wf {α : Type _} [Fintype α] [DecidableEq α] + {c c' : Circuit α} (h : ¬c' ≤ c) : card_compl (c' ||| c) < card_compl c := by + apply Finset.card_lt_card + simp [Finset.ssubset_iff, Finset.subset_iff] + simp only [Circuit.le_def, not_forall, Bool.not_eq_true] at h + rcases h with ⟨x, hx, h⟩ + use x + simp [hx, h] + +def decideIfZerosAux {arity : Type _} [DecidableEq arity] + (p : FSM arity) (c : Circuit p.α) : Bool := + if c.eval p.initCarry + then false + else + have c' := (c.bind (p.nextBitCirc ∘ some)).fst + if h : c' ≤ c then true + else + have _wf : card_compl (c' ||| c) < card_compl c := + decideIfZeroAux_wf h + decideIfZerosAux p (c' ||| c) + termination_by card_compl c + +def decideIfZeros {arity : Type _} [DecidableEq arity] + (p : FSM arity) : Bool := + decideIfZerosAux p (p.nextBitCirc none).fst + +theorem decideIfZerosAux_correct {arity : Type _} [DecidableEq arity] + (p : FSM arity) (c : Circuit p.α) + (hc : ∀ s, c.eval s = true → + ∃ m y, (p.changeInitCarry s).eval y m = true) + (hc₂ : ∀ (x : arity → Bool) (s : p.α → Bool), + (FSM.nextBit p s x).snd = true → Circuit.eval c s = true) : + decideIfZerosAux p c = true ↔ ∀ n x, p.eval x n = false := by + rw [decideIfZerosAux] + split_ifs with h + · simp + exact hc p.initCarry h + · dsimp + split_ifs with h' + · simp only [true_iff] + intro n x + rw [p.eval_eq_zero_of_set {x | c.eval x = true}] + · intro y s + simp [Circuit.le_def, Circuit.eval_fst, Circuit.eval_bind] at h' + simp [Circuit.eval_fst, FSM.nextBit] + apply h' + · assumption + · exact hc₂ + · let c' := (c.bind (p.nextBitCirc ∘ some)).fst + have _wf : card_compl (c' ||| c) < card_compl c := + decideIfZeroAux_wf h' + apply decideIfZerosAux_correct p (c' ||| c) + simp [c', Circuit.eval_fst, Circuit.eval_bind] + intro s hs + rcases hs with ⟨x, hx⟩ | h + · rcases hc _ hx with ⟨m, y, hmy⟩ + use (m+1) + use fun a i => Nat.casesOn i x (fun i a => y a i) a + rw [FSM.eval_changeInitCarry_succ] + rw [← hmy] + simp only [FSM.nextBit, Nat.rec_zero, Nat.rec_add_one] + · exact hc _ h + · intro x s h + have := hc₂ _ _ h + simp only [Circuit.eval_bind, Bool.or_eq_true, Circuit.eval_fst, + Circuit.eval_or, this, or_true] +termination_by card_compl c + +theorem decideIfZeros_correct {arity : Type _} [DecidableEq arity] + (p : FSM arity) : decideIfZeros p = true ↔ ∀ n x, p.eval x n = false := by + apply decideIfZerosAux_correct + · simp only [Circuit.eval_fst, forall_exists_index] + intro s x h + use 0 + use (fun a _ => x a) + simpa [FSM.eval, FSM.changeInitCarry, FSM.nextBit, FSM.carry] + · simp only [Circuit.eval_fst] + intro x s h + use x + exact h + +end FSM + +/-- +The fragment of predicate logic that we support in `bv_automata`. +Currently, we support equality, conjunction, disjunction, and negation. +This can be expanded to also support arithmetic constraints such as unsigned-less-than. +-/ +inductive Predicate : Nat → Type _ where +| eq (t1 t2 : Term) : Predicate ((max t1.arity t2.arity)) +| and (p : Predicate n) (q : Predicate m) : Predicate (max n m) +| or (p : Predicate n) (q : Predicate m) : Predicate (max n m) +-- For now, we can't prove `not`, because it needs NFA → DFA conversion +-- the way Sid knows how to build it, or negation normal form, +-- both of which is machinery we lack. +-- | not (p : Predicate n) : Predicate n diff --git a/SSA/Experimental/Bits/AutoStructs/ForLean.lean b/SSA/Experimental/Bits/AutoStructs/ForLean.lean new file mode 100644 index 000000000..e2ae26fa8 --- /dev/null +++ b/SSA/Experimental/Bits/AutoStructs/ForLean.lean @@ -0,0 +1,9 @@ +/- +Released under Apache 2.0 license as described in the file LICENSE. +-/ + +theorem ofBool_1_iff_true : BitVec.ofBool b = 1#1 ↔ b := by + cases b <;> simp + +theorem ofBool_0_iff_false : BitVec.ofBool b = 0#1 ↔ ¬ b := by + cases b <;> simp diff --git a/SSA/Experimental/Bits/AutoStructs/FormulaToAuto.lean b/SSA/Experimental/Bits/AutoStructs/FormulaToAuto.lean new file mode 100644 index 000000000..d3c01282b --- /dev/null +++ b/SSA/Experimental/Bits/AutoStructs/FormulaToAuto.lean @@ -0,0 +1,324 @@ +/- +Released under Apache 2.0 license as described in the file LICENSE. +-/ +import Std.Data.HashSet +import Std.Data.HashMap +import Mathlib.Data.Fintype.Basic +import Mathlib.Data.Finset.Basic +import Mathlib.Data.FinEnum +import Mathlib.Tactic.FinCases +import SSA.Experimental.Bits.AutoStructs.Basic +import SSA.Experimental.Bits.AutoStructs.Defs +import SSA.Experimental.Bits.AutoStructs.FinEnum +import SSA.Experimental.Bits.AutoStructs.FiniteStateMachine + +open AutoStructs + +section fsm + +abbrev Alphabet (arity: Type) [FinEnum arity] := BitVec (FinEnum.card arity + 1) + +variable {arity : Type} [FinEnum arity] + +private structure fsm.State (carryLen : Nat) where + m : NFA $ Alphabet (arity := arity) -- TODO: ugly all over... + map : Std.HashMap (BitVec carryLen) State := ∅ + worklist : Array (BitVec carryLen) := ∅ + +def finFunToBitVec (c : carry → Bool) [FinEnum carry] : BitVec (FinEnum.card carry) := + ((FinEnum.toList carry).enum.map (fun (i, x) => c x |> Bool.toNat * 2^i)).foldl (init := 0) Nat.add |> BitVec.ofNat _ + +def bitVecToFinFun [FinEnum ar] (bv : BitVec $ FinEnum.card ar) : ar → Bool := fun c => bv[FinEnum.equiv.toFun c] + +/-- +Transforms an `FSM` of arity `k` to an `NFA` of arity `k+1`. +This correponds to transforming a function with `k` inputs and +one output to a `k+1`-ary relation. By convention, the output +is the MSB of the alphabet. +-/ +partial +def NFA.ofFSM (p : FSM arity) [FinEnum p.α] : NFA (Alphabet (arity := arity)) := + let m := NFA.empty + let (s, m) := m.newState + let initState := finFunToBitVec p.initCarry + let m := m.addInitial s + let map := Std.HashMap.empty.insert initState s + let worklist := Array.singleton initState + let st : fsm.State (arity := arity) (FinEnum.card p.α) := { m, map, worklist } + go st +where go (st : fsm.State (arity := arity) (FinEnum.card p.α)) : NFA _ := Id.run do + let some carry := st.worklist.get? (st.worklist.size - 1) | return st.m + let some s := st.map.get? carry | return NFA.empty + let m := st.m.addFinal s + let st := { st with m, worklist := st.worklist.pop } + let st := (FinEnum.toList (BitVec (FinEnum.card arity))).foldl (init := st) fun st a => + let eval x := (p.nextBitCirc x).eval (Sum.elim (bitVecToFinFun carry) (bitVecToFinFun a)) + let res : Bool := eval none + let carry' : BitVec (FinEnum.card p.α) := finFunToBitVec (fun c => eval (some c)) + + let (s', st) := if let some s' := st.map.get? carry' then (s', st) else + let (s', m) := st.m.newState + let map := st.map.insert carry' s' + let worklist := st.worklist.push carry' + (s', {m, map, worklist}) + + let m := st.m.addTrans (a.cons res) s s' + { st with m } + go st + +end fsm + + +/- A bunch of NFAs that implement the relations we care about -/ +section nfas_relations + +def NFA.ofConst {w} (bv : BitVec w) : NFA (BitVec 1) := + let m := NFA.empty + let (s, m) := m.newState + let m := m.addInitial s + let (s', m) := (List.range w).foldl (init := (s, m)) fun (s, m) i => + let b := bv[i]?.getD false + let (s', m) := m.newState + let m := m.addTrans (BitVec.ofBool b) s s' + (s', m) + m.addFinal s' + +def NFA.autEq : NFA (BitVec 2) := + let m := NFA.empty + let (s, m) := m.newState + let m := m.addInitial s + let m := m.addFinal s + let m := m.addTrans 0 s s + let m := m.addTrans 3 s s + m + +def NFA.autSignedCmp (cmp: RelationOrdering) : NFA (BitVec 2) := + let m := NFA.empty + let (seq, m) := m.newState + let (sgt, m) := m.newState + let (slt, m) := m.newState + let (sgtfin, m) := m.newState + let (sltfin, m) := m.newState + let m := m.addInitial seq + let m := m.addManyTrans [0#2, 3#2] seq seq + let m := m.addTrans 1#2 seq sgt + let m := m.addTrans 2#2 seq slt + let m := m.addTrans 1#2 seq sltfin + let m := m.addTrans 2#2 seq sgtfin + let m := m.addManyTrans [0#2, 1#2, 3#2] sgt sgt + let m := m.addManyTrans [0#2, 2#2, 3#2] sgt sgtfin + let m := m.addTrans 1#2 sgt sltfin + let m := m.addTrans 2#2 sgt slt + let m := m.addManyTrans [0#2, 2#2, 3#2] slt slt + let m := m.addManyTrans [0#2, 1#2, 3#2] slt sltfin + let m := m.addTrans 2#2 slt sgtfin + let m := m.addTrans 1#2 slt sgt + match cmp with + | .lt => m.addFinal sltfin + | .le => (m.addFinal sltfin).addFinal seq + | .gt => m.addFinal sgtfin + | .ge => (m.addFinal sgtfin).addFinal seq + +def NFA.autUnsignedCmp (cmp: RelationOrdering) : NFA (BitVec 2) := + let m := NFA.empty + let (seq, m) := m.newState + let (sgt, m) := m.newState + let (slt, m) := m.newState + let m := m.addInitial seq + let m := m.addManyTrans [0#2, 3#2] seq seq + let m := m.addTrans 1#2 seq sgt + let m := m.addTrans 2#2 seq slt + let m := m.addManyTrans [0#2, 1#2, 3#2] sgt sgt + let m := m.addTrans 2#2 sgt slt + let m := m.addManyTrans [0#2, 2#2, 3#2] slt slt + let m := m.addTrans 1#2 slt sgt + match cmp with + | .lt => m.addFinal slt + | .le => (m.addFinal slt).addFinal seq + | .gt => m.addFinal sgt + | .ge => (m.addFinal sgt).addFinal seq + +def NFA.autMsbSet : NFA (BitVec 1) := + let m := NFA.empty + let (si, m) := m.newState + let (sf, m) := m.newState + let m := m.addInitial si + let m := m.addFinal sf + let m := m.addTrans 1 si sf + let m := m.addManyTrans [0, 1] si si + m.determinize + +end nfas_relations + +-- A bunch of maps from `Fin n` to `Fin m` that we use to +-- lift and project variables when we interpret formulas +def liftMaxSucc1 (n m : Nat) : Fin (n + 1) → Fin (max n m + 2) := + fun k => if _ : k = n then Fin.last (max n m) else k.castLE (by omega) +def liftMaxSucc2 (n m : Nat) : Fin (m + 1) → Fin (max n m + 2) := + fun k => if _ : k = m then Fin.last (max n m + 1) else k.castLE (by omega) +def liftLast2 n : Fin 2 → Fin (n + 2) +| 0 => n +| 1 => n + 1 +def liftExcecpt2 n : Fin n → Fin (n + 2) := + fun k => Fin.castLE (by omega) k +def liftMax1 (n m : Nat) : Fin n → Fin (max n m) := + fun k => k.castLE (by omega) +def liftMax2 (n m : Nat) : Fin m → Fin (max n m) := + fun k => k.castLE (by omega) + +-- TODO(leo): style? upstream? +@[simp] +lemma finEnumCardFin n : FinEnum.card (Fin n) = n := by + rw [FinEnum.card, FinEnum.fin, FinEnum.ofList, FinEnum.ofNodupList] + simp only + rw [List.Nodup.dedup] + · simp + · apply List.nodup_finRange + +def AutoStructs.Relation.autOfRelation : Relation → NFA (BitVec 2) +| .eq => NFA.autEq +| .signed ord => NFA.autSignedCmp ord +| .unsigned ord => NFA.autUnsignedCmp ord + +def unopNfa {A} [BEq A] [FinEnum A] [Hashable A] + (op : Unop) (m : NFA A) : NFA A := + match op with + | .neg => m.neg + +-- TODO(leo) : why is the typchecking so slow? +def binopNfa {A} [BEq A] [FinEnum A] [Hashable A] + (op : Binop) (m1 : NFA A) (m2 : NFA A) : NFA A := + match op with + | .and => m1.inter m2 + | .or => m1.union m2 + | .impl => m1.neg.union m2 + | .equiv => (m1.neg.union m2).inter (m2.neg.union m1) + +-- TODO(leo) : why is it so slow? 40 seconds on my machine +-- the slow part is the compilation apparently +def nfaOfFormula (φ : Formula) : NFA (BitVec φ.arity) := + match φ with + | .atom rel t1 t2 => + let m1 := (termEvalEqFSM t1).toFSM |> NFA.ofFSM + let m2 := (termEvalEqFSM t2).toFSM |> NFA.ofFSM + let f1 := liftMaxSucc1 (FinEnum.card $ Fin t1.arity) (FinEnum.card $ Fin t2.arity) + let m1' := m1.lift f1 + let f2 := liftMaxSucc2 (FinEnum.card $ Fin t1.arity) (FinEnum.card $ Fin t2.arity) + let m2' := m2.lift f2 + let meq := rel.autOfRelation.lift $ liftLast2 (max (FinEnum.card (Fin t1.arity)) (FinEnum.card (Fin t2.arity))) + let m := NFA.inter m1' m2' |> NFA.inter meq + let mfinal := m.proj (liftExcecpt2 _) + have h : (Formula.atom .eq t1 t2).arity = max (FinEnum.card (Fin t1.arity)) (FinEnum.card (Fin t2.arity)) := by simp [FinEnum.card] + h ▸ mfinal + | .msbSet t => + let m := (termEvalEqFSM t).toFSM |> NFA.ofFSM + let mMsb := NFA.autMsbSet.lift $ fun _ => Fin.last t.arity + have h : t.arity + 1 = FinEnum.card (Fin t.arity) + 1 := by + simp [FinEnum.card] + let m : NFA (BitVec (t.arity + 1)) := h.symm ▸ m + let res := m.inter mMsb + res.proj $ fun n => n.castLE (by rw [Formula.arity]; omega) + | .unop op φ => unopNfa op (nfaOfFormula φ) + | .binop op φ1 φ2 => + let m1 := (nfaOfFormula φ1).lift $ liftMax1 φ1.arity φ2.arity + let m2 := (nfaOfFormula φ2).lift $ liftMax2 φ1.arity φ2.arity + binopNfa op m1 m2 + +-- This is wrong, this is (hopefuly) true only for `w > 0` +axiom decision_procedure_is_correct {w} (φ : Formula) (env : Nat → BitVec w) : + nfaOfFormula φ |>.isUniversal' → φ.sat' env + +-- For testing the comparison operators. +def nfaOfCompareConstants (signed : Bool) {w : Nat} (a b : BitVec w) : NFA (BitVec 0) := + let m1 := NFA.ofConst a + let m2 := NFA.ofConst b + let f1 : Fin 1 → Fin 2 := fun 0 => 0 + let m1' := m1.lift f1 + let f2 : Fin 1 → Fin 2 := fun 0 => 1 + let m2' := m2.lift f2 + let meq := if signed then NFA.autSignedCmp .lt else NFA.autUnsignedCmp .lt + let m := NFA.inter m1' m2' |> NFA.inter meq + let mfinal := m.proj (liftExcecpt2 _) + mfinal + +/- This case is a bit weird because we have on the one hand an + automaton with a singleton alphabet, denoting a singleton set. + Hence the correpondance between a word and the unique bitvector in + `BitVec 0` is not super clear... This is why we check for non-emptiness + rather than universality. This shoud be clarified. +-/ +def testLeq (signed : Bool) (w : Nat) : Option (BitVec w × BitVec w) := + (List.range (2^w)).findSome? fun n => + (List.range (2^w)).findSome? fun m => + let bv := BitVec.ofNat w n + let bv' := BitVec.ofNat w m + if (if signed then bv <ₛ bv' else bv <ᵤ bv') == + (nfaOfCompareConstants signed bv bv' |> NFA.isNotEmpty) + then none else some (bv, bv') +/-- info: true -/ +#guard_msgs in #eval! (testLeq true 4 == none) + +def nfaOfMsb {w : Nat} (a : BitVec w) : NFA (BitVec 0) := + let m := NFA.ofConst a + let meq := NFA.autMsbSet + let m := m |> NFA.inter meq + let mfinal := m.proj $ fun _ => 0 + mfinal + +def testMsb (w : Nat) : Bool := + (List.range (2^w)).all fun n => + let bv := BitVec.ofNat w n + (bv.msb == true) == (nfaOfMsb bv |> NFA.isNotEmpty) +/-- info: true -/ +#guard_msgs in #eval! testMsb 8 + +-- -x = ~~~ (x - 1) +def ex_formula_neg_eq_neg_not_one : Formula := + open Term in + let x := var 0 + Formula.atom .eq (neg x) (not $ sub x 1) + +/-- info: true -/ +#guard_msgs in #eval! nfaOfFormula ex_formula_neg_eq_neg_not_one |> NFA.isUniversal + +-- x &&& ~~~ y = x - (x &&& y) +def ex_formula_and_not_eq_sub_add : Formula := + open Term in + let x := var 0 + let y := var 1 + Formula.atom .eq (and x (not y)) (sub x (and x y)) +/-- info: true -/ +#guard_msgs in #eval! nfaOfFormula ex_formula_and_not_eq_sub_add |> NFA.isUniversal + +/- x &&& y ≤ᵤ ~~~(x ^^^ y) -/ +def ex_formula_and_ule_not_xor : Formula := + open Term in + let x := var 0 + let y := var 1 + .atom (.unsigned .le) (.and x y) (.not (.xor x y)) + +/-- info: true -/ +#guard_msgs in #eval! nfaOfFormula ex_formula_and_ule_not_xor |> NFA.isUniversal + +-- Only true for `w > 0`! +-- x = 0 ↔ (~~~ (x ||| -x)).msb +def ex_formula_eq_zero_iff_not_or_sub : Formula := + open Term in + let x := var 0 + .binop .equiv + (.atom .eq x .zero) + (.msbSet (.not (.or x (.neg x)))) + +/-- info: true -/ +#guard_msgs in #eval! nfaOfFormula ex_formula_eq_zero_iff_not_or_sub |> NFA.isUniversal' + +-- (x <ₛ 0) ↔ x.msb := by +def ex_formula_lst_iff : Formula := + open Term in + let x := var 0 + .binop .equiv + (.atom (.signed .lt) x .zero) + (.msbSet x) + +/-- info: true -/ +#guard_msgs in #eval! nfaOfFormula ex_formula_lst_iff |> NFA.isUniversal diff --git a/SSA/Experimental/Bits/AutoStructs/Tactic.lean b/SSA/Experimental/Bits/AutoStructs/Tactic.lean new file mode 100644 index 000000000..f21ca8898 --- /dev/null +++ b/SSA/Experimental/Bits/AutoStructs/Tactic.lean @@ -0,0 +1,311 @@ +/- +Released under Apache 2.0 license as described in the file LICENSE. +-/ + +import Lean.Meta.Tactic.Simp.BuiltinSimprocs +import Lean.Meta.KExprMap +import SSA.Experimental.Bits.AutoStructs.Basic +import SSA.Experimental.Bits.AutoStructs.Defs +import SSA.Experimental.Bits.AutoStructs.FormulaToAuto +import SSA.Experimental.Bits.SafeNativeDecide +import Qq.Macro + +open AutoStructs + +open Lean Elab Tactic +open Lean Meta +open scoped Qq + +def AutoStructs.Term.toExpr (t : Term) : Expr := + open Term in + match t with + | .var n => mkApp (mkConst ``var) (mkNatLit n) + | .zero => mkConst ``zero + | .one => mkConst ``one + | .negOne => mkConst ``negOne + | .and t1 t2 => mkApp2 (mkConst ``Term.and) t1.toExpr t2.toExpr + | .or t1 t2 => mkApp2 (mkConst ``Term.or) t1.toExpr t2.toExpr + | .xor t1 t2 => mkApp2 (mkConst ``Term.xor) t1.toExpr t2.toExpr + | .add t1 t2 => mkApp2 (mkConst ``add) t1.toExpr t2.toExpr + | .sub t1 t2 => mkApp2 (mkConst ``sub) t1.toExpr t2.toExpr + | .not t => mkApp (mkConst ``Term.not) t.toExpr + | .neg t => mkApp (mkConst ``neg) t.toExpr + | .incr t => mkApp (mkConst ``incr) t.toExpr + | .decr t => mkApp (mkConst ``decr) t.toExpr + +def AutoStructs.Relation.toExpr (rel : Relation) : Expr := + open Relation in + open RelationOrdering in + match rel with + | .eq => mkConst ``eq + | .unsigned .lt => mkApp (mkConst ``unsigned) (mkConst ``lt) + | .unsigned .le => mkApp (mkConst ``unsigned) (mkConst ``le) + | .unsigned .gt => mkApp (mkConst ``unsigned) (mkConst ``gt) + | .unsigned .ge => mkApp (mkConst ``unsigned) (mkConst ``ge) + | .signed .lt => mkApp (mkConst ``signed) (mkConst ``lt) + | .signed .le => mkApp (mkConst ``signed) (mkConst ``le) + | .signed .gt => mkApp (mkConst ``signed) (mkConst ``gt) + | .signed .ge => mkApp (mkConst ``signed) (mkConst ``ge) + +def AutoStructs.Binop.toExpr (rel : Binop) : Expr := + open Binop in + match rel with + | .and => mkConst ``Binop.and + | .or => mkConst ``Binop.or + | .impl => mkConst ``Binop.impl + | .equiv => mkConst ``Binop.equiv + +def AutoStructs.Unop.toExpr (rel : Unop) : Expr := + open Unop in + match rel with + | .neg => mkConst ``neg + +def AutoStructs.Formula.toExpr (φ : Formula) : Expr := + open AutoStructs in + open Formula in + match φ with + | .atom rel t1 t2 => mkApp3 (mkConst ``atom) rel.toExpr t1.toExpr t2.toExpr + | .binop op φ1 φ2 => mkApp3 (mkConst ``binop) op.toExpr φ1.toExpr φ2.toExpr + | .unop op φ => mkApp2 (mkConst ``unop) op.toExpr φ.toExpr + | .msbSet φ => mkApp (mkConst ``msbSet) φ.toExpr + +instance : ToExpr AutoStructs.Formula where + toExpr := AutoStructs.Formula.toExpr + toTypeExpr := mkConst ``AutoStructs.Formula + +namespace Tactic +structure State where + varMap : KExprMap Nat := {} + invMap : Array Expr := #[] +deriving Inhabited + +abbrev M := StateRefT State MetaM + +def addAsVar (e : Expr) : M AutoStructs.Term := do + if let some v ← (←get).varMap.find? e then + pure (.var v) + else + let s ← get + let v := s.invMap.size + let varMap := ← s.varMap.insert e v + let invMap := s.invMap.push e + set ({varMap, invMap } : State) + pure (.var v) + +-- TODO(leo): make this work +def checkBVs (es : List Expr) : M Bool := do + for e in es do + let_expr BitVec _ := e | return false + pure true + +-- TODO(leo): make the shortcuts better +partial def parseTerm (e : Expr) : M AutoStructs.Term := do + match_expr e with + | OfNat.ofNat α n _ => + let_expr BitVec _ ← α | addAsVar e + match n with + | .lit (.natVal 0) => pure AutoStructs.Term.zero + | .lit (.natVal 1) => pure AutoStructs.Term.one + | _ => logWarning m!"Unknown integer {n}"; addAsVar e -- TODO: all other integers... + | BitVec.ofNat _w n => + match n.nat? with + | some 0 => pure AutoStructs.Term.zero + | some 1 => pure AutoStructs.Term.one + | _ => logWarning m!"Unknown integer {n}"; addAsVar e -- TODO: all other integers... + | HXor.hXor α1 α2 α3 _ e1 e2 => + let_expr BitVec _ ← α1 | addAsVar e + let_expr BitVec _ ← α2 | addAsVar e + let_expr BitVec _ ← α3 | addAsVar e + let t1 ← parseTerm e1 + let t2 ← parseTerm e2 + pure (.xor t1 t2) + | HOr.hOr α1 α2 α3 _ e1 e2 => + let_expr BitVec _ ← α1 | addAsVar e + let_expr BitVec _ ← α2 | addAsVar e + let_expr BitVec _ ← α3 | addAsVar e + let t1 ← parseTerm e1 + let t2 ← parseTerm e2 + pure (.or t1 t2) + | HAnd.hAnd α1 α2 α3 _ e1 e2 => + let_expr BitVec _ ← α1 | addAsVar e + let_expr BitVec _ ← α2 | addAsVar e + let_expr BitVec _ ← α3 | addAsVar e + let t1 ← parseTerm e1 + let t2 ← parseTerm e2 + pure (.and t1 t2) + | HAdd.hAdd α1 α2 α3 _ e1 e2 => + let_expr BitVec _ ← α1 | addAsVar e + let_expr BitVec _ ← α2 | addAsVar e + let_expr BitVec _ ← α3 | addAsVar e + let t1 ← parseTerm e1 + let t2 ← parseTerm e2 + pure (.add t1 t2) + | HSub.hSub α1 α2 α3 _ e1 e2 => + let_expr BitVec _ ← α1 | addAsVar e + let_expr BitVec _ ← α2 | addAsVar e + let_expr BitVec _ ← α3 | addAsVar e + let t1 ← parseTerm e1 + let t2 ← parseTerm e2 + pure (.sub t1 t2) + | Neg.neg α _ e1 => + let_expr BitVec _ ← α | addAsVar e + let t1 ← parseTerm e1 + pure (.neg t1) + | Complement.complement α _ e' => + let_expr BitVec _ ← α | addAsVar e + let t ← parseTerm e' + pure (.not t) + | _ => + -- logInfo m!"term is {e} === {reprStr e}, let's treat is as a variable" + addAsVar e + +-- Note: we assume a preprocesing phase replaced boolean operations with +-- Prop operations, e.g. `(x < y && y ≤ z) = (x < z)` with `x < y ∧ y ≤ z ↔ x < z +partial def parseFormula (e : Expr) : M Formula := do + match_expr e with + | Eq α e1 e2 => + match_expr e2 with + | true => + let_expr Bool ← α | failure + parseFormula e1 + | false => + let_expr Bool ← α | failure + pure $ .unop .neg (←parseFormula e1) + | _ => + let_expr BitVec _ ← α | throwError m!"Equality {e} has a strange type" + let t1 ← parseTerm e1 + let t2 ← parseTerm e2 + pure (.atom .eq t1 t2) + | Not e => + let t ← parseFormula e + pure (.unop .neg t) + | Iff e1 e2 => + let t1 ← parseFormula e1 + let t2 ← parseFormula e2 + pure (.binop .equiv t1 t2) + | BEq.beq α _ e1 e2 => + match_expr α with + | BitVec _ => + let t1 ← parseTerm e1 + let t2 ← parseTerm e2 + pure (.atom .eq t1 t2) + | _ => throwError m!"Unexpected Beq type {α}" + -- | Impl => TODO + | And e1 e2 => + let t1 ← parseFormula e1 + let t2 ← parseFormula e2 + pure (.binop .and t1 t2) + | Or e1 e2 => + let t1 ← parseFormula e1 + let t2 ← parseFormula e2 + pure (.binop .or t1 t2) + | BitVec.slt _ e1 e2 => + pure (.atom (.signed .lt) (← parseTerm e1) (← parseTerm e2)) + | BitVec.sle _ e1 e2 => + pure (.atom (.signed .le) (← parseTerm e1) (← parseTerm e2)) + | BitVec.ult _ e1 e2 => + pure (.atom (.unsigned .lt) (← parseTerm e1) (← parseTerm e2)) + | BitVec.ule _ e1 e2 => + pure (.atom (.unsigned .le) (← parseTerm e1) (← parseTerm e2)) + | BitVec.msb _ e => + pure (.msbSet (← parseTerm e)) + | _ => throwError m!"Unsupported syntax {e} === {reprStr e}" + +axiom automaton_axiom (A : Prop) : A + +private def mkNativeAuxDecl (baseName : Name) (type value : Expr) : TermElabM Name := do + let auxName ← Term.mkAuxName baseName + let decl := Declaration.defnDecl { + name := auxName, levelParams := [], type, value + hints := .abbrev + safety := .safe + } + addDecl decl + compileDecl decl + pure auxName + +def mkFinLit (m n : Nat) : MetaM Expr := + let r := mkRawNatLit n + mkAppOptM ``OfNat.ofNat #[some (mkApp (mkConst ``Fin) (mkNatLit m)), some r, none] + +def mkBitVecLit0 (w : Expr) : MetaM Expr := + mkAppOptM ``BitVec.zero #[w] + + +-- instance {α : Type u} [ToExpr α] [ToLevel.{u}] : ToExpr (Array α) := +-- let type := toTypeExpr α +-- { toExpr := fun as => mkApp2 (mkConst ``List.toArray [toLevel.{u}]) type (toExpr as.toList) +-- toTypeExpr := mkApp (mkConst ``Array [toLevel.{u}]) type } + +-- let type := toTypeExpr α +-- let nil := mkApp (mkConst ``List.nil [levelZero]) type +-- let cons := mkApp (mkConst ``List.cons [levelZero]) type +-- { toExpr := List.toExprAux nil cons, +-- toTypeExpr := mkApp (mkConst ``List [levelZero]) type } + +def listExprExpr (es : List Expr) (type : Expr) : Expr := + let nil := mkApp (mkConst ``List.nil [levelZero]) type + let cons := mkApp (mkConst ``List.cons [levelZero]) type + es.foldl (init := nil) fun res e => mkApp2 cons e res + +def arrayExprExpr (es : Array Expr) (type : Expr) : Expr := + mkApp2 (mkConst ``List.toArray [levelZero]) type (listExprExpr es.toList type) + +private def buildEnv (es : Array Expr) : MetaM Expr := do + let some e0 := es[0]? | throwError "The goal contains no variables" + let type ← inferType e0 + let_expr BitVec _ ← type | throwError "Variables must be of BitVector type" + let a := arrayExprExpr es.reverse type + mkAppM ``envOfArray #[a] + +private partial def assertSame (φ : Formula) (st : State) : TacticM Unit:= do + withMainContext do + liftMetaTactic fun goal => do + let goalT ← goal.getType + let efor := toExpr φ + let ρ ← buildEnv st.invMap + let new ← mkAppM ``Formula.sat' #[efor, ρ] + -- let new ← mkAppM ``Eq #[new, mkConst ``true] + let newGoal ← mkAppM ``Iff #[new, goalT] + let mvar ← mkFreshExprMVar (some newGoal) + let mvar' ← mkFreshExprMVar (some new) + goal.assign (←mkAppM ``Iff.mp #[mvar, mvar']) + pure [mvar.mvarId!, mvar'.mvarId!] + +elab "bv_automata_inner" : tactic => do + withMainContext do + let mvar ← getMainGoal + let typ ← mvar.getType + let e ← instantiateMVars typ + let (φ, st) ← parseFormula e|>.run default + assertSame φ st + +-- TODO(leo): make the tactic more structured (in Coq `bv_inner; [by simp | by ...]) +macro "bv_automata'" : tactic => + `(tactic| + (bv_automata_inner; simp; apply decision_procedure_is_correct; safe_native_decide)) + +end Tactic + +section tests + +variable (w : Nat) (x y z : BitVec w) + +theorem dfadfa : (x <ₛ 0) ↔ x.msb := by bv_automata' + +theorem and_ule_not_xor : x &&& y ≤ᵤ ~~~(x ^^^ y) := by bv_automata' + +theorem xor_ule_or : x ^^^ y ≤ᵤ x ||| y := by bv_automata' + +theorem ult_iff_not_ule : (x <ᵤ y) ↔ ¬ (y ≤ᵤ x) := by bv_automata' + +theorem sub_neg_sub : (x - y) = - (y - x) := by bv_automata' + +theorem eq_iff_not_sub_or_sub : + x = y ↔ (~~~ (x - y ||| y - x)).msb := by bv_automata' + +theorem lt_iff_sub_xor_xor_and_sub_xor : + (x <ₛ y) ↔ ((x - y) ^^^ ((x ^^^ y) &&& ((x - y) ^^^ x))).msb := by + bv_automata' + +end tests diff --git a/SSA/Experimental/Bits/Fast/Tactic.lean b/SSA/Experimental/Bits/Fast/Tactic.lean index ac8622a0a..f3d3355e8 100644 --- a/SSA/Experimental/Bits/Fast/Tactic.lean +++ b/SSA/Experimental/Bits/Fast/Tactic.lean @@ -2,6 +2,7 @@ import Lean.Meta.Tactic.Simp.BuiltinSimprocs import SSA.Experimental.Bits.Fast.BitStream import SSA.Experimental.Bits.Fast.Decide import SSA.Experimental.Bits.Fast.Lemmas +import SSA.Experimental.Bits.SafeNativeDecide import Qq.Macro open Lean Elab Tactic @@ -295,40 +296,6 @@ def introduceMapIndexToFVar : TacticM Unit := withMainContext <| do elab "introduceMapIndexToFVar" : tactic => introduceMapIndexToFVar -/- Copy-pasted from Lean/Elab/Tactic/ElabTerm.lean --/ - -private def preprocessPropToDecide (expectedType : Expr) : TermElabM Expr := do - let mut expectedType ← instantiateMVars expectedType - if expectedType.hasFVar then - expectedType ← zetaReduce expectedType - if expectedType.hasFVar || expectedType.hasMVar then - throwError "expected type must not contain free or meta variables{indentExpr expectedType}" - return expectedType - -private def mkNativeAuxDecl (baseName : Name) (type value : Expr) : TermElabM Name := do - let auxName ← Term.mkAuxName baseName - let decl := Declaration.defnDecl { - name := auxName, levelParams := [], type, value - hints := .abbrev - safety := .safe - } - addDecl decl - compileDecl decl - pure auxName - -elab "safe_native_decide" : tactic => - Lean.Elab.Tactic.closeMainGoalUsing `safeNativeDecide fun expectedType => do - let expectedType ← preprocessPropToDecide expectedType - let d ← mkDecide expectedType - let auxDeclName ← mkNativeAuxDecl `_nativeDecide (Lean.mkConst `Bool) d - -- new lines - unless ← reduceBoolNative auxDeclName do - throwError "The statement is false" - let rflPrf ← mkEqRefl (toExpr true) - let s := d.appArg! -- get instance from `d` - return mkApp3 (Lean.mkConst ``of_decide_eq_true) expectedType s <| mkApp3 (Lean.mkConst ``Lean.ofReduceBool) (Lean.mkConst auxDeclName) (toExpr true) rflPrf - /-- Create bv_automata tactic which solves equalities on bitvectors. -/ diff --git a/SSA/Experimental/Bits/SafeNativeDecide.lean b/SSA/Experimental/Bits/SafeNativeDecide.lean new file mode 100644 index 000000000..62584ab65 --- /dev/null +++ b/SSA/Experimental/Bits/SafeNativeDecide.lean @@ -0,0 +1,38 @@ +import Lean.Meta.Tactic.Simp.BuiltinSimprocs +import Lean.Meta + +open Lean Elab Tactic +open Lean Meta + +/- Copy-pasted from Lean/Elab/Tactic/ElabTerm.lean -/ + +private def preprocessPropToDecide (expectedType : Expr) : TermElabM Expr := do + let mut expectedType ← instantiateMVars expectedType + if expectedType.hasFVar then + expectedType ← zetaReduce expectedType + if expectedType.hasFVar || expectedType.hasMVar then + throwError "expected type must not contain free or meta variables{indentExpr expectedType}" + return expectedType + +private def mkNativeAuxDecl (baseName : Name) (type value : Expr) : TermElabM Name := do + let auxName ← Term.mkAuxName baseName + let decl := Declaration.defnDecl { + name := auxName, levelParams := [], type, value + hints := .abbrev + safety := .safe + } + addDecl decl + compileDecl decl + pure auxName + +elab "safe_native_decide" : tactic => + Lean.Elab.Tactic.closeMainGoalUsing `safeNativeDecide fun expectedType => do + let expectedType ← preprocessPropToDecide expectedType + let d ← mkDecide expectedType + let auxDeclName ← mkNativeAuxDecl `_nativeDecide (Lean.mkConst `Bool) d + -- new lines + unless ← reduceBoolNative auxDeclName do + throwError "The statement is false" + let rflPrf ← mkEqRefl (toExpr true) + let s := d.appArg! -- get instance from `d` + return mkApp3 (Lean.mkConst ``of_decide_eq_true) expectedType s <| mkApp3 (Lean.mkConst ``Lean.ofReduceBool) (Lean.mkConst auxDeclName) (toExpr true) rflPrf diff --git a/SSA/Projects/InstCombine/TacticAuto.lean b/SSA/Projects/InstCombine/TacticAuto.lean index 15846dac8..e593c85da 100644 --- a/SSA/Projects/InstCombine/TacticAuto.lean +++ b/SSA/Projects/InstCombine/TacticAuto.lean @@ -5,6 +5,8 @@ import Mathlib.Tactic.Ring import SSA.Projects.InstCombine.ForLean import SSA.Projects.InstCombine.LLVM.EDSL import SSA.Experimental.Bits.Fast.Tactic +import SSA.Experimental.Bits.AutoStructs.Tactic +import SSA.Experimental.Bits.AutoStructs.ForLean import Std.Tactic.BVDecide -- import Leanwuzla @@ -156,6 +158,12 @@ macro "bv_auto": tactic => simp (config := {failIfUnchanged := false}) only [(BitVec.two_mul), ←BitVec.negOne_eq_allOnes] bv_automata ) + | ( + simp (config := {failIfUnchanged := false}) only [BitVec.two_mul, ←BitVec.negOne_eq_allOnes, ofBool_0_iff_false, ofBool_1_iff_true] + try rw [Bool.eq_iff_iff] + simp (config := {failIfUnchanged := false}) [Bool.or_eq_true_iff, Bool.and_eq_true_iff, beq_iff_eq] + bv_automata' + ) | simp (config := {failIfUnchanged := false}) only [BitVec.abs_eq_if] try split