Skip to content

Commit

Permalink
First attempt at constant propagation for ext and slice
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirsch committed Nov 3, 2024
1 parent a7a6436 commit e75400e
Showing 1 changed file with 37 additions and 13 deletions.
50 changes: 37 additions & 13 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,33 +361,38 @@ def __init__(self, nid, sid_line, domain, comment, line_no):
class Constant(Expression):
def __init__(self, nid, sid_line, value, comment, line_no):
super().__init__(nid, sid_line, {}, comment, line_no)
self.value = value
if not(0 <= value < 2**sid_line.size or -2**(sid_line.size - 1) <= value < 2**(sid_line.size - 1)):
self.print_value = value
self.signed_value = value
if 0 <= value < 2**sid_line.size:
self.value = value
if 2**(sid_line.size - 1) <= value:
self.signed_value = value - 2**sid_line.size
elif -2**(sid_line.size - 1) <= value < 2**(sid_line.size - 1):
assert value < 0
self.value = 2**sid_line.size + value
else:
raise model_error(f"{value} in range of {sid_line.size}-bit bitvector", line_no)

def get_mapped_array_expression_for(self, index):
return self

def get_value(self):
if isinstance(self.sid_line, Bool):
return bool(self.value)
else:
return self.value
return self

def get_z3(self):
if self.z3 is None:
if isinstance(self.sid_line, Bool):
self.z3 = z3.BoolVal(self.get_value())
self.z3 = z3.BoolVal(bool(self.value))
else:
self.z3 = z3.BitVecVal(self.get_value(), self.sid_line.size)
self.z3 = z3.BitVecVal(self.value, self.sid_line.size)
return self.z3

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
if isinstance(self.sid_line, Bool):
self.bitwuzla = tm.mk_true() if self.get_value() else tm.mk_false()
self.bitwuzla = tm.mk_true() if bool(self.value) else tm.mk_false()
else:
self.bitwuzla = tm.mk_bv_value(self.sid_line.get_bitwuzla(tm), self.get_value())
self.bitwuzla = tm.mk_bv_value(self.sid_line.get_bitwuzla(tm), self.value)
return self.bitwuzla

class Zero(Constant):
Expand Down Expand Up @@ -415,7 +420,7 @@ def __init__(self, nid, sid_line, value, comment, line_no):
super().__init__(nid, sid_line, value, comment, line_no)

def __str__(self):
return f"{self.nid} {Constd.keyword} {self.sid_line.nid} {self.value} {self.comment}"
return f"{self.nid} {Constd.keyword} {self.sid_line.nid} {self.print_value} {self.comment}"

class Const(Constant):
keyword = OP_CONST
Expand Down Expand Up @@ -551,9 +556,11 @@ def get_mapped_array_expression_for(self, index):
return super().get_mapped_array_expression_for(index)

def get_value(self):
assert self.value is not None
return self.value

def set_value(self, value):
assert self.sid_line.match_sorts(value.sid_line)
self.value = value

def get_step_name(self, step):
Expand Down Expand Up @@ -636,19 +643,28 @@ def copy(self, arg1_line):
else:
return self

def get_value(self):
arg1_value = self.arg1_line.get_value()
if isinstance(arg1_value, Constant):
return type(arg1_value)(next_nid(), self.sid_line, arg1_value.value, self.comment, self.line_no)
else:
return self.copy(arg1_value)

def get_z3(self):
if self.z3 is None:
if self.op == 'sext':
self.z3 = z3.SignExt(self.w, self.arg1_line.get_z3())
elif self.op == 'uext':
else:
assert self.op == 'uext'
self.z3 = z3.ZeroExt(self.w, self.arg1_line.get_z3())
return self.z3

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
if self.op == 'sext':
bitwuzla_op = bitwuzla.Kind.BV_SIGN_EXTEND
elif self.op == 'uext':
else:
assert self.op == 'uext'
bitwuzla_op = bitwuzla.Kind.BV_ZERO_EXTEND
self.bitwuzla = tm.mk_term(bitwuzla_op,
[self.arg1_line.get_bitwuzla(tm)], [self.w])
Expand Down Expand Up @@ -677,6 +693,14 @@ def copy(self, arg1_line):
else:
return self

def get_value(self):
arg1_value = self.arg1_line.get_value()
if isinstance(arg1_value, Constant):
return type(arg1_value)(next_nid(), self.sid_line,
(arg1_value.value & 2**(self.u + 1) - 1) >> self.l, self.comment, self.line_no)
else:
return self.copy(arg1_value)

def get_z3(self):
if self.z3 is None:
self.z3 = z3.Extract(self.u, self.l, self.arg1_line.get_z3())
Expand Down

0 comments on commit e75400e

Please sign in to comment.