diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index 1d12641f0abb..33106dbccfd1 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -531,6 +531,8 @@ SMT-Lib name: `bvashr` except this operator uses a `Nat` shift value. -/ def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s) +def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat + instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x <<< y.toNat⟩ instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x >>> y.toNat⟩ @@ -614,6 +616,13 @@ theorem ofBool_append (msb : Bool) (lsbs : BitVec w) : ofBool msb ++ lsbs = (cons msb lsbs).cast (Nat.add_comm ..) := rfl +/-- +`twoPow i` is the bitvector `2^i` if `i < w`, and `0` otherwise. +That is, 2 to the power `i`. +For the bitwise point of view, it has the `i`th bit as `1` and all other bits as `0`. +-/ +def twoPow (w : Nat) (i : Nat) : BitVec w := 1#w <<< i + end bitwise section normalization_eqs diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 1ca892057551..5f142f389ce3 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -159,6 +159,20 @@ theorem add_eq_adc (w : Nat) (x y : BitVec w) : x + y = (adc x y false).snd := b theorem allOnes_sub_eq_not (x : BitVec w) : allOnes w - x = ~~~x := by rw [← add_not_self x, BitVec.add_comm, add_sub_cancel] +/-- Adding two bitvectors equals or-ing them if they are 1 in mutually exclusive locations. -/ +theorem add_eq_or_of_and_eq_zero {w : Nat} (x y : BitVec w) + (h : x &&& y = 0#w) : x + y = x ||| y := by + rw [add_eq_adc, adc, iunfoldr_replace (fun _ => false) (x ||| y)] + · rfl + · simp [adcb, atLeastTwo, h] + intros i + replace h : (x &&& y).getLsb i = (0#w).getLsb i := by rw [h] + simp only [getLsb_and, getLsb_zero, and_eq_false_imp] at h + constructor + · intros hx + simp_all [hx] + · by_cases hx : x.getLsb i <;> simp_all [hx] + /-! ### Negation -/ theorem bit_not_testBit (x : BitVec w) (i : Fin w) : @@ -235,4 +249,1321 @@ theorem sle_eq_carry (x y : BitVec w) : x.sle y = !((x.msb == y.msb).xor (carry w y (~~~x) true)) := by rw [sle_eq_not_slt, slt_eq_not_carry, beq_comm] +/-! ### mul recurrence for bitblasting -/ + +def mulRec (l r : BitVec w) (s : Nat) : BitVec w := + let cur := if r.getLsb s then (l <<< s) else 0 + match s with + | 0 => cur + | s + 1 => mulRec l r s + cur + +theorem mulRec_zero_eq (l r : BitVec w) : + mulRec l r 0 = if r.getLsb 0 then l else 0 := by + simp [mulRec] + +theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : + mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := by + simp [mulRec] + +theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false + {x : BitVec w} {i : Nat} {hx : x.getLsb i = false} : + zeroExtend w₂ (x.truncate (i + 1)) = + zeroExtend w₂ (x.truncate i) := by + ext k + simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] + by_cases hik:i = k + · subst hik + simp [hx] + · by_cases hik' : k < i + 1 <;> simp [hik'] <;> omega + +theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true + (x : BitVec w) (i : Nat) (hx : x.getLsb i = true) : + zeroExtend w₂ (x.truncate (i + 1)) = + zeroExtend w₂ (x.truncate i) ||| (twoPow w₂ i) := by + ext k + simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] + by_cases hik : i = k + · subst hik + simp [hx] + · by_cases hik' : k < i + 1 <;> simp [hik, hik'] <;> omega + +/-- Recurrence lemma: truncating to `i+1` bits and then zero extending to `w` +equals truncating upto `i` bits `[0..i-1]`, and then adding the `i`th bit of `x`. -/ +theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w) (i : Nat) : + zeroExtend w (x.truncate (i + 1)) = + zeroExtend w (x.truncate i) + (x &&& twoPow w i) := by + rw [add_eq_or_of_and_eq_zero] + · ext k + simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] + by_cases hik:i = k + · subst hik + simp + · simp [hik] + /- Really, 'omega' should be able to do this-/ + by_cases hik' : k < (i + 1) + · have hik'' : k < i := by omega + simp [hik', hik''] + · have hik'' : ¬ (k < i) := by omega + simp [hik', hik''] + · ext k + simp + by_cases hi : x.getLsb i <;> simp [hi] <;> omega + +theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : + mulRec l r s = l * ((r.truncate (s + 1)).zeroExtend w) := by + induction s + case zero => + simp [mulRec_zero_eq] + by_cases r.getLsb 0 + case pos hr => + simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero, + hr, ofBool_true, ofNat_eq_ofNat] + rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]; simp + case neg hr => + simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero] + case succ s' hs => + rw [mulRec_succ_eq, hs] + have heq : + (if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) = + (l * (r &&& (BitVec.twoPow w (s' + 1)))) := by + simp only [ofNat_eq_ofNat, and_twoPow_eq_getLsb] + by_cases hr : r.getLsb (s' + 1) <;> simp [hr] + rw [heq, ← BitVec.mul_add, ← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow] + +/-- Zero extending by number of bits larger than the bitwidth has no effect. -/ +theorem zeroExtend_of_ge {x : BitVec w} {i j : Nat} (hi : i ≥ w) : + (x.zeroExtend i).zeroExtend j = x.zeroExtend j := by + ext k + simp + intros hx; + have hi' : k < w := BitVec.lt_of_getLsb _ _ hx + omega + +/-- Zero extending by the bitwidth has no effect. -/ +theorem zeroExtend_eq_self {x : BitVec w} : x.zeroExtend w = x := by + ext i + simp [getLsb_zeroExtend] + +theorem getLsb_mul (x y : BitVec w) (i : Nat) : + (x * y).getLsb i = (mulRec x y w).getLsb i := by + simp [mulRec_eq_mul_signExtend_truncate] + rw [truncate, zeroExtend_of_ge (by omega), zeroExtend_eq_self] +/- ## Shift left for arbitrary bit width -/ + +@[simp] +theorem shiftLeft_zero (x : BitVec w) : x <<< 0 = x := by + simp [bv_toNat] + +@[simp] +theorem zero_shiftLeft (n : Nat) : (0#w) <<< n = 0 := by + simp [bv_toNat] + +@[simp] +theorem truncate_one_eq_ofBool_getLsb (x : BitVec w) : + x.truncate 1 = ofBool (x.getLsb 0) := by + ext i + simp [show i = 0 by omega] + +/-## shiftLeft recurrence -/ + +def shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ := + let shiftAmt := (y &&& (twoPow w₂ n)) + match n with + | 0 => x <<< shiftAmt + | n + 1 => (shiftLeftRec x y n) <<< shiftAmt + +@[simp] +theorem shiftLeftRec_zero (x : BitVec w₁) (y : BitVec w₂) : + shiftLeftRec x y 0 = x <<< (y &&& twoPow w₂ 0) := by + simp [shiftLeftRec] + +@[simp] +theorem shiftLeftRec_succ (x : BitVec w₁) (y : BitVec w₂) : + shiftLeftRec x y (n + 1) = + (shiftLeftRec x y n) <<< (y &&& twoPow w₂ (n + 1)) := by + simp [shiftLeftRec] + +-- | TODO: should this be a simp-lemma? Probably not. +theorem shiftLeft_eq' (x : BitVec w) (y : BitVec w₂) : + x <<< y = x <<< y.toNat := by rfl + +-- | TODO: what to name these theorems? +@[simp] +theorem shiftLeft_zero' (x : BitVec w) : + x <<< (0#w₂) = x := by + simp [shiftLeft_eq'] + +@[simp] +theorem getLsb_ofNat_one (w i : Nat) : + (1#w).getLsb i = (decide (i = 0) && decide (i < w)) := by + rcases w with rfl | w + · simp; + · simp [getLsb] + by_cases hi : i = 0 + · simp [hi] + · simp [hi] + intros _; simp [testBit, shiftRight_eq_div_pow]; + suffices 1 / 2^i = 0 by simp [this] + apply Nat.div_eq_of_lt; + exact Nat.one_lt_two_pow_iff.mpr hi + +theorem shiftLeft'_shiftLeft' {x y z : BitVec w} : + x <<< y <<< z = x <<< (y.toNat + z.toNat) := by + simp [shiftLeft_eq', shiftLeft_shiftLeft] + +theorem shiftLeft_or_eq_shiftLeft_shiftLeft_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂} + (h : y &&& z = 0#w₂) (h' : y.toNat + z.toNat < 2^w₂): + x <<< (y ||| z) = x <<< y <<< z := by + simp [← add_eq_or_of_and_eq_zero _ _ h, shiftLeft_eq', shiftLeft_shiftLeft, + toNat_add, Nat.mod_eq_of_lt h'] + + +theorem getLsb_shiftLeft' (x : BitVec w) (y : BitVec w₂) (i : Nat) : + (x <<< y).getLsb i = (decide (i < w) && !decide (i < y.toNat) && x.getLsb (i - y.toNat)) := by + simp [shiftLeft_eq', getLsb_shiftLeft] + +theorem shiftLeftRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n + 1 ≤ w₂) : + shiftLeftRec x y n = x <<< (y.truncate (n + 1)).zeroExtend w₂ := by + induction n generalizing x y + case zero => + ext i + simp only [shiftLeftRec_zero, twoPow_zero_eq_one, Nat.reduceAdd, truncate_one_eq_ofBool_getLsb] + have heq : (y &&& 1#w₂) = zeroExtend w₂ (ofBool (y.getLsb 0)) := by + ext i + by_cases h : (↑i : Nat) = 0 <;> simp [h, Bool.and_comm] + simp [heq] + case succ n ih => + simp + by_cases h : y.getLsb (n + 1) <;> simp [h] + · rw [ih (hn := by omega)] + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true _ _ h] + rw [shiftLeft_or_eq_shiftLeft_shiftLeft_of_and_eq_zero] + · simp + · simp; + have hpow : 2 ^ (n + 1) < 2 ^ w₂ := by + apply Nat.pow_lt_pow_of_lt (by decide) (by omega) + have h₂ : 2 ^ (n + 1) % 2 ^ w₂ = 2 ^ (n + 1) := Nat.mod_eq_of_lt (by omega) + have h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + apply Nat.mod_lt + apply Nat.pow_pos (by decide); omega + obtain h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) <;> omega + rw [h₁, h₂] + rcases w₂ with rfl | w₂ + · omega + · apply Nat.add_lt_add_of_lt_of_le + · simp only [pow_eq, Nat.mul_eq, Nat.mul_one] + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + · apply Nat.mod_lt + · apply Nat.pow_pos (by decide) + · apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · simp + apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · rw [ih (hn := by omega)] + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)] + simp [h] + +#print axioms shiftLeftRec_eq + +theorem shiftLeft_eq_shiftLeft_rec (x : BitVec ℘) (y : BitVec w₂) : + x <<< y = shiftLeftRec x y (w₂ - 1) := by + rcases w₂ with rfl | w₂ + · simp [of_length_zero] + · simp [shiftLeftRec_eq x y w₂ (by omega)] + + +/-## (Logical) ushiftRight recurrence -/ + +def ushiftRight_rec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ := + let shiftAmt := (y &&& (twoPow w₂ n)) + match n with + | 0 => x >>> shiftAmt + | n + 1 => (ushiftRight_rec x y n) >>> shiftAmt + +@[simp] +theorem ushiftRight_rec_zero (x : BitVec w₁) (y : BitVec w₂) : + ushiftRight_rec x y 0 = x >>> (y &&& twoPow w₂ 0) := by + simp [ushiftRight_rec] + +@[simp] +theorem ushiftRight_rec_succ (x : BitVec w₁) (y : BitVec w₂) : + ushiftRight_rec x y (n + 1) = + (ushiftRight_rec x y n) >>> (y &&& twoPow w₂ (n + 1)) := by + simp [ushiftRight_rec] + +-- | TODO: should this be a simp-lemma? Probably not. +theorem ushiftRight_eq' (x : BitVec w) (y : BitVec w₂) : + x >>> y = x >>> y.toNat := by rfl + + +@[simp] +theorem BitVec.ushiftRight_zero (x : BitVec w) : x >>> 0 = x := by + simp [bv_toNat] + +-- | TODO: what to name these theorems? +@[simp] +theorem ushiftRight_zero' (x : BitVec w) : + x >>> (0#w₂) = x := by + simp [ushiftRight_eq'] + +theorem ushiftRight'_ushiftRight' {x y z : BitVec w} : + x >>> y >>> z = x >>> (y.toNat + z.toNat) := by + simp [ushiftRight_eq', shiftRight_shiftRight] + +theorem ushiftRight_or_eq_ushiftRight_ushiftRight_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂} + (h : y &&& z = 0#w₂) (h' : y.toNat + z.toNat < 2^w₂): + x >>> (y ||| z) = x >>> y >>> z := by + simp [← add_eq_or_of_and_eq_zero _ _ h, ushiftRight_eq', shiftRight_shiftRight, + toNat_add, Nat.mod_eq_of_lt h'] + +theorem getLsb_ushiftRight' (x : BitVec w) (y : BitVec w₂) (i : Nat) : + (x >>> y).getLsb i = x.getLsb (y.toNat + i) := by + simp [ushiftRight_eq', getLsb_ushiftRight] + +theorem ushiftRight_rec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n + 1 ≤ w₂) : + ushiftRight_rec x y n = x >>> (y.truncate (n + 1)).zeroExtend w₂ := by + induction n generalizing x y + case zero => + ext i + simp only [ushiftRight_rec_zero, twoPow_zero_eq_one, Nat.reduceAdd, truncate_one_eq_ofBool_getLsb] + have heq : (y &&& 1#w₂) = zeroExtend w₂ (ofBool (y.getLsb 0)) := by + ext i + by_cases h : (↑i : Nat) = 0 <;> simp [h, Bool.and_comm] + simp [heq] + case succ n ih => + simp + by_cases h : y.getLsb (n + 1) <;> simp [h] + · rw [ih (hn := by omega)] + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true _ _ h] + rw [ushiftRight_or_eq_ushiftRight_ushiftRight_of_and_eq_zero] + · simp + · simp; + have hpow : 2 ^ (n + 1) < 2 ^ w₂ := by + apply Nat.pow_lt_pow_of_lt (by decide) (by omega) + have h₂ : 2 ^ (n + 1) % 2 ^ w₂ = 2 ^ (n + 1) := Nat.mod_eq_of_lt (by omega) + have h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + apply Nat.mod_lt + apply Nat.pow_pos (by decide); omega + obtain h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) <;> omega + rw [h₁, h₂] + rcases w₂ with rfl | w₂ + · omega + · apply Nat.add_lt_add_of_lt_of_le + · simp only [pow_eq, Nat.mul_eq, Nat.mul_one] + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + · apply Nat.mod_lt + · apply Nat.pow_pos (by decide) + · apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · simp + apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · rw [ih (hn := by omega)] + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)] + simp [h] + +#print axioms ushiftRight_rec_eq + +theorem shiftRight_eq_shiftRight_rec (x : BitVec ℘) (y : BitVec w₂) : + x >>> y = ushiftRight_rec x y (w₂ - 1) := by + rcases w₂ with rfl | w₂ + · simp [of_length_zero] + · simp [ushiftRight_rec_eq x y w₂ (by omega)] + + +/- ### Arithmetic (sshiftRight) recurrence -/ + +@[simp] +theorem sshiftRight'_zero (x : BitVec w) : + x.sshiftRight' (0#w₂) = x := by + ext i + rw [sshiftRight', getLsb_sshiftRight] + simp + +def sshiftRightRec (x : BitVec w) (y : BitVec w₂) (n : Nat) : BitVec w := + let shiftAmt := (y &&& (twoPow w₂ n)) + match n with + | 0 => x.sshiftRight' shiftAmt + | n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt + +@[simp] +theorem sshiftRightRec_zero_eq (x : BitVec w) (y : BitVec w₂) : + sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by + simp [sshiftRightRec] + +@[simp] +theorem sshiftRightRec_succ_eq (x : BitVec w) (y : BitVec w₂) (n : Nat) : + sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by + simp [sshiftRightRec] + +/-- The msb after arithmetic shifting right equals the original msb. -/ +theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} : + (x.sshiftRight n).msb = x.msb := by + rw [msb_eq_getLsb_last, getLsb_sshiftRight] + rcases w with rfl | w + · simp + · simp only [Nat.add_sub_cancel] + simp [show ¬ (w + 1 ≤ w) by omega] + intros h + rw [msb_eq_getLsb_last] + simp only [Nat.add_sub_cancel] + simp [show n + w = w by omega] + +theorem sshiftRight_sshiftRight {x : BitVec w} {m n : Nat} : + (x.sshiftRight m).sshiftRight n = x.sshiftRight (m + n) := by + ext i + simp only [getLsb_sshiftRight] + simp only [Nat.add_assoc] + by_cases h₁ : w ≤ (i : Nat) + · simp [h₁] + · simp only [h₁, decide_False, Bool.not_false, Bool.true_and] + by_cases h₂ : n + ↑i < w + · simp [h₂] + · simp only [h₂, ↓reduceIte] + by_cases h₃ : m + (n + ↑i) < w + · simp [h₃] + omega + · simp [h₃] + apply sshiftRight_msb_eq_msb + +theorem sshiftRight'_sshiftRight' {x : BitVec w₁} {y : BitVec w₂} {z : BitVec w₃} : + (x.sshiftRight' y).sshiftRight' z = x.sshiftRight (y.toNat + z.toNat) := by + simp [sshiftRight', shiftRight_shiftRight, sshiftRight_sshiftRight] + + +theorem sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂} + (h : y &&& z = 0#w₂) (h' : y.toNat + z.toNat < 2^w₂): + x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by + simp [← add_eq_or_of_and_eq_zero _ _ h] + simp [BitVec.sshiftRight'] + simp [sshiftRight_sshiftRight] + rw [Nat.mod_eq_of_lt h'] + +theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) (hn : n + 1 ≤ w₂) : + sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by + induction n generalizing x y + case zero => + ext i + simp only [ushiftRight_rec_zero, twoPow_zero_eq_one, Nat.reduceAdd, truncate_one_eq_ofBool_getLsb] + have heq : (y &&& 1#w₂) = zeroExtend w₂ (ofBool (y.getLsb 0)) := by + ext i + by_cases h : (↑i : Nat) = 0 <;> simp [h, Bool.and_comm] + simp [heq] + case succ n ih => + simp + by_cases h : y.getLsb (n + 1) <;> simp [h] + · rw [ih (hn := by omega)] + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true _ _ h] + rw [sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero] + · simp + · simp; + have hpow : 2 ^ (n + 1) < 2 ^ w₂ := by + apply Nat.pow_lt_pow_of_lt (by decide) (by omega) + have h₂ : 2 ^ (n + 1) % 2 ^ w₂ = 2 ^ (n + 1) := Nat.mod_eq_of_lt (by omega) + have h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + apply Nat.mod_lt + apply Nat.pow_pos (by decide); omega + obtain h₁ : y.toNat % 2 ^ (n + 1) % 2 ^ w₂ = y.toNat % 2 ^ (n + 1) := by + apply Nat.mod_eq_of_lt + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) <;> omega + rw [h₁, h₂] + rcases w₂ with rfl | w₂ + · omega + · apply Nat.add_lt_add_of_lt_of_le + · simp only [pow_eq, Nat.mul_eq, Nat.mul_one] + apply Nat.lt_of_lt_of_le (m := 2 ^ (n + 1)) + · apply Nat.mod_lt + · apply Nat.pow_pos (by decide) + · apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · simp + apply Nat.pow_le_pow_of_le_right (by decide) (by omega) + · rw [ih (hn := by omega)] + rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)] + simp [h] + + +theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) : + (x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by + rcases w₂ with rfl | w₂ + · simp [of_length_zero] + · simp [sshiftRightRec_eq x y w₂ (by omega)] + +/- ## udiv/urem bitblasting -/ + +/- +r = n - d * q +r = n - d * (∑ i, 2^i * q.getLsb i) + +-/ + + +/-! +Let us study an instructive counterexample to the claim that + `n = d * q + r` for (`0 ≤ r < d`) uniquely determining q and r *over bitvectors*. + +- Let `bitwidth = 3` +- Let `n = 0, d = 3` +- If we choose `q = 2, r = 2`, then d * q + r = 6 + 2 = 8 ≃ 0 (mod 8) so satisfies. +- But see that `q = 0, r = 0` also satisfies, as 0 * 3 + 0 = 0. +- So for (`n = 0, d = 3`), both: + `q = 2, r = 2` as well as + `q = 0, r = 0` are solutions! + +It's easy to cook up such examples, by chosing `(q, r)` for a fixed `(d, n)` +such that `(d * q + r)` overflows. +-/ + +/-! +References: +- Fast 32-bit Division on the DSP56800E: Minimized nonrestoring division algorithm by David Baca +- Bitwuzla sources for bitblasting.h +-/ + + +/-- TODO: This theorem surely exists somewhere. -/ +theorem Nat.div_add_eq_left_of_lt {x y z : Nat} (hx : z ∣ x) (hy : y < z) (hz : 0 < z): + (x + y) / z = x / z := by + refine Nat.div_eq_of_lt_le ?lo ?hi + · apply Nat.le_trans + · exact div_mul_le_self x z + · omega + · simp only [succ_eq_add_one, Nat.add_mul, Nat.one_mul] + apply Nat.add_lt_add_of_le_of_lt + · apply Nat.le_of_eq + exact (Nat.div_eq_iff_eq_mul_left hz hx).mp rfl + · exact hy + +theorem div_characterized_of_mul_add_toNat {d n q r : BitVec w} (hd : 0 < d) + (hrd : r < d) + (hdqnr : d.toNat * q.toNat + r.toNat = n.toNat) : + (n.udiv d = q ∧ n.umod d = r) := by + constructor + · apply BitVec.eq_of_toNat_eq + rw [toNat_udiv hd] + replace hdqnr : (d.toNat * q.toNat + r.toNat) / d.toNat = n.toNat / d.toNat := by + simp [hdqnr] + rw [Nat.div_add_eq_left_of_lt] at hdqnr + · rw [← hdqnr] + exact mul_div_right q.toNat hd + · exact Nat.dvd_mul_right d.toNat q.toNat + · exact hrd + · exact hd + · apply BitVec.eq_of_toNat_eq + rw [toNat_umod] + replace hdqnr : (d.toNat * q.toNat + r.toNat) % d.toNat = n.toNat % d.toNat := by + simp [hdqnr] + rw [Nat.add_mod, Nat.mul_mod_right] at hdqnr + simp at hdqnr + replace hrd : r.toNat < d.toNat := by + rw [BitVec.lt_def] at hrd + exact hrd -- TODO: golf + rw [Nat.mod_eq_of_lt hrd] at hdqnr + simp [hdqnr] + +theorem div_characterized_of_mul_add_of_lt {d n q r : BitVec w} (hd : 0 < d) + (hrd : r < d) + (hdqnr : d * q + r = n) + (hlt : d.toNat * q.toNat + r.toNat < 2^w) : + (n.udiv d = q ∧ n.umod d = r) := by + apply div_characterized_of_mul_add_toNat <;> try assumption + apply Eq.symm + have hlt' : d.toNat * q.toNat < 2^w := by omega + calc + n.toNat = (d * q + r).toNat := by rw [← hdqnr] + _ = ((d * q).toNat + r.toNat) % 2^w := by simp [BitVec.toNat_add] + _ = ((d.toNat * q.toNat) % 2^w + r.toNat) % 2^w := by simp [BitVec.toNat_mul] + _ = ((d.toNat * q.toNat) + r.toNat) % 2^w := by simp [Nat.mod_eq_of_lt hlt'] + _ = ((d.toNat * q.toNat) + r.toNat) := by simp [Nat.mod_eq_of_lt hlt] + +theorem div_characterized_toNat_of_eq_udiv_of_eq_umod {d n q r : BitVec w} (hd : 0 < d) + (hq : n.udiv d = q) (hr : n.umod d = r) : + (d.toNat * q.toNat + r.toNat = n.toNat) := by + have hdiv : n.toNat / d.toNat = q.toNat := by + rw [← toNat_udiv hd] -- TODO: squeeze + rw [(toNat_eq _ _).mp hq] + have hmod : n.toNat % d.toNat = r.toNat := by + rw [← toNat_umod] -- TODO: squeeze + rw [(toNat_eq _ _).mp hr] + rw [← hdiv, ← hmod] -- TODO: flip + rw [div_add_mod] + +theorem div_characterized_toNat_of_eq_udiv_of_eq_umod_of_lt {d n q r : BitVec w} (hd : 0 < d) + (hq : n.udiv d = q) (hr : n.umod d = r) + (hlt : d.toNat * q.toNat + r.toNat < 2^w) : + d * q + r = n := by + apply eq_of_toNat_eq + simp [toNat_add, toNat_mul] + rw [Nat.mod_eq_of_lt hlt] + apply div_characterized_toNat_of_eq_udiv_of_eq_umod hd hq hr + +theorem div_iff_add_mod_of_lt {d n q r : BitVec w} (hd : 0 < d) + (hrd : r < d) + (hlt : d.toNat * q.toNat + r.toNat < 2^w) : + (n.udiv d = q ∧ n.umod d = r) ↔ (d * q + r = n) := by + constructor + · intros h; obtain ⟨h₁, h₂⟩ := h + apply div_characterized_toNat_of_eq_udiv_of_eq_umod_of_lt <;> assumption + · intros h + apply div_characterized_of_mul_add_of_lt <;> assumption + +/-# Tons of Lemmas for Proving Bitblasting Correct -/ + + + +theorem BitVec.shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) : + x <<< n = x * (BitVec.twoPow w n) := by + ext i + simp + + +@[simp] +theorem BitVec.or_zero (x : BitVec w) : x ||| 0#w = x := by + ext i + simp + + +theorem BitVec.sub_le_self_of_le {x y : BitVec w} (hx : y ≤ x) : x - y ≤ x := by + simp [BitVec.lt_def, BitVec.le_def] at hx ⊢ + rw [← Nat.add_sub_assoc (by omega)] + rw [Nat.add_comm] + rw [Nat.add_sub_assoc (by omega)] + rw [Nat.add_mod] + simp only [mod_self, Nat.zero_add, mod_mod] + rw [Nat.mod_eq_of_lt] <;> omega + +theorem BitVec.sub_lt_self_of_lt_of_lt {x y : BitVec w} (hx : y < x) (hy : 0 < y): x - y < x := by + simp [BitVec.lt_def] at hx hy ⊢ + rw [← Nat.add_sub_assoc (by omega)] + rw [Nat.add_comm] + rw [Nat.add_sub_assoc (by omega)] + rw [Nat.add_mod] + simp only [mod_self, Nat.zero_add, mod_mod] + rw [Nat.mod_eq_of_lt] <;> omega + +theorem BitVec.le_iff_not_lt {x y : BitVec w} : (¬ x < y) ↔ y ≤ x := by + constructor <;> + (intro h; simp [BitVec.lt_def, BitVec.le_def] at h ⊢; omega) + +@[simp] +theorem BitVec.le_refl (x : BitVec w) : x ≤ x := by + simp [BitVec.le_def] + + +theorem BitVec.shiftLeft_mul_comm (x y : BitVec w) (n : Nat) : + x <<< n * y = x * y <<< n := by + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.mul_assoc] + congr 1 + apply BitVec.mul_comm + +theorem BitVec.shiftLeft_mul_assoc (x y : BitVec w) (n : Nat) : + x * y <<< n = (x * y) <<< n := by + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.shiftLeft_eq_mul_twoPow] + rw [BitVec.mul_assoc] + +theorem BitVec.add_mul (x y z : BitVec w) : (y + z) * x = y * x + z * x := by + conv => + lhs + rw [BitVec.mul_comm, BitVec.mul_add] + congr 1 <;> rw [BitVec.mul_comm] + +theorem BitVec.add_assoc {x y z : BitVec w} : x + y + z = x + (y + z) := by + apply eq_of_toNat_eq + simp[Nat.add_assoc] + +theorem BitVec.add_sub_assoc {m k : BitVec w} (h : k ≤ m) (n : BitVec w) : + n + m - k = n + (m - k) := by + apply BitVec.eq_of_toNat_eq + simp only [toNat_sub, toNat_add, mod_add_mod, add_mod_mod, Nat.add_assoc] + +/-- +Bitwise or of (x <<< 1) with 1 is the same as addition. +This is useful to reason in mixed-arithmetic bitwise contexts. +-/ +private theorem BitVec.shiftLeft_one_or_one_eq_shiftLeft_one_add_one {x : BitVec w} : + x <<< 1 ||| 1#w = (x <<< 1) + 1#w := by + rw [BitVec.add_eq_or_of_and_eq_zero] + ext i + simp + intro i _ hi' + omega + +theorem BitVec.add_sub_self_left {x y : BitVec w} : x + y - x = y := by + apply eq_of_toNat_eq + simp + calc + (x.toNat + y.toNat + (2 ^ w - x.toNat)) % 2 ^ w = (x.toNat + y.toNat + 2 ^ w - x.toNat) % 2 ^ w := by + rw [Nat.add_sub_assoc (Nat.le_of_lt x.isLt)] + _ = (x.toNat + y.toNat - x.toNat + 2 ^ w) % 2 ^ w := by rw [Nat.sub_add_comm]; omega + _ = (y.toNat + 2 ^ w) % 2 ^ w := by rw [Nat.add_sub_self_left] + _ = y.toNat % 2 ^ w := by simp + _ = y.toNat := by simp [Nat.mod_eq_of_lt] + +theorem BitVec.add_sub_self_right {x y : BitVec w} : x + y - y = x := by + rw [BitVec.add_comm] + rw [BitVec.add_sub_self_left] + +@[simp] +theorem BitVec.le_of_not_lt {x y : BitVec w} : ¬ x < y → y ≤ x := by + simp [BitVec.lt_def, BitVec.le_def] + +/-- +if the MSB is false, then the arithmetic value of shifting +is the same as the original value times 2. +That is, if the msb is false, then shifting by 1 does not overflow. +Can be generalized to talk about shifting by `k` if the top `k` bits are false. +-/ +theorem BitVec.toNat_shiftLeft_one_eq_mul_two_of_msb_false + (x : BitVec w) + (h : x.msb = false) : + (x <<< 1).toNat = x.toNat * 2 := by + simp only [toNat_shiftLeft] + have h := (BitVec.msb_eq_false_iff_two_mul_lt x).mp h + rw [Nat.shiftLeft_eq, Nat.mod_eq_of_lt (by omega)] + +/- upon shifting left by one, if times 2 is less than 2^w, then we cannot overflow. -/ +theorem BitVec.toNat_shiftLeft_one_eq_mul_two_of_lt + (x : BitVec w) + (hlt : x.toNat * 2 < 2 ^ w) : + (x <<< 1).toNat = x.toNat * 2 := by + simp only [toNat_shiftLeft] + rw [Nat.shiftLeft_eq, Nat.mod_eq_of_lt (by omega)] + +/-- +The arithmetic version of: +If `n : Bitvec w` has only the low `k < w` bits set, +then `(n <<< 1 | b)` does not overflow. +-/ +theorem mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two + (hn : n < 2 ^ k) (hb : b < 2) (hk : k < w) : + n * 2 + b < 2 ^ w := by + have : k + 1 ≤ w := by omega + have : 2^(k + 1) ≤ 2 ^w := by + apply Nat.pow_le_pow_of_le_right (by decide) (by assumption) + have : n ≤ 2 ^k - 1 := by omega + have : n * 2 ≤ 2^k * 2 - 2 := by omega + have : n * 2 + b ≤ 2^k * 2 - 1 := by omega + have : n * 2 + b ≤ 2 ^(k + 1) - 1 := by omega + have : n * 2 + b ≤ 2 ^w - 1 := by omega + have : n * 2 + b < 2^w := by omega + assumption + +/-- +This is used when proving the correctness of the divison algorithm, +where we know that `r < d`. +We then want to show that `r <<< 1 | b - d < d` as the loop invariant. +In arithmethic, this is the same as showing that +`r * 2 + 1 - d < d`, which this theorem establishes. +-/ +theorem two_mul_add_sub_lt_of_lt_of_lt_two -- HERE HERE + (h : a < x) (hy : y < 2): + 2 * a + y - x < x := by omega + +/-- +Variant of `BitVec.toNat_sub` that does not introduce a modulo. +-/ +theorem BitVec.toNat_sub_of_lt {x y : BitVec w} (hy : y ≤ x) : + (x - y).toNat = x.toNat - y.toNat := by + simp only [toNat_sub] + rw [← Nat.add_sub_assoc] + · rw [Nat.sub_add_comm] + · rw [Nat.add_mod] + simp only [mod_self, Nat.add_zero, mod_mod] + rw [Nat.mod_eq_of_lt] + omega + · simp only [le_def] at hy + omega + · omega + +/-- +If `n : Bitvec w` has only the low `k < w` bits set, +then `(n <<< 1 | b)` does not overflow, and we can compute its value +as a multiply and add. +-/ +theorem toNat_shiftLeft_or_zeroExtend_ofBool_eq (w : Nat) + (r : BitVec w) + (b : Bool) + (hk : k < w) + (hr : r.toNat < 2 ^ k) : + (r <<< 1 ||| zeroExtend w (ofBool b)).toNat = + (r.toNat * 2 + b.toNat) := by + have : b.toNat = if b then 1 else 0 := by rcases b <;> rfl + rw [this] + have hk' : 2^k < 2^w := by + apply Nat.pow_lt_pow_of_lt (by decide) (by omega) + rcases w with rfl | w + · omega -- contradiction, k < w + · rw [← BitVec.add_eq_or_of_and_eq_zero] + · simp only [toNat_add, toNat_shiftLeft, toNat_truncate, toNat_ofBool, toNat, add_mod_mod, + mod_add_mod] + rw [Nat.shiftLeft_eq] + simp only [show (2 ^ 1 = 2) by decide] + rw [Nat.mod_eq_of_lt] + · rcases b with rfl | rfl <;> simp + · apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two + · exact hr + · rcases b <;> decide + · assumption + · ext i + simp only [getLsb_and, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, + getLsb_zeroExtend, getLsb_ofBool, getLsb_zero, and_eq_false_imp, and_eq_true, not_eq_true', + decide_eq_false_iff_not, Nat.not_lt, decide_eq_true_eq, and_imp] + intros hi _ hi' + omega + +/- # DivRem, V3 -/ +structure DivRemInput (w wr wn : Nat) + (n : BitVec w) + (d : BitVec w) : Type where + q : BitVec w + r : BitVec w + hwr : wr ≤ w + hwn : wn ≤ w + hwrn : wr + wn = w + hd : 0 < d + hrd : r.toNat < d.toNat + hrwr : r.toNat < 2^wr + hqwr : q.toNat < 2^wr + hdiv : n.toNat >>> wn = d.toNat * q.toNat + r.toNat + +/-- In a valid DivRemInput, it is implied that `w > 0`. -/ +def DivRemInput.hw (h : DivRemInput w wr wn n d) : 0 < w := by + have hd := h.hd + rcases w with rfl | w + · have hcontra : d = 0#0 := by apply Subsingleton.elim + rw [hcontra] at hd + simp at hd + · omega + +/-- +Make an initial state of the DivRemInput, for a given choice of +`n, d, q, r`. -/ +def DivRemInput_init (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) : + DivRemInput w 0 w n d:= { + q := 0#w + r := 0#w + hwr := by omega, + hwn := by omega, + hwrn := by omega, + hd := by assumption + hrd := by simp [BitVec.lt_def] at hd ⊢; assumption + hrwr := by simp, + hqwr := by simp, + hdiv := by + simp; + rw [Nat.shiftRight_eq_div_pow] + apply Nat.div_eq_of_lt n.isLt +} + +@[simp] +theorem DivRemInput_init_q (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) : + (DivRemInput_init w n d hw hd).q = 0#w := by + rfl + +@[simp] +theorem DivRemInput_init_r (w : Nat) (n d : BitVec w) (hw : 0 < w) (hd : 0#w < d) : + (DivRemInput_init w n d hw hd).r = 0#w := by + rfl + +theorem DivRemInput_implies_udiv_urem + (h : DivRemInput w w 0 n d) : + n.udiv d = h.q ∧ n.umod d = h.r := by + apply div_characterized_of_mul_add_toNat + (n := n) (d := d) (q := h.q) (r := h.r) + (h.hd) + (h.hrd) + (by + have hdiv := h.hdiv + simp at hdiv + omega + ) + +structure ShiftSubtractInput (w wr wn : Nat) (n d: BitVec w) + extends DivRemInput w wr wn n d : Type where + hwn_lt : 0 < wn -- we can only call this function legally if we have dividend bits. + + +/-- +In the shift subtract input, we have one more bit to spare, +so we do not overflow. +-/ +def ShiftSubtractInput.wr_add_one_le_w + (h : ShiftSubtractInput w wr wn n d) : wr + 1 ≤ w := by + have hwrn := h.hwrn + have hwn_lt := h.hwn_lt + omega + +def ShiftSubtractInput.wr_lt_w + (h : ShiftSubtractInput w wr wn n d) : wr < w := by + have hwr := h.wr_add_one_le_w + omega + +/-- +In the shift subtract input, we have one more bit to spare, +so we do not overflow. +-/ +def ShiftSubtractInput.wr_le_wr_sub_one + (h : ShiftSubtractInput w wr wn n d) : wr ≤ w - 1 := by + have hw := h.hw + have hwrn := h.hwrn + have hwn_lt := h.hwn_lt + omega + +/-- If we have extra bits to spare in `n`, +then the div rem input can be converted into a shift subtract input +to run a round of the shift subtracter. -/ +def DivRemInput.toShiftSubtractInput + (h : DivRemInput w wr (wn + 1) n d) : + ShiftSubtractInput w wr (wn + 1) n d := { + q := h.q, + r := h.r + hwr := h.hwr, + hwn := h.hwn, + hwrn := by have := h.hwrn; omega, + hd := h.hd, + hrd := h.hrd, + hrwr := h.hrwr, + hqwr := h.hqwr, + hdiv := h.hdiv, + hwn_lt := by omega + } + +def ShiftSubtractInput.nmsb (_ : ShiftSubtractInput w wr wn n d) : + Bool := n.getLsb (wn - 1) + +def DivRemInput.wr_eq_w_of_wn_eq_zero + (h : DivRemInput w wr 0 n d) : DivRemInput w w 0 n d := + { + q := h.q, + r := h.r, + hwr := by have := h.hwr; omega, + hwn := h.hwn, + hwrn := by have := h.hwrn; omega, + hd := h.hd, + hrd := h.hrd, + hrwr := by have := h.hrwr; omega, + hqwr := by have := h.hqwr; omega, + hdiv := h.hdiv + } + + +/- # Division Recurrence for Bitblasting (V2 )-/ + +def concatBit' (x : BitVec w) (b : Bool) : BitVec w := + x <<< 1 ||| (BitVec.ofBool b).zeroExtend w + +theorem concatBit'_lt (x : BitVec w) (b : Bool) : + (concatBit' x b).toNat < 2 ^ w := (concatBit' x b).isLt + +theorem toNat_concatBit'_eq (x : BitVec w) (b : Bool) (k : Nat) + (hk : k < w) (hx : x.toNat < 2 ^ k) : + (concatBit' x b).toNat = x.toNat * 2 + b.toNat:= by + simp only [concatBit'] + rw [toNat_shiftLeft_or_zeroExtend_ofBool_eq (k := k)] + · omega + · omega + +theorem toNat_concatBit'_false_eq (x : BitVec w) (k : Nat) + (hk : k < w) (hx : x.toNat < 2 ^ k) : + (concatBit' x false).toNat = x.toNat * 2 := by + rw [toNat_concatBit'_eq (k := k) (hk := hk) (hx := hx)] + simp + +theorem toNat_concatBit'_lt (x : BitVec w) (b : Bool) (k : Nat) + (hk : k < w) (hx : x.toNat < 2 ^ k) : + (concatBit' x b).toNat < 2 ^ (k + 1) := by + rw [toNat_concatBit'_eq x b k hk hx] + apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two hx + · rcases b with rfl | rfl <;> decide + · omega + +private theorem BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_or_zeroExtend_getLsb + {x : BitVec w} {k : Nat} (hk' : 0 < k) : + x >>> (k - 1) = ((x >>> k <<< 1) ||| ((BitVec.ofBool (x.getLsb (k - 1))).zeroExtend w)) := by + ext i + simp only [getLsb_ushiftRight, getLsb_or, getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, + getLsb_zeroExtend, getLsb_ofBool] + by_cases (i : Nat) < 1 + case pos h => + simp only [h, decide_True, Bool.not_true, Bool.false_and] + have hi : (i : Nat) = 0 := by omega + simp [hi] + case neg h => + simp only [h, decide_False, Bool.not_false, Bool.true_and] + have hi : (i : Nat) ≠ 0 := by omega + simp only [hi, decide_False, Bool.false_and, Bool.or_false] + congr 1 + omega + +theorem ShiftSubtractInput.n_shiftr_wl_minus_one_eq_n_shiftr_wl_or_nmsb + (h : ShiftSubtractInput w wr wn n d) : + n >>> (wn - 1) = (n >>> wn).concatBit' (ShiftSubtractInput.nmsb h) := by + rw [concatBit'] + rw [ShiftSubtractInput.nmsb] + rw [BitVec.shiftLeft_sub_eq_shiftLeft_shiftRight_or_zeroExtend_getLsb] + have hwn_lt := h.hwn_lt + omega + +/-- +Shifting right by `n < w` yields a bitvector whose value +is less than `2^(w - n)` +-/ +theorem BitVec.ushiftRight_lt (x : BitVec w) (n : Nat) (hn : n ≤ w) : + (x >>> n).toNat < 2 ^ (w - n) := by + rw [toNat_ushiftRight] + rw [shiftRight_eq_div_pow] + rw [Nat.div_lt_iff_lt_mul] + · rw [Nat.pow_sub_mul_pow] + · apply x.isLt + · apply hn + · apply Nat.pow_pos (by decide) + +/-- The value of shifting by `wn - 1` equals +shifting by `wn` and grabbing the lsb at (wn - 1) -/ +theorem ShiftSubtractInput.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb + (h : ShiftSubtractInput w wr wn n d) : + n.toNat >>> (wn - 1) = (n.toNat >>> wn) * 2 + h.nmsb.toNat := by + have hn := ShiftSubtractInput.n_shiftr_wl_minus_one_eq_n_shiftr_wl_or_nmsb h + obtain hn : (n >>> (wn - 1)).toNat = ((n >>> wn).concatBit' h.nmsb).toNat := by + simp [hn] + simp at hn + rw [toNat_concatBit'_eq (k := w - wn)] at hn + · rw [hn] + rw [toNat_ushiftRight] + · have := h.hwn_lt + have := h.hw + omega + · apply BitVec.ushiftRight_lt + have := h.hwrn + omega + +/-- +One round of the division algorithm, that tries to perform a subtract shift. +Note that this is only called when `r.msb = false`, so we will not overflow. +This means that `r'.toNat = r.toNat *2 + q.toNat` +-/ +def divSubtractShift (h : ShiftSubtractInput w wr wn n d) : + DivRemInput w (wr + 1) (wn - 1) n d := + let r' := concatBit' h.r h.nmsb + let rltd : Bool := r' < d -- true if r' < d. In this case, we don't have a quotient bit. + let q := h.q.concatBit' !rltd -- if r ≥ d, then we have a quotient bit. + if hrltd : rltd + then { + q := q, + r := r', + hwr := by + have := h.hwr + have := h.wr_add_one_le_w + omega, + hwn := by + have := h.hwn + omega, + hwrn := by + have := h.hwrn + have := h.wr_add_one_le_w + omega, + hd := h.hd, + hrd := by + simp [rltd] at hrltd + simp [BitVec.lt_def] at hrltd + assumption, + hrwr := by + simp [r'] + apply toNat_concatBit'_lt + · exact h.wr_add_one_le_w + · exact h.hrwr, + hqwr := by + simp [q] + apply toNat_concatBit'_lt + · exact h.wr_add_one_le_w + · exact h.hqwr, + hdiv := by + rw [h.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb] + simp only [r'] + rw [h.hdiv] + rw [toNat_concatBit'_eq (x := h.r) + (k := wr) + (hk := h.wr_lt_w) + (hx := h.hrwr)] + simp only [q] + simp only [hrltd, Bool.not_true] + have hq' := toNat_concatBit'_false_eq h.q wr h.wr_lt_w h.hqwr + rw [hq'] + rw [← Nat.mul_assoc] + rw [Nat.add_mul] + rw [Nat.add_assoc] + } + else { + q := q, + r := r' - d, + hwr := by + have := h.hwr + have := h.wr_add_one_le_w + omega, + hwn := by + have := h.hwn + omega, + hwrn := by + have := h.hwrn + have := h.wr_add_one_le_w + omega, + hd := h.hd, + hrd := by + simp [rltd] at hrltd + simp [BitVec.lt_def] at hrltd + have hr := h.hrd + -- | TODO: make this a field. + have hr' : h.r < d := by simp [BitVec.lt_def]; exact hr + rw [BitVec.toNat_sub_of_lt hrltd] + simp only [r'] + rw [toNat_concatBit'_eq (x := h.r) + (k := wr) + (hk := h.wr_lt_w) + (hx := h.hrwr)] + rw [Nat.mul_comm] -- TODO: canonicalize an order between w*2 and 2*w + apply two_mul_add_sub_lt_of_lt_of_lt_two + · exact hr + · apply Bool.toNat_lt + hrwr := by + simp only [r'] + /- TODO: this proof is repeated, lift it to above the structure building. -/ + have hdr' : ¬ (r' < d) := by + simp [rltd] at hrltd + assumption + have hdr' : d ≤ r' := BitVec.le_iff_not_lt.mp hdr' + rw [BitVec.toNat_sub_of_lt hdr'] + have hr' : r'.toNat < 2 ^ (wr + 1) := by + simp [r'] + apply toNat_concatBit'_lt + · exact h.wr_add_one_le_w + · exact h.hrwr + omega + hqwr := by + simp [q] + apply toNat_concatBit'_lt + · exact h.wr_add_one_le_w + · exact h.hqwr, + hdiv := by + rw [h.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb] + have hdr' : ¬ (r' < d) := by + simp [rltd] at hrltd + assumption + have hdr' : d ≤ r' := BitVec.le_iff_not_lt.mp hdr' + rw [BitVec.toNat_sub_of_lt hdr'] + simp only [r'] + rw [h.hdiv] + rw [toNat_concatBit'_eq (x := h.r) + (k := wr) + (hk := h.wr_lt_w) + (hx := h.hrwr)] + simp only [q] + rw [toNat_concatBit'_eq (x := h.q) + (k := wr) + (hk := h.wr_lt_w) + (hx := h.hqwr)] + simp only [hrltd, Bool.not_false, toNat_true] + simp [Nat.mul_add] + apply Eq.symm + calc + _ = d.toNat * (h.q.toNat * 2) + d.toNat + (h.r.toNat * 2 + h.nmsb.toNat - d.toNat) := + by rfl + _ = d.toNat * (h.q.toNat * 2) + d.toNat - d.toNat + (h.r.toNat * 2 + h.nmsb.toNat) := by + simp + rw [Nat.add_assoc] + congr 1 + rw [Nat.add_sub_cancel'] + simp only [r'] at hdr' + simp only [BitVec.le_def] at hdr' + rw [BitVec.toNat_concatBit'_eq + (x := h.r) + (b := h.nmsb) + (k := wr) + (hk := h.wr_lt_w) + (hx := h.hrwr)] at hdr' + assumption + _ = d.toNat * (h.q.toNat * 2) + (h.r.toNat * 2 + h.nmsb.toNat) := by + rw [Nat.add_sub_cancel] + _ = (d.toNat * h.q.toNat + h.r.toNat) * 2 + h.nmsb.toNat := by + rw [← Nat.add_assoc] + rw [← Nat.mul_assoc] + rw [Nat.add_mul] + _ = (d.toNat * h.q.toNat + h.r.toNat) * 2 + h.nmsb.toNat := rfl + } + +/-- info: 'BitVec.divSubtractShift' depends on axioms: [propext, Classical.choice, Quot.sound] -/ +#guard_msgs in #print axioms divSubtractShift + +/-- +Core divsion recurrence. +We have three widths at play: +- w, the total bitwidth +- wr, the effective bitwidth of the reminder +- wn, the effective bitwidth of the dividend. +- We have the invariant that wn + wr = w. + +See that when it is called, we will know that : + - r < [2^wr = 2^(w - wn)] + which allows us to safely shift left, since it is of length n. + In particular, since 'wn' decreases in the course of the recursion, + will will allow larger and larger values, and at the step where 'wn = 0', + we will have `r < 2^w`, which is no longer sufficient to allow for a shift left. + Thus, at this step, we will stop and return a full remainder. + So, the remainder is morally of length `w - wn`. + - d > 0 + - r < d + - n.toNat >>> wr = +-/ +def divRec' (h : DivRemInput w wr wn n d) : + DivRemInput w w 0 n d := + match wn with + | 0 => h.wr_eq_w_of_wn_eq_zero + | _ + 1 => + let new := divSubtractShift h.toShiftSubtractInput + divRec' new + +/-- info: 'BitVec.divRec'' depends on axioms: [propext, Classical.choice, Quot.sound] -/ +#guard_msgs in #print axioms divRec' + +theorem divRec'_correct (n d : BitVec w) (hw : 0 < w) (hd : 0 < d) : + let out := divRec' (DivRemInput_init w n d hw hd) + n.udiv d = out.q ∧ n.umod d = out.r := by + simp + apply DivRemInput_implies_udiv_urem + +def divSubtractShiftNonDep (n q r d : BitVec w) (wn : Nat) : BitVec w × BitVec w := + let r' := concatBit' r (n.getLsb (wn - 1)) + let rltd : Bool := r' < d + let q := q.concatBit' !rltd + if rltd + then (q, r') + else (q, r' - d) + +@[simp] +theorem DivRemInput.toShiftSubtractInput_r_eq_r + (h : DivRemInput w wr (wn + 1) n d) : + (h.toShiftSubtractInput).r = h.r := by + simp [toShiftSubtractInput] + +@[simp] +theorem DivRemInput.toShiftSubtractInput_q_eq_q + (h : DivRemInput w wr (wn + 1) n d) : + (h.toShiftSubtractInput).q = h.q := by + simp only [toShiftSubtractInput] + +theorem divSubtractShift_eq_divSubtractShiftNonDep + (h : ShiftSubtractInput w wr wn n d) : + ((divSubtractShift h).q, (divSubtractShift h).r) = divSubtractShiftNonDep n h.q h.r d wn := by + simp [divSubtractShift, divSubtractShiftNonDep, ShiftSubtractInput.nmsb] + by_cases h : h.r.concatBit' (n.getLsb (wn - 1)) < d <;> + simp only [h, ↓reduceDite, decide_True, Bool.not_true, ↓reduceIte] + +@[simp] +theorem q_divSubtractShift_eq_fst_divSubtractShiftNonDep' + (h : DivRemInput w wr (wn + 1) n d) : + (divSubtractShift h.toShiftSubtractInput).q = + (divSubtractShiftNonDep n h.q h.r d (wn + 1)).fst := by + simp [divSubtractShift, + divSubtractShiftNonDep, + ShiftSubtractInput.nmsb] + by_cases cond : h.r.concatBit' (n.getLsb wn) < d <;> + simp only [cond, ↓reduceDite, decide_True, Bool.not_true, ↓reduceIte] + +@[simp] +theorem r_divSubtractShift_eq_snd_divSubtractShiftNonDep' + (h : DivRemInput w wr (wn + 1) n d) : + (divSubtractShift h.toShiftSubtractInput).r = + (divSubtractShiftNonDep n h.q h.r d (wn + 1)).snd := by + simp [divSubtractShift, + divSubtractShiftNonDep, + ShiftSubtractInput.nmsb] + by_cases cond : h.r.concatBit' (n.getLsb wn) < d <;> + simp only [cond, ↓reduceDite, decide_True, Bool.not_true, ↓reduceIte] + +theorem divSubtractShift_eq_divSubtractShiftNonDep' + (h : DivRemInput w wr (wn + 1) n d) : + ((divSubtractShift h.toShiftSubtractInput).q, (divSubtractShift h.toShiftSubtractInput).r) = + divSubtractShiftNonDep n h.q h.r d (wn + 1) := by + simp [divSubtractShift, divSubtractShiftNonDep, ShiftSubtractInput.nmsb] + by_cases h : h.r.concatBit' (n.getLsb wn) < d <;> + simp only [h, ↓reduceDite, decide_True, Bool.not_true, ↓reduceIte] + +def divRecNondep (n q r d : BitVec w) (wn : Nat) : + BitVec w × BitVec w := + match wn with + | 0 => (q, r) + | wn + 1 => + let (q', r') := divSubtractShiftNonDep n q r d (wn + 1) + divRecNondep n q' r' d wn + +theorem divRec_eq_divRecNonDep (h h' : DivRemInput w wr wn n d) + (hh' : h.q = h'.q ∧ h.r = h'.r): + ((divRec' h).q, (divRec' h).r) = divRecNondep n h'.q h'.r d wn := by + induction wn generalizing w wr n d + case zero => + simp [divRec', divRecNondep, DivRemInput.wr_eq_w_of_wn_eq_zero] + simp [hh'.1, hh'.2] + case succ wn ih => + simp [divRecNondep, divRec'] + rw[← divSubtractShift_eq_divSubtractShiftNonDep'] + apply ih <;> + simp [q_divSubtractShift_eq_fst_divSubtractShiftNonDep', + r_divSubtractShift_eq_snd_divSubtractShiftNonDep', + hh'.1, hh'.2] + +-- def concatBit' (x : BitVec w) (b : Bool) : BitVec w := +-- x <<< 1 ||| (BitVec.ofBool b).zeroExtend w + +theorem divSubtractShiftNonDep_fst (n q r d : BitVec w) (wn : Nat) : + (divSubtractShiftNonDep n q r d wn).fst = + q.concatBit' !decide (r.concatBit' (n.getLsb (wn - 1)) < d) := by + simp [divSubtractShiftNonDep] + by_cases h : r.concatBit' (n.getLsb (wn - 1)) < d <;> + simp [h] + +theorem divSubtractShiftNonDep_snd (n q r d : BitVec w) (wn : Nat) : + (divSubtractShiftNonDep n q r d wn).snd = + if r.concatBit' (n.getLsb (wn - 1)) < d then r.concatBit' (n.getLsb (wn - 1)) + else r.concatBit' (n.getLsb (wn - 1)) - d := by + simp [divSubtractShiftNonDep] + by_cases h : r.concatBit' (n.getLsb (wn - 1)) < d <;> simp [h] + +theorem divRecNonDep_zero (n q r d : BitVec w) : divRecNondep n q r d 0 = (q, r) := by simp [divRecNondep] + +theorem divRecNonDep_succ (n q r d : BitVec w) (wn : Nat) : + (divRecNondep n q r d (wn + 1) = + divRecNondep n (divSubtractShiftNonDep n q r d (wn + 1)).1 + (divSubtractShiftNonDep n q r d (wn + 1)).2 d wn) := by + simp [divRecNondep, divSubtractShiftNonDep] + +theorem divRecNonDep_correct (n d : BitVec w) (hw : 0 < w) (hd : 0 < d) : + let out := divRecNondep n 0#w 0#w d w + n.udiv d = out.fst ∧ n.umod d = out.snd := by + simp + have heq := divRec_eq_divRecNonDep (DivRemInput_init w n d hw hd) (DivRemInput_init w n d hw hd) + (by simp) + simp at heq + have hcorrect := divRec'_correct n d hw hd + obtain ⟨hqcorrect, hrcorrect⟩ := hcorrect + rw [hqcorrect, hrcorrect] + have heq_q : (divRec' (DivRemInput_init w n d hw hd)).q = + (n.divRecNondep (0#w) (0#w) d w).fst := by + rw [← heq] + have heq_r : (divRec' (DivRemInput_init w n d hw hd)).r = + (n.divRecNondep (0#w) (0#w) d w).snd := by + rw [← heq] + simp [heq_q, heq_r] +/-- +info: 'BitVec.divRecNonDep_correct' depends on axioms: [propext, Classical.choice, Quot.sound] +-/ +#guard_msgs in #print axioms divRecNonDep_correct + end BitVec diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 4b205933df73..48afd622d602 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -163,6 +163,16 @@ theorem toNat_zero (n : Nat) : (0#n).toNat = 0 := by trivial private theorem lt_two_pow_of_le {x m n : Nat} (lt : x < 2 ^ m) (le : m ≤ n) : x < 2 ^ n := Nat.lt_of_lt_of_le lt (Nat.pow_le_pow_of_le_right (by trivial : 0 < 2) le) +@[simp] +theorem getLsb_ofBool (b : Bool) (i : Nat) : (BitVec.ofBool b).getLsb i = ((i = 0) && b) := by + rcases b with rfl | rfl + · simp [ofBool] + · simp [ofBool, getLsb_ofNat] + by_cases hi : (i = 0) + · simp [hi] + · simp [hi] + omega + /-! ### msb -/ @[simp] theorem msb_zero : (0#w).msb = false := by simp [BitVec.msb, getMsb] @@ -408,6 +418,28 @@ theorem msb_zeroExtend (x : BitVec w) : (x.zeroExtend v).msb = (decide (0 < v) & theorem msb_zeroExtend' (x : BitVec w) (h : w ≤ v) : (x.zeroExtend' h).msb = (decide (0 < v) && x.getLsb (v - 1)) := by rw [zeroExtend'_eq, msb_zeroExtend] +/-- zero extending a bitvector to width 1 equals the boolean of the lsb. -/ +theorem zeroExtend_one_eq_ofBool_getLsb_zero (x : BitVec w) : + x.zeroExtend 1 = BitVec.ofBool (x.getLsb 0) := by + ext i + simp [getLsb_zeroExtend, Fin.fin_one_eq_zero i] + +/-- `testBit 1 i` is true iff the index `i` equals 0. -/ +private theorem Nat.testBit_one_eq_true_iff_self_eq_zero {i : Nat} : + Nat.testBit 1 i = true ↔ i = 0 := by + cases i <;> simp + +/-- Zero extending `1#v` to `1#w` equals `1#w` when `v > 0`. -/ +theorem zeroExtend_ofNat_one_eq_ofNat_one_of_lt {v w : Nat} (hv : 0 < v): + (BitVec.ofNat v 1).zeroExtend w = BitVec.ofNat w 1 := by + ext i + obtain ⟨i, hilt⟩ := i + simp only [getLsb_zeroExtend, hilt, decide_True, getLsb_ofNat, Bool.true_and, + Bool.and_iff_right_iff_imp, decide_eq_true_eq] + intros hi1 + have hv := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi1 + omega + /-! ## extractLsb -/ @[simp] @@ -593,6 +625,11 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl @[simp] theorem toFin_shiftLeft {n : Nat} (x : BitVec w) : BitVec.toFin (x <<< n) = Fin.ofNat' (x.toNat <<< n) (Nat.two_pow_pos w) := rfl +@[simp] +theorem shiftLeft_zero_eq (x : BitVec w) : x <<< 0 = x := by + apply eq_of_toNat_eq + simp + @[simp] theorem getLsb_shiftLeft (x : BitVec m) (n) : getLsb (x <<< n) i = (decide (i < m) && !decide (i < n) && getLsb x (i - n)) := by rw [← testBit_toNat, getLsb] @@ -726,6 +763,49 @@ theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) : Nat.not_lt, decide_eq_true_eq] omega +/-- The arithmetic shift right equals the msb when `s + i ≥ w`, and equals the logical shift right when `s + i < w. -/ +theorem getLsb_sshiftRight_eq_getLsb_ushiftRight (x : BitVec w) (s i : Nat) : + getLsb (x.sshiftRight s) i = (!decide (w ≤ i) && if s + i < w then (x >>> s).getLsb i else x.msb) := by + have h : (x >>> s).getLsb i = x.getLsb (s + i) := by + simp only [getLsb_ushiftRight] + rw [h] + simp [getLsb_sshiftRight] + +theorem getLsb_sshift'_eq_getLsb_sshiftRight : + getLsb (sshiftRight' x y) i = getLsb (x.sshiftRight y.toNat) i := by + simp [sshiftRight'] + +/-! ### udiv -/ + +theorem udiv_eq {x y : BitVec n} : + x.udiv y = BitVec.ofNat n (x.toNat / y.toNat) := by + apply BitVec.eq_of_toNat_eq + simp only [udiv, toNat_ofNatLt, toNat_ofNat] + rw [Nat.mod_eq_of_lt] + exact Nat.lt_of_le_of_lt (Nat.div_le_self ..) (by omega) + +theorem toNat_udiv {x y : BitVec n} (hy : 0 < y): + (x.udiv y).toNat = x.toNat / y.toNat := by + rw [udiv_eq] + simp only [toNat_ofNat] + rw [Nat.mod_eq_of_lt] + rw [Nat.div_lt_iff_lt_mul hy] + apply Nat.lt_of_lt_of_le x.isLt + apply Nat.le_mul_of_pos_right _ hy + +/-! ### umod -/ + +theorem umod_eq {x y : BitVec n} : + x.umod y = BitVec.ofNat n (x.toNat % y.toNat) := by + apply BitVec.eq_of_toNat_eq + simp only [umod, toNat_ofNatLt, toNat_ofNat] + rw [Nat.mod_eq_of_lt (b := 2^n)] + apply Nat.lt_of_le_of_lt (Nat.mod_le _ _) x.isLt + +@[simp] +theorem toNat_umod {x y : BitVec n} : + (x.umod y).toNat = x.toNat % y.toNat := by rfl + /-! ### append -/ theorem append_def (x : BitVec v) (y : BitVec w) : @@ -1085,6 +1165,23 @@ theorem ofInt_mul {n} (x y : Int) : BitVec.ofInt n (x * y) = apply eq_of_toInt_eq simp +@[simp] +theorem BitVec.mul_one {x : BitVec w} : x * 1#w = x := by + apply eq_of_toNat_eq + simp [toNat_mul, Nat.mod_eq_of_lt x.isLt] + +@[simp] +theorem BitVec.mul_zero {x : BitVec w} : x * 0#w = 0#w := by + apply eq_of_toNat_eq + simp [toNat_mul] + +theorem BitVec.mul_add {x y z : BitVec w} : + x * (y + z) = x * y + x * z := by + apply eq_of_toNat_eq + simp + rw [Nat.mul_mod, Nat.mod_mod (y.toNat + z.toNat), + ← Nat.mul_mod, Nat.mul_add] + /-! ### le and lt -/ @[bv_toNat] theorem le_def (x y : BitVec n) : @@ -1311,4 +1408,69 @@ theorem getLsb_rotateRight {x : BitVec w} {r i : Nat} : · simp · rw [← rotateRight_mod_eq_rotateRight, getLsb_rotateRight_of_le (Nat.mod_lt _ (by omega))] +/- ## twoPow -/ + +@[simp] +theorem toNat_twoPow (w : Nat) (i : Nat) : (twoPow w i).toNat = 2^i % 2^w := by + rcases w with rfl | w + · simp [Nat.mod_one] + · simp [twoPow, toNat_shiftLeft] + have h1 : 1 < 2 ^ (w + 1) := Nat.one_lt_two_pow (by omega) + rw [Nat.mod_eq_of_lt h1] + rw [Nat.shiftLeft_eq, Nat.one_mul] + +@[simp] +theorem getLsb_twoPow (i j : Nat) : (twoPow w i).getLsb j = ((i < w) && (i = j)) := by + rcases w with rfl | w + · simp only [twoPow, BitVec.reduceOfNat, Nat.zero_le, getLsb_ge, Bool.false_eq, + Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] + omega + · simp only [twoPow, getLsb_shiftLeft, getLsb_ofNat] + by_cases hj : j < i + · simp only [hj, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, Bool.false_eq, + Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not] + omega + · by_cases hi : Nat.testBit 1 (j - i) + · obtain hi' := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi + have hij : j = i := by omega + simp_all + · have hij : i ≠ j := by + intro h; subst h + simp at hi + simp_all + +@[simp] +theorem and_twoPow_eq_getLsb (x : BitVec w) (i : Nat) : + x &&& (twoPow w i) = if x.getLsb i then twoPow w i else 0#w := by + ext j + simp only [getLsb_and, getLsb_twoPow] + by_cases hj : i = j <;> by_cases hx : x.getLsb i <;> simp_all + +@[simp] +theorem mul_twoPow_eq_shiftLeft (x : BitVec w) (i : Nat) : + x * (twoPow w i) = x <<< i := by + apply eq_of_toNat_eq + simp only [toNat_mul, toNat_twoPow, toNat_shiftLeft, Nat.shiftLeft_eq] + by_cases hi : i < w + · have hpow : 2^i < 2^w := Nat.pow_lt_pow_of_lt (by omega) (by omega) + rw [Nat.mod_eq_of_lt hpow] + · have hpow : 2 ^ i % 2 ^ w = 0 := by + rw [Nat.mod_eq_zero_of_dvd] + apply Nat.pow_dvd_pow 2 (by omega) + simp [Nat.mul_mod, hpow] + +theorem BitVec.toNat_twoPow (w : Nat) (i : Nat) : (twoPow w i).toNat = 2^i % 2^w := by + rcases w with rfl | w + · simp [Nat.mod_one] + · simp [twoPow, toNat_shiftLeft] + have hone : 1 < 2 ^ (w + 1) := by + rw [show 1 = 2^0 by simp[Nat.pow_zero]] + exact Nat.pow_lt_pow_of_lt (by omega) (by omega) + simp [Nat.mod_eq_of_lt hone, Nat.shiftLeft_eq] + +@[simp] +theorem twoPow_zero_eq_one (w : Nat) : twoPow w 0 = 1#w := by + apply eq_of_toNat_eq + simp + end BitVec diff --git a/src/Init/Data/BitVec/div_invariant.py b/src/Init/Data/BitVec/div_invariant.py new file mode 100644 index 000000000000..999eb3f56ac3 --- /dev/null +++ b/src/Init/Data/BitVec/div_invariant.py @@ -0,0 +1,60 @@ +def get_lsb(n, j): + return int(bool(n & (1 << j))) + +def print_bits(w, n): + return ("{0:0%sb}" % (w)).format(n) + +def check_pre_invariant(w, n, d, q, r, j): + qright = n // d + rright = n % d + assert r < d + +# n / d <-> n = q * d + r +def check_post_invariant(w, n, d, q, r, j): + qright = n // d + rright = n % d + assert r < d + nhigh = n >> j + print(" n >> j = %s | q(%s) * d(%s) + r(%s) = (%s)" % + (print_bits(w, nhigh), print_bits(w, q), print_bits(w, d), print_bits(w, r), print_bits(w, q * d + r))) + assert nhigh == d * q + r + +def shift_subtract(w, n, d, q, r, j): + print(f"shift_subtract> n: '%s' | d: '%s' | q : '%s' | r : '%s' | j : '%s'" % + (print_bits(w, n), print_bits(w, d), print_bits(w, q), print_bits(w, r), j)) + print(f" n[%s] = %s" % (j, get_lsb(n, j))) + check_pre_invariant(w, n, d, q, r, j) + + r = (r << 1) | get_lsb(n, j) + print(f" r = %s" % print_bits(w, r)) + if r >= d: + print(f" r > d.") + r -= d + q = (q << 1) | 1 + print(f" r.new = %s" % print_bits(w, r)) + print(f" q.new = %s" % print_bits(w, q)) + else: + print(f" r < d.") + q = (q << 1) + print(f" r.new = %s" % print_bits(w, r)) + print(f" q = %s" % print_bits(w, q)) + check_post_invariant(w, n, d, q, r, j) + if j == 0: + (qout, rout) = (q, r) + else: + (qout, rout) = shift_subtract(w, n, d, q, r, j-1) + return (qout, rout) + +# 10 / 3 = 3 +for n in range(1, 32): + for d in range(1, 32): + w = 6 + (q, r) = shift_subtract(w, n, d, 0, 0, w-1) + assert n == d * q + r + if n == d * q + r and r < d: + print ("verified correct invariant for n: '%s' | d : '%s' | q : '%s' r: '%s'" % + (n, d, q, r)) + else: + raise RuntimeError("verification failed for n: '%s' | d: '%s'" % (n, d)) + +