Skip to content

Commit

Permalink
Add arbitrary bitvectors using Z
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeom committed Mar 2, 2025
1 parent 7f1417e commit 57fda39
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 0 deletions.
166 changes: 166 additions & 0 deletions src/smtml/bitvector.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
type t =
{ value : Z.t
; width : int
}

let mask width = Z.pred (Z.shift_left Z.one width)

let make v m =
let masked_value = Z.logand v (mask m) in
{ value = masked_value; width = m }

let view { value; _ } = value

let numbits { width; _ } = width

let equal a b = Z.equal a.value b.value && a.width = b.width

let compare a b = Z.compare a.value b.value

let msb bv = Z.testbit bv.value (bv.width - 1)

let to_signed bv =
let msb = msb bv in
if msb then Z.sub bv.value (Z.shift_left Z.one bv.width) else bv.value

let pp fmt bv = Z.pp_print fmt bv.value

(* Unop *)
let neg bv = make (Z.neg bv.value) bv.width

let lognot a = make (Z.lognot a.value) a.width

let clz bv =
let rec count_zeros i =
if i >= bv.width || Z.testbit bv.value (bv.width - 1 - i) then i
else count_zeros (i + 1)
in
make (Z.of_int @@ count_zeros 0) bv.width

let ctz bv =
let rec count_zeros i =
if i >= bv.width || Z.testbit bv.value i then i else count_zeros (i + 1)
in
make (Z.of_int @@ count_zeros 0) bv.width

let popcnt bv = make (Z.of_int @@ Z.popcount bv.value) bv.width

(* Binop *)
let add a b =
assert (a.width = b.width);
make (Z.add a.value b.value) a.width

let sub a b =
assert (a.width = b.width);
make (Z.sub a.value b.value) a.width

let mul a b =
assert (a.width = b.width);
make (Z.mul a.value b.value) a.width

let div a b =
assert (a.width = b.width);
if Z.equal b.value Z.zero then raise Division_by_zero;
make (Z.div (to_signed a) (to_signed b)) a.width

let div_u a b =
assert (a.width = b.width);
if Z.equal b.value Z.zero then raise Division_by_zero;
make (Z.div a.value b.value) a.width

let logand a b =
assert (a.width = b.width);
make (Z.logand a.value b.value) a.width

let logor a b =
assert (a.width = b.width);
make (Z.logor a.value b.value) a.width

let logxor a b =
assert (a.width = b.width);
make (Z.logxor a.value b.value) a.width

let shl a n =
let n = Z.to_int n.value in
make (Z.shift_left a.value n) a.width

let ashr a n =
let n = Z.to_int n.value in
let signed_value = to_signed a in
make (Z.shift_right signed_value n) a.width

let lshr a n =
let n = Z.to_int n.value in
make (Z.shift_right_trunc a.value n) a.width

let rem a b =
assert (a.width = b.width);
if Z.equal b.value Z.zero then raise Division_by_zero;
make (Z.rem (to_signed a) (to_signed b)) a.width

let rem_u a b =
assert (a.width = b.width);
if Z.equal b.value Z.zero then raise Division_by_zero;
make (Z.rem a.value b.value) a.width

let rotate_left bv n =
let n = Z.to_int n.value mod bv.width in
let left_part = Z.shift_left bv.value n in
let right_part = Z.shift_right bv.value (bv.width - n) in
let rotated = Z.logor left_part right_part in
make rotated bv.width

let rotate_right bv n =
let n = Z.to_int n.value mod bv.width in
let right_part = Z.shift_right bv.value n in
let left_part = Z.shift_left bv.value (bv.width - n) in
let rotated = Z.logor left_part right_part in
make rotated bv.width

(* Relop *)
let lt_u a b = Z.lt a.value b.value

let gt_u a b = Z.gt a.value b.value

let le_u a b = Z.leq a.value b.value

let ge_u a b = Z.geq a.value b.value

let lt a b = Z.lt (to_signed a) (to_signed b)

let gt a b = Z.gt (to_signed a) (to_signed b)

let le a b = Z.leq (to_signed a) (to_signed b)

let ge a b = Z.geq (to_signed a) (to_signed b)

(* Extract and concat *)
let concat a b =
let new_width = a.width + b.width in
let shifted = Z.shift_left a.value b.width in
let combined = Z.logor shifted b.value in
make combined new_width

let extract bv ~high ~low =
assert (high <= bv.width && low >= 0 && low < high);
let width = high - low + 1 in
let shifted = Z.shift_right bv.value low in
let extracted = Z.logand shifted (mask width) in
make extracted width

(* Cvtop *)
let zero_extend width bv =
let new_width = bv.width + width in
make bv.value new_width

let sign_extend width bv =
let new_width = bv.width + width in
let msb = msb bv in
let sign_mask =
if msb then
let shift_amount = bv.width in
Z.shift_left (mask width) shift_amount
else Z.zero
in
let extended = Z.logor bv.value sign_mask in
make extended new_width
77 changes: 77 additions & 0 deletions src/smtml/bitvector.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
type t

val make : Z.t -> int -> t

val view : t -> Z.t

val numbits : t -> int

val equal : t -> t -> bool

val compare : t -> t -> int

val pp : Format.formatter -> t -> unit

val neg : t -> t

val lognot : t -> t

val clz : t -> t

val ctz : t -> t

val popcnt : t -> t

val add : t -> t -> t

val sub : t -> t -> t

val mul : t -> t -> t

val div : t -> t -> t

val div_u : t -> t -> t

val logand : t -> t -> t

val logor : t -> t -> t

val logxor : t -> t -> t

val shl : t -> t -> t

val ashr : t -> t -> t

val lshr : t -> t -> t

val rem : t -> t -> t

val rem_u : t -> t -> t

val rotate_left : t -> t -> t

val rotate_right : t -> t -> t

val lt : t -> t -> bool

val lt_u : t -> t -> bool

val gt : t -> t -> bool

val gt_u : t -> t -> bool

val le : t -> t -> bool

val le_u : t -> t -> bool

val ge : t -> t -> bool

val ge_u : t -> t -> bool

val concat : t -> t -> t

val extract : t -> high:int -> low:int -> t

val zero_extend : int -> t -> t

val sign_extend : int -> t -> t
1 change: 1 addition & 0 deletions src/smtml/dune
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
altergo_mappings
ast
;axioms
bitvector
bitwuzla_mappings
binder
cache
Expand Down
3 changes: 3 additions & 0 deletions test/bitvector/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
(test
(name test_bitvector)
(libraries smtml))
85 changes: 85 additions & 0 deletions test/bitvector/test_bitvector.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
open Smtml.Bitvector

let z n = Z.of_int n (* Helper to create Z.t values *)

let test_make () =
let bv = make (z 5) 8 in
assert (view bv = z 5);
assert (numbits bv = 8)

let test_neg () =
let bv = make (z 5) 8 in
assert (equal (neg bv) (make (z (-5)) 8))

let test_add () =
let bv1 = make (z 3) 8 in
let bv2 = make (z 5) 8 in
assert (view (add bv1 bv2) = z 8)

let test_sub () =
let bv1 = make (z 10) 8 in
let bv2 = make (z 3) 8 in
assert (view (sub bv1 bv2) = z 7)

let test_mul () =
let bv1 = make (z 4) 8 in
let bv2 = make (z 3) 8 in
assert (view (mul bv1 bv2) = z 12)

let test_div () =
let bv1 = make (z 10) 8 in
let bv2 = make (z 2) 8 in
assert (view (div bv1 bv2) = z 5)

let test_div_u () =
let bv1 = make (z 10) 8 in
let bv2 = make (z 3) 8 in
assert (view (div_u bv1 bv2) = z (10 / 3))

let test_logical_ops () =
let bv1 = make (z 0b1100) 4 in
let bv2 = make (z 0b1010) 4 in
assert (view (logand bv1 bv2) = z 0b1000);
assert (view (logor bv1 bv2) = z 0b1110);
assert (view (logxor bv1 bv2) = z 0b0110)

let test_shifts () =
let bv = make (z 0b0011) 4 in
assert (view (shl bv (make (z 1) 4)) = z 0b0110);
assert (view (lshr bv (make (z 1) 4)) = z 0b0001);
assert (view (ashr bv (make (z 1) 4)) = z 0b0001)

let test_comparisons () =
let bv1 = make (z 3) 4 in
let bv2 = make (z 5) 4 in
assert (lt bv1 bv2);
assert (le bv1 bv2);
assert (gt bv2 bv1);
assert (ge bv2 bv1);
assert (lt_u bv1 bv2);
assert (gt_u bv2 bv1)

let test_rotate () =
let bv = make (z 0b1101) 4 in
let one = make (z 1) 4 in
assert (view (rotate_left bv one) = z 0b1011);
assert (view (rotate_right bv one) = z 0b1110)

let test_extensions () =
let bv = make (z 0b1010) 4 in
assert (numbits (zero_extend 4 bv) = 8);
assert (numbits (sign_extend 4 bv) = 8)

let () =
test_make ();
test_neg ();
test_add ();
test_sub ();
test_mul ();
test_div ();
test_div_u ();
test_logical_ops ();
test_shifts ();
test_comparisons ();
test_rotate ();
test_extensions ()

0 comments on commit 57fda39

Please sign in to comment.