Skip to content

Commit

Permalink
Preparing constant propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirsch committed Oct 28, 2024
1 parent e4aceb1 commit 044eaf9
Showing 1 changed file with 108 additions and 68 deletions.
176 changes: 108 additions & 68 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,20 +365,26 @@ def __init__(self, nid, sid_line, value, comment, line_no):
def get_mapped_array_expression_for(self, index):
return self

def get_value(self):
if isinstance(self.sid_line, Bool):
return bool(self.value)
else:
return self.value

def get_z3(self):
if self.z3 is None:
if isinstance(self.sid_line, Bool):
self.z3 = z3.BoolVal(bool(self.value))
self.z3 = z3.BoolVal(self.get_value())
else:
self.z3 = z3.BitVecVal(self.value, self.sid_line.size)
self.z3 = z3.BitVecVal(self.get_value(), self.sid_line.size)
return self.z3

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
if isinstance(self.sid_line, Bool):
self.bitwuzla = tm.mk_true() if bool(self.value) else tm.mk_false()
self.bitwuzla = tm.mk_true() if self.get_value() else tm.mk_false()
else:
self.bitwuzla = tm.mk_bv_value(self.sid_line.get_bitwuzla(tm), self.value)
self.bitwuzla = tm.mk_bv_value(self.sid_line.get_bitwuzla(tm), self.get_value())
return self.bitwuzla

class Zero(Constant):
Expand Down Expand Up @@ -482,6 +488,9 @@ def __init__(self, nid, sid_line, symbol, comment, line_no, index = None):
def __str__(self):
return f"{self.nid} {Input.keyword} {self.sid_line.nid} {self.symbol} {self.comment}"

def get_value(self):
return self

def get_z3_step(self, step):
return self.get_z3()

Expand Down Expand Up @@ -512,6 +521,7 @@ def __init__(self, nid, sid_line, symbol, comment, line_no, index = None):
self.name = f"state{self.nid}"
self.init_line = None
self.next_line = None
self.value = None
self.new_state(index)
# rotor-dependent program counter declaration
if comment == "; program counter":
Expand All @@ -531,6 +541,12 @@ def remove_state(self):
del State.states[key]
return

def get_value(self):
return self.value

def set_value(self, value):
self.value = value

def get_step_name(self, step):
return f"{self.name}-{step}"

Expand All @@ -539,17 +555,17 @@ def get_z3_step(self, step):
self.cache_z3[step] = z3.Const(self.get_step_name(step), self.sid_line.get_z3())
return self.cache_z3[step]

def get_z3_lambda(term, domain):
if domain:
return z3.Lambda([state.get_z3() for state in domain], term)
def get_z3_lambda(line):
if line.domain:
return z3.Lambda([state.get_z3() for state in line.domain], line.get_z3())
else:
return term
return line.get_z3()

def get_z3_select(term, domain, step):
def get_z3_select(line, domain, step):
if domain:
return z3.Select(term, *[state.get_z3_step(step) for state in domain])
return z3.Select(line.get_z3_lambda(), *[state.get_z3_step(step) for state in domain])
else:
return term
return line.get_z3_lambda()

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
Expand All @@ -562,19 +578,19 @@ def get_bitwuzla_step(self, step, tm):
self.get_step_name(step))
return self.cache_bitwuzla[step]

def get_bitwuzla_lambda(term, domain, tm):
if domain:
def get_bitwuzla_lambda(line, tm):
if line.domain:
return tm.mk_term(bitwuzla.Kind.LAMBDA,
[*[state.get_bitwuzla(tm) for state in domain], term])
[*[state.get_bitwuzla(tm) for state in line.domain], line.get_bitwuzla(tm)])
else:
return term
return line.get_bitwuzla(tm)

def get_bitwuzla_select(term, domain, step, tm):
def get_bitwuzla_select(line, domain, step, tm):
if domain:
return tm.mk_term(bitwuzla.Kind.APPLY,
[term, *[state.get_bitwuzla_step(step, tm) for state in domain]])
[line.get_bitwuzla_lambda(tm), *[state.get_bitwuzla_step(step, tm) for state in domain]])
else:
return term
return line.get_bitwuzla_lambda(tm)

class Indexed(Expression):
def __init__(self, nid, sid_line, arg1_line, comment, line_no):
Expand Down Expand Up @@ -1118,10 +1134,19 @@ def get_z3(self):
self.arg2_line.get_z3(), self.arg3_line.get_z3())
return self.z3

def get_z3_step(self, step):
def get_z3_lambda(self):
# only needed for branching
if self.z3_lambda_line is None:
self.z3_lambda_line = State.get_z3_lambda(self.get_z3(), self.domain)
return State.get_z3_select(self.z3_lambda_line, self.domain, step)
self.z3_lambda_line = State.get_z3_lambda(self)
return self.z3_lambda_line

def get_z3_select(self, domain, step):
# only needed for branching
return State.get_z3_select(self, domain, step)

def get_z3_step(self, step):
# only needed for branching
return self.get_z3_select(self.domain, step)

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
Expand All @@ -1131,10 +1156,19 @@ def get_bitwuzla(self, tm):
self.arg3_line.get_bitwuzla(tm)])
return self.bitwuzla

def get_bitwuzla_step(self, step, tm):
def get_bitwuzla_lambda(self, tm):
# only needed for branching
if self.bitwuzla_lambda_line is None:
self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(self.get_bitwuzla(tm), self.domain, tm)
return State.get_bitwuzla_select(self.bitwuzla_lambda_line, self.domain, step, tm)
self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(self, tm)
return self.bitwuzla_lambda_line

def get_bitwuzla_select(self, domain, step, tm):
# only needed for branching
return State.get_bitwuzla_select(self, domain, step, tm)

def get_bitwuzla_step(self, step, tm):
# only needed for branching
return self.get_bitwuzla_select(self.domain, step, tm)

class Write(Ternary):
keyword = OP_WRITE
Expand Down Expand Up @@ -1210,6 +1244,12 @@ def __init__(self, nid, comment, line_no):
self.bitwuzla_lambda_line = None
self.cache_bitwuzla_select = {}

def get_z3_select(self, domain, step):
return State.get_z3_select(self, domain, step)

def get_bitwuzla_select(self, domain, step, tm):
return State.get_bitwuzla_select(self, domain, step, tm)

class Transitional(Sequential):
def __init__(self, nid, sid_line, state_line, exp_line, comment, line_no, array_line, index):
super().__init__(nid, comment, line_no)
Expand Down Expand Up @@ -1258,6 +1298,29 @@ def new_transition(self, transitions, index):
assert self.nid not in transitions, f"transition nid {self.nid} already defined @ {self.line_no}"
transitions[self.nid] = self

def get_z3_lambda(self):
if self.z3_lambda_line is None:
self.z3_lambda_line = State.get_z3_lambda(self.exp_line)
return self.z3_lambda_line

def get_z3_select(self, step):
if step not in self.cache_z3_select:
self.cache_z3_select[step] = super().get_z3_select(self.exp_line.domain, step)
return self.cache_z3_select[step]

def get_bitwuzla_lambda(self, tm):
if self.bitwuzla_lambda_line is None:
self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(self.exp_line, tm)
return self.bitwuzla_lambda_line

def get_bitwuzla_select(self, step, tm):
if step not in self.cache_bitwuzla_select:
self.cache_bitwuzla_select[step] = super().get_bitwuzla_select(self.exp_line.domain, step, tm)
return self.cache_bitwuzla_select[step]

def set_value(self):
self.state_line.set_value(self.exp_line.get_value())

class Init(Transitional):
keyword = OP_INIT

Expand All @@ -1284,27 +1347,26 @@ def get_z3_step(self, step):
# initialize with constant array
return self.state_line.get_z3_step(0) == z3.K(
self.sid_line.array_size_line.get_z3(),
self.exp_line.get_z3())
self.exp_line.get_z3())
else:
return self.state_line.get_z3_step(0) == State.get_z3_select(
State.get_z3_lambda(self.exp_line.get_z3(), self.exp_line.domain),
self.exp_line.domain, 0)
if isinstance(self.exp_line, Constant):
self.set_value()
return self.state_line.get_z3_step(0) == self.get_z3_select(0)

def get_bitwuzla_step(self, step, tm):
assert step == 0, f"bitwuzla init with {step} != 0"
if isinstance(self.sid_line, Array) and isinstance(self.exp_line.sid_line, Bitvec):
# initialize with constant array
return tm.mk_term(bitwuzla.Kind.EQUAL,
[self.state_line.get_bitwuzla_step(0, tm),
tm.mk_const_array(self.sid_line.get_bitwuzla(tm),
self.exp_line.get_bitwuzla(tm))])
tm.mk_const_array(self.sid_line.get_bitwuzla(tm),
self.exp_line.get_bitwuzla(tm))])
else:
if isinstance(self.exp_line, Constant):
self.set_value()
return tm.mk_term(bitwuzla.Kind.EQUAL,
[self.state_line.get_bitwuzla_step(0, tm),
State.get_bitwuzla_select(
State.get_bitwuzla_lambda(
self.exp_line.get_bitwuzla(tm), self.exp_line.domain, tm),
self.exp_line.domain, 0, tm)])
self.get_bitwuzla_select(0, tm)])

class Next(Transitional):
keyword = OP_NEXT
Expand All @@ -1326,18 +1388,6 @@ def __init__(self, nid, sid_line, state_line, exp_line, comment, line_no, array_
def __str__(self):
return f"{self.nid} {Next.keyword} {self.sid_line.nid} {self.state_line.nid} {self.exp_line.nid} {self.comment}"

def get_z3_lambda(self):
if self.z3_lambda_line is None:
self.z3_lambda_line = State.get_z3_lambda(
self.exp_line.get_z3(), self.exp_line.domain)
return self.z3_lambda_line

def get_z3_select(self, step):
if step not in self.cache_z3_select:
self.cache_z3_select[step] = State.get_z3_select(
self.get_z3_lambda(), self.exp_line.domain, step)
return self.cache_z3_select[step]

def get_z3_step(self, step):
if step not in self.cache_z3:
self.cache_z3[step] = self.state_line.get_z3_step(step + 1) == self.get_z3_select(step)
Expand All @@ -1353,18 +1403,6 @@ def get_z3_no_change(self, step):
self.cache_z3_no_change[step] = self.state_line.get_z3_step(step + 1) == self.state_line.get_z3_step(step)
return self.cache_z3_no_change[step]

def get_bitwuzla_lambda(self, tm):
if self.bitwuzla_lambda_line is None:
self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(
self.exp_line.get_bitwuzla(tm), self.exp_line.domain, tm)
return self.bitwuzla_lambda_line

def get_bitwuzla_select(self, step, tm):
if step not in self.cache_bitwuzla_select:
self.cache_bitwuzla_select[step] = State.get_bitwuzla_select(
self.get_bitwuzla_lambda(tm), self.exp_line.domain, step, tm)
return self.cache_bitwuzla_select[step]

def get_bitwuzla_step(self, step, tm):
if step not in self.cache_bitwuzla:
self.cache_bitwuzla[step] = tm.mk_term(bitwuzla.Kind.EQUAL,
Expand Down Expand Up @@ -1401,22 +1439,24 @@ def __init__(self, nid, property_line, symbol, comment, line_no):
def set_mapped_array_expression(self):
self.property_line = self.property_line.get_mapped_array_expression_for(None)

def get_z3_step(self, step):
def get_z3_lambda(self):
if self.z3_lambda_line is None:
self.z3_lambda_line = State.get_z3_lambda(
self.property_line.get_z3(), self.property_line.domain)
self.z3_lambda_line = State.get_z3_lambda(self.property_line)
return self.z3_lambda_line

def get_z3_step(self, step):
if step not in self.cache_z3:
self.cache_z3[step] = State.get_z3_select(
self.z3_lambda_line, self.property_line.domain, step)
self.cache_z3[step] = super().get_z3_select(self.property_line.domain, step)
return self.cache_z3[step]

def get_bitwuzla_step(self, step, tm):
def get_bitwuzla_lambda(self, tm):
if self.bitwuzla_lambda_line is None:
self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(
self.property_line.get_bitwuzla(tm), self.property_line.domain, tm)
self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(self.property_line, tm)
return self.bitwuzla_lambda_line

def get_bitwuzla_step(self, step, tm):
if step not in self.cache_bitwuzla:
self.cache_bitwuzla[step] = State.get_bitwuzla_select(
self.bitwuzla_lambda_line, self.property_line.domain, step, tm)
self.cache_bitwuzla[step] = super().get_bitwuzla_select(self.property_line.domain, step, tm)
return self.cache_bitwuzla[step]

class Constraint(Property):
Expand Down

0 comments on commit 044eaf9

Please sign in to comment.