Skip to content

Commit

Permalink
feat: smt solving refinement (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark authored Oct 12, 2023
1 parent 9ae36d1 commit 3f72a02
Show file tree
Hide file tree
Showing 74 changed files with 240 additions and 248 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-ffi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
include:
- testname: "ffi:tests"
- testname: "tests/ffi"

steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test-long.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
fail-fast: false
matrix:
include:
- testname: "tests/solver"
- testname: "examples/simple"
- testname: "examples/tokens/ERC20"
- testname: "examples/tokens/ERC721"
Expand Down
3 changes: 1 addition & 2 deletions examples/simple/test/Vault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ contract VaultTest is SymTest {
vault.setTotalShares(svm.createUint256("S1"));
}

// NOTE: currently timeout when --smt-div is enabled, while producing invalid counterexamples when --smt-div is not given
/// @custom:halmos --solver-timeout-assertion 10000
function check_deposit(uint assets) public {
uint A1 = vault.totalAssets();
uint S1 = vault.totalShares();
Expand All @@ -40,7 +40,6 @@ contract VaultTest is SymTest {
assert(A1 * S2 <= A2 * S1); // no counterexample
}

/// @custom:halmos --smt-div
function check_mint(uint shares) public {
uint A1 = vault.totalAssets();
uint S1 = vault.totalShares();
Expand Down
88 changes: 44 additions & 44 deletions src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from .sevm import *
from .utils import (
create_solver,
hexify,
indent_text,
NamedTimer,
Expand Down Expand Up @@ -155,11 +156,11 @@ def mk_this() -> Address:
return con_addr(magic_address + 1)


def mk_solver(args: Namespace):
# quantifier-free bitvector + array theory; https://smtlib.cs.uiowa.edu/logics.shtml
solver = SolverFor("QF_AUFBV")
solver.set(timeout=args.solver_timeout_branching)
return solver
def mk_solver(args: Namespace, logic="QF_AUFBV", ctx=None, assertion=False):
timeout = (
args.solver_timeout_assertion if assertion else args.solver_timeout_branching
)
return create_solver(logic, ctx, timeout, args.solver_max_memory)


def rendered_initcode(context: CallContext) -> str:
Expand Down Expand Up @@ -567,8 +568,7 @@ def run(
options = mk_options(args)
sevm = SEVM(options)

solver = SolverFor("QF_AUFBV")
solver.set(timeout=args.solver_timeout_branching)
solver = mk_solver(args)
solver.add(setup_ex.solver.assertions())

(exs, steps, logs) = sevm.run(
Expand Down Expand Up @@ -689,17 +689,12 @@ def run(
if is_valid:
print(red(f"Counterexample: {render_model(model)}"))
counterexamples.append(model)
elif args.print_potential_counterexample:
else:
warn(
COUNTEREXAMPLE_INVALID,
f"Counterexample (potentially invalid): {render_model(model)}",
)
counterexamples.append(model)
else:
warn(
COUNTEREXAMPLE_INVALID,
f"Counterexample (potentially invalid): (not displayed, use --print-potential-counterexample)",
)
else:
warn(COUNTEREXAMPLE_UNKNOWN, f"Counterexample: {result}")

Expand Down Expand Up @@ -946,13 +941,20 @@ class GenModelArgs:
sexpr: str


def solve(query: str, args: Namespace) -> Tuple[CheckSatResult, Model]:
solver = mk_solver(args, ctx=Context(), assertion=True)
solver.from_string(query)
result = solver.check()
model = solver.model() if result == sat else None
return result, model


def gen_model_from_sexpr(fn_args: GenModelArgs) -> ModelWithContext:
args, idx, sexpr = fn_args.args, fn_args.idx, fn_args.sexpr
solver = SolverFor("QF_AUFBV", ctx=Context())
solver.set(timeout=args.solver_timeout_assertion)
solver.from_string(sexpr)
res = solver.check()
model = solver.model() if res == sat else None
res, model = solve(sexpr, args)

if res == sat and not is_model_valid(model):
res, model = solve(refine(sexpr), args)

# TODO: handle args.solver_subprocess

Expand All @@ -963,38 +965,40 @@ def is_unknown(result: CheckSatResult, model: Model) -> bool:
return result == unknown or (result == sat and not is_model_valid(model))


def refine(query: str) -> str:
# replace uninterpreted abstraction with actual symbols for assertion solving
# TODO: replace `(evm_bvudiv x y)` with `(ite (= y (_ bv0 256)) (_ bv0 256) (bvudiv x y))`
# as bvudiv is undefined when y = 0; also similarly for evm_bvurem
query = re.sub(r"(\(\s*)evm_(bv[a-z]+)(_[0-9]+)?\b", r"\1\2", query)
# remove the uninterpreted function symbols
# TODO: this will be no longer needed once is_model_valid is properly implemented
return re.sub(
r"\(\s*declare-fun\s+evm_(bv[a-z]+)(_[0-9]+)?\b",
r"(declare-fun dummy_\1\2",
query,
)


def gen_model(args: Namespace, idx: int, ex: Exec) -> ModelWithContext:
if args.verbose >= 1:
print(f"Checking path condition (path id: {idx+1})")

model = None

ex.solver.set(timeout=args.solver_timeout_assertion)
res = ex.solver.check()
if res == sat:
model = ex.solver.model()
model = ex.solver.model() if res == sat else None

if is_unknown(res, model) and args.solver_fresh:
if res == sat and not is_model_valid(model):
if args.verbose >= 1:
print(f" Checking again with a fresh solver")
sol2 = SolverFor("QF_AUFBV", ctx=Context())
# sol2.set(timeout=args.solver_timeout_assertion)
sol2.from_string(ex.solver.to_smt2())
res = sol2.check()
if res == sat:
model = sol2.model()

if is_unknown(res, model) and args.solver_subprocess:
print(f" Checking again with refinement")
res, model = solve(refine(ex.solver.to_smt2()), args)

if args.solver_subprocess and is_unknown(res, model):
if args.verbose >= 1:
print(f" Checking again in an external process")
fname = f"/tmp/{uuid.uuid4().hex}.smt2"
if args.verbose >= 1:
print(f" {args.solver_subprocess_command} {fname} >{fname}.out")
query = ex.solver.to_smt2()
# replace uninterpreted abstraction with actual symbols for assertion solving
# TODO: replace `(evm_bvudiv x y)` with `(ite (= y (_ bv0 256)) (_ bv0 256) (bvudiv x y))`
# as bvudiv is undefined when y = 0; also similarly for evm_bvurem
query = re.sub(r"(\(\s*)evm_(bv[a-z]+)(_[0-9]+)?\b", r"\1\2", query)
query = refine(ex.solver.to_smt2())
with open(fname, "w") as f:
f.write("(set-logic QF_AUFBV)\n")
f.write(query)
Expand Down Expand Up @@ -1048,6 +1052,8 @@ def package_result(


def is_model_valid(model: AnyModel) -> bool:
# TODO: evaluate the path condition against the given model after excluding evm_* symbols,
# since the evm_* symbols may still appear in valid models.
for decl in model:
if str(decl).startswith("evm_"):
return False
Expand Down Expand Up @@ -1077,15 +1083,9 @@ def mk_options(args: Namespace) -> Dict:
"verbose": args.verbose,
"debug": args.debug,
"log": args.log,
"add": not args.no_smt_add,
"sub": not args.no_smt_sub,
"mul": not args.no_smt_mul,
"div": args.smt_div,
"mod": args.smt_mod,
"divByConst": args.smt_div_by_const,
"modByConst": args.smt_mod_by_const,
"expByConst": args.smt_exp_by_const,
"timeout": args.solver_timeout_branching,
"max_memory": args.solver_max_memory,
"sym_jump": args.symbolic_jump,
"print_steps": args.print_steps,
"unknown_calls_return_size": args.return_size_of_unknown_calls,
Expand Down
2 changes: 2 additions & 0 deletions src/halmos/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class InvalidJumpDestError(ExceptionalHalt):
`PUSH-N` opcodes.
"""

pass


class MessageDepthLimitError(ExceptionalHalt):
"""
Expand Down
29 changes: 7 additions & 22 deletions src/halmos/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,6 @@ def mk_arg_parser() -> argparse.ArgumentParser:
# smt solver options
group_solver = parser.add_argument_group("Solver options")

group_solver.add_argument(
"--no-smt-add", action="store_true", help="do not interpret `+`"
)
group_solver.add_argument(
"--no-smt-sub", action="store_true", help="do not interpret `-`"
)
group_solver.add_argument(
"--no-smt-mul", action="store_true", help="do not interpret `*`"
)
group_solver.add_argument("--smt-div", action="store_true", help="interpret `/`")
group_solver.add_argument("--smt-mod", action="store_true", help="interpret `mod`")
group_solver.add_argument(
"--smt-div-by-const", action="store_true", help="interpret division by constant"
)
group_solver.add_argument(
"--smt-mod-by-const", action="store_true", help="interpret constant modulo"
)
group_solver.add_argument(
"--smt-exp-by-const",
metavar="N",
Expand All @@ -205,6 +188,13 @@ def mk_arg_parser() -> argparse.ArgumentParser:
default=1000,
help="set timeout (in milliseconds) for solving assertion violation conditions; 0 means no timeout (default: %(default)s)",
)
group_solver.add_argument(
"--solver-max-memory",
metavar="SIZE",
type=int,
default=0,
help="set memory limit (in megabytes) for the solver; 0 means no limit (default: %(default)s)",
)
group_solver.add_argument(
"--solver-fresh",
action="store_true",
Expand Down Expand Up @@ -248,10 +238,5 @@ def mk_arg_parser() -> argparse.ArgumentParser:
group_experimental.add_argument(
"--symbolic-jump", action="store_true", help="support symbolic jump destination"
)
group_experimental.add_argument(
"--print-potential-counterexample",
action="store_true",
help="print potentially invalid counterexamples",
)

return parser
Loading

0 comments on commit 3f72a02

Please sign in to comment.