From 044eaf9cd1a09d1d3f4d633fef583bda807c9209 Mon Sep 17 00:00:00 2001 From: Christoph Kirsch Date: Mon, 28 Oct 2024 15:53:23 +0100 Subject: [PATCH] Preparing constant propagation --- tools/bitme.py | 176 ++++++++++++++++++++++++++++++------------------- 1 file changed, 108 insertions(+), 68 deletions(-) diff --git a/tools/bitme.py b/tools/bitme.py index afb7730e..8b8509e4 100755 --- a/tools/bitme.py +++ b/tools/bitme.py @@ -365,20 +365,26 @@ def __init__(self, nid, sid_line, value, comment, line_no): def get_mapped_array_expression_for(self, index): return self + def get_value(self): + if isinstance(self.sid_line, Bool): + return bool(self.value) + else: + return self.value + def get_z3(self): if self.z3 is None: if isinstance(self.sid_line, Bool): - self.z3 = z3.BoolVal(bool(self.value)) + self.z3 = z3.BoolVal(self.get_value()) else: - self.z3 = z3.BitVecVal(self.value, self.sid_line.size) + self.z3 = z3.BitVecVal(self.get_value(), self.sid_line.size) return self.z3 def get_bitwuzla(self, tm): if self.bitwuzla is None: if isinstance(self.sid_line, Bool): - self.bitwuzla = tm.mk_true() if bool(self.value) else tm.mk_false() + self.bitwuzla = tm.mk_true() if self.get_value() else tm.mk_false() else: - self.bitwuzla = tm.mk_bv_value(self.sid_line.get_bitwuzla(tm), self.value) + self.bitwuzla = tm.mk_bv_value(self.sid_line.get_bitwuzla(tm), self.get_value()) return self.bitwuzla class Zero(Constant): @@ -482,6 +488,9 @@ def __init__(self, nid, sid_line, symbol, comment, line_no, index = None): def __str__(self): return f"{self.nid} {Input.keyword} {self.sid_line.nid} {self.symbol} {self.comment}" + def get_value(self): + return self + def get_z3_step(self, step): return self.get_z3() @@ -512,6 +521,7 @@ def __init__(self, nid, sid_line, symbol, comment, line_no, index = None): self.name = f"state{self.nid}" self.init_line = None self.next_line = None + self.value = None self.new_state(index) # rotor-dependent program counter declaration if comment == "; program counter": @@ -531,6 +541,12 @@ def remove_state(self): del State.states[key] return + def get_value(self): + return self.value + + def set_value(self, value): + self.value = value + def get_step_name(self, step): return f"{self.name}-{step}" @@ -539,17 +555,17 @@ def get_z3_step(self, step): self.cache_z3[step] = z3.Const(self.get_step_name(step), self.sid_line.get_z3()) return self.cache_z3[step] - def get_z3_lambda(term, domain): - if domain: - return z3.Lambda([state.get_z3() for state in domain], term) + def get_z3_lambda(line): + if line.domain: + return z3.Lambda([state.get_z3() for state in line.domain], line.get_z3()) else: - return term + return line.get_z3() - def get_z3_select(term, domain, step): + def get_z3_select(line, domain, step): if domain: - return z3.Select(term, *[state.get_z3_step(step) for state in domain]) + return z3.Select(line.get_z3_lambda(), *[state.get_z3_step(step) for state in domain]) else: - return term + return line.get_z3_lambda() def get_bitwuzla(self, tm): if self.bitwuzla is None: @@ -562,19 +578,19 @@ def get_bitwuzla_step(self, step, tm): self.get_step_name(step)) return self.cache_bitwuzla[step] - def get_bitwuzla_lambda(term, domain, tm): - if domain: + def get_bitwuzla_lambda(line, tm): + if line.domain: return tm.mk_term(bitwuzla.Kind.LAMBDA, - [*[state.get_bitwuzla(tm) for state in domain], term]) + [*[state.get_bitwuzla(tm) for state in line.domain], line.get_bitwuzla(tm)]) else: - return term + return line.get_bitwuzla(tm) - def get_bitwuzla_select(term, domain, step, tm): + def get_bitwuzla_select(line, domain, step, tm): if domain: return tm.mk_term(bitwuzla.Kind.APPLY, - [term, *[state.get_bitwuzla_step(step, tm) for state in domain]]) + [line.get_bitwuzla_lambda(tm), *[state.get_bitwuzla_step(step, tm) for state in domain]]) else: - return term + return line.get_bitwuzla_lambda(tm) class Indexed(Expression): def __init__(self, nid, sid_line, arg1_line, comment, line_no): @@ -1118,10 +1134,19 @@ def get_z3(self): self.arg2_line.get_z3(), self.arg3_line.get_z3()) return self.z3 - def get_z3_step(self, step): + def get_z3_lambda(self): + # only needed for branching if self.z3_lambda_line is None: - self.z3_lambda_line = State.get_z3_lambda(self.get_z3(), self.domain) - return State.get_z3_select(self.z3_lambda_line, self.domain, step) + self.z3_lambda_line = State.get_z3_lambda(self) + return self.z3_lambda_line + + def get_z3_select(self, domain, step): + # only needed for branching + return State.get_z3_select(self, domain, step) + + def get_z3_step(self, step): + # only needed for branching + return self.get_z3_select(self.domain, step) def get_bitwuzla(self, tm): if self.bitwuzla is None: @@ -1131,10 +1156,19 @@ def get_bitwuzla(self, tm): self.arg3_line.get_bitwuzla(tm)]) return self.bitwuzla - def get_bitwuzla_step(self, step, tm): + def get_bitwuzla_lambda(self, tm): + # only needed for branching if self.bitwuzla_lambda_line is None: - self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(self.get_bitwuzla(tm), self.domain, tm) - return State.get_bitwuzla_select(self.bitwuzla_lambda_line, self.domain, step, tm) + self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(self, tm) + return self.bitwuzla_lambda_line + + def get_bitwuzla_select(self, domain, step, tm): + # only needed for branching + return State.get_bitwuzla_select(self, domain, step, tm) + + def get_bitwuzla_step(self, step, tm): + # only needed for branching + return self.get_bitwuzla_select(self.domain, step, tm) class Write(Ternary): keyword = OP_WRITE @@ -1210,6 +1244,12 @@ def __init__(self, nid, comment, line_no): self.bitwuzla_lambda_line = None self.cache_bitwuzla_select = {} + def get_z3_select(self, domain, step): + return State.get_z3_select(self, domain, step) + + def get_bitwuzla_select(self, domain, step, tm): + return State.get_bitwuzla_select(self, domain, step, tm) + class Transitional(Sequential): def __init__(self, nid, sid_line, state_line, exp_line, comment, line_no, array_line, index): super().__init__(nid, comment, line_no) @@ -1258,6 +1298,29 @@ def new_transition(self, transitions, index): assert self.nid not in transitions, f"transition nid {self.nid} already defined @ {self.line_no}" transitions[self.nid] = self + def get_z3_lambda(self): + if self.z3_lambda_line is None: + self.z3_lambda_line = State.get_z3_lambda(self.exp_line) + return self.z3_lambda_line + + def get_z3_select(self, step): + if step not in self.cache_z3_select: + self.cache_z3_select[step] = super().get_z3_select(self.exp_line.domain, step) + return self.cache_z3_select[step] + + def get_bitwuzla_lambda(self, tm): + if self.bitwuzla_lambda_line is None: + self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(self.exp_line, tm) + return self.bitwuzla_lambda_line + + def get_bitwuzla_select(self, step, tm): + if step not in self.cache_bitwuzla_select: + self.cache_bitwuzla_select[step] = super().get_bitwuzla_select(self.exp_line.domain, step, tm) + return self.cache_bitwuzla_select[step] + + def set_value(self): + self.state_line.set_value(self.exp_line.get_value()) + class Init(Transitional): keyword = OP_INIT @@ -1284,11 +1347,11 @@ def get_z3_step(self, step): # initialize with constant array return self.state_line.get_z3_step(0) == z3.K( self.sid_line.array_size_line.get_z3(), - self.exp_line.get_z3()) + self.exp_line.get_z3()) else: - return self.state_line.get_z3_step(0) == State.get_z3_select( - State.get_z3_lambda(self.exp_line.get_z3(), self.exp_line.domain), - self.exp_line.domain, 0) + if isinstance(self.exp_line, Constant): + self.set_value() + return self.state_line.get_z3_step(0) == self.get_z3_select(0) def get_bitwuzla_step(self, step, tm): assert step == 0, f"bitwuzla init with {step} != 0" @@ -1296,15 +1359,14 @@ def get_bitwuzla_step(self, step, tm): # initialize with constant array return tm.mk_term(bitwuzla.Kind.EQUAL, [self.state_line.get_bitwuzla_step(0, tm), - tm.mk_const_array(self.sid_line.get_bitwuzla(tm), - self.exp_line.get_bitwuzla(tm))]) + tm.mk_const_array(self.sid_line.get_bitwuzla(tm), + self.exp_line.get_bitwuzla(tm))]) else: + if isinstance(self.exp_line, Constant): + self.set_value() return tm.mk_term(bitwuzla.Kind.EQUAL, [self.state_line.get_bitwuzla_step(0, tm), - State.get_bitwuzla_select( - State.get_bitwuzla_lambda( - self.exp_line.get_bitwuzla(tm), self.exp_line.domain, tm), - self.exp_line.domain, 0, tm)]) + self.get_bitwuzla_select(0, tm)]) class Next(Transitional): keyword = OP_NEXT @@ -1326,18 +1388,6 @@ def __init__(self, nid, sid_line, state_line, exp_line, comment, line_no, array_ def __str__(self): return f"{self.nid} {Next.keyword} {self.sid_line.nid} {self.state_line.nid} {self.exp_line.nid} {self.comment}" - def get_z3_lambda(self): - if self.z3_lambda_line is None: - self.z3_lambda_line = State.get_z3_lambda( - self.exp_line.get_z3(), self.exp_line.domain) - return self.z3_lambda_line - - def get_z3_select(self, step): - if step not in self.cache_z3_select: - self.cache_z3_select[step] = State.get_z3_select( - self.get_z3_lambda(), self.exp_line.domain, step) - return self.cache_z3_select[step] - def get_z3_step(self, step): if step not in self.cache_z3: self.cache_z3[step] = self.state_line.get_z3_step(step + 1) == self.get_z3_select(step) @@ -1353,18 +1403,6 @@ def get_z3_no_change(self, step): self.cache_z3_no_change[step] = self.state_line.get_z3_step(step + 1) == self.state_line.get_z3_step(step) return self.cache_z3_no_change[step] - def get_bitwuzla_lambda(self, tm): - if self.bitwuzla_lambda_line is None: - self.bitwuzla_lambda_line = State.get_bitwuzla_lambda( - self.exp_line.get_bitwuzla(tm), self.exp_line.domain, tm) - return self.bitwuzla_lambda_line - - def get_bitwuzla_select(self, step, tm): - if step not in self.cache_bitwuzla_select: - self.cache_bitwuzla_select[step] = State.get_bitwuzla_select( - self.get_bitwuzla_lambda(tm), self.exp_line.domain, step, tm) - return self.cache_bitwuzla_select[step] - def get_bitwuzla_step(self, step, tm): if step not in self.cache_bitwuzla: self.cache_bitwuzla[step] = tm.mk_term(bitwuzla.Kind.EQUAL, @@ -1401,22 +1439,24 @@ def __init__(self, nid, property_line, symbol, comment, line_no): def set_mapped_array_expression(self): self.property_line = self.property_line.get_mapped_array_expression_for(None) - def get_z3_step(self, step): + def get_z3_lambda(self): if self.z3_lambda_line is None: - self.z3_lambda_line = State.get_z3_lambda( - self.property_line.get_z3(), self.property_line.domain) + self.z3_lambda_line = State.get_z3_lambda(self.property_line) + return self.z3_lambda_line + + def get_z3_step(self, step): if step not in self.cache_z3: - self.cache_z3[step] = State.get_z3_select( - self.z3_lambda_line, self.property_line.domain, step) + self.cache_z3[step] = super().get_z3_select(self.property_line.domain, step) return self.cache_z3[step] - def get_bitwuzla_step(self, step, tm): + def get_bitwuzla_lambda(self, tm): if self.bitwuzla_lambda_line is None: - self.bitwuzla_lambda_line = State.get_bitwuzla_lambda( - self.property_line.get_bitwuzla(tm), self.property_line.domain, tm) + self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(self.property_line, tm) + return self.bitwuzla_lambda_line + + def get_bitwuzla_step(self, step, tm): if step not in self.cache_bitwuzla: - self.cache_bitwuzla[step] = State.get_bitwuzla_select( - self.bitwuzla_lambda_line, self.property_line.domain, step, tm) + self.cache_bitwuzla[step] = super().get_bitwuzla_select(self.property_line.domain, step, tm) return self.cache_bitwuzla[step] class Constraint(Property):