Skip to content

Commit

Permalink
Merge pull request #88 from rindPHI/dev
Browse files Browse the repository at this point in the history
dev: Improved inference of numeric intervals, fixed bug in evaluation of concrete additions & multiplications.
  • Loading branch information
rindPHI authored Nov 1, 2023
2 parents 2c0157e + 4bd7766 commit 0be5292
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 14 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ This file contains the notable changes in the ISLa project since version 0.2a1

## [unreleased]

## [1.14.2] - 2023-11-01

### Changed

- Improved the inference of intervals from grammars for numeric variables (relevant for
the `enable_optimized_z3_queries` mode).
- Fixed a bug in the evaluation of (concrete) arithmetic expressions: So far, we assumed
that Z3 addition and multiplication operations can only have two children, while more
are possible. Thus, some elements were ignored when computing sums or products.

## [1.14.1] - 2023-07-10

### Changed
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "isla-solver"
version = "1.14.1"
version = "1.14.2"
authors = [
{ name = "Dominic Steinhoefel", email = "[email protected]" },
]
Expand Down
2 changes: 1 addition & 1 deletion src/isla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
# You should have received a copy of the GNU General Public License
# along with ISLa. If not, see <http://www.gnu.org/licenses/>.

__version__ = "1.14.1"
__version__ = "1.14.2"
8 changes: 4 additions & 4 deletions src/isla/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __init__(self, *bound_elements: Union[str, BoundVariable, List[str]]):
self.__flattened_elements: Dict[str, Tuple[Tuple[BoundVariable, ...], ...]] = {}

def __add__(self, other: Union[str, "BoundVariable"]) -> "BindExpression":
assert type(other) == str or type(other) == BoundVariable
assert isinstance(other, str) or type(other) is BoundVariable
result = BindExpression(*self.bound_elements)
result.bound_elements.append(other)
return result
Expand Down Expand Up @@ -294,7 +294,7 @@ def __combination_to_tree_prefix(
flattened_bind_expr_str = "".join(
map(
lambda elem: f"{{{elem.n_type} {elem.name}}}"
if type(elem) == BoundVariable
if type(elem) is BoundVariable
else str(elem),
bound_elements,
)
Expand Down Expand Up @@ -1230,7 +1230,7 @@ def __hash__(self):
return hash((type(self).__name__, self.args))

def __eq__(self, other):
return type(self) == type(other) and self.args == other.args
return type(self) is type(other) and self.args == other.args


class NegatedFormula(PropositionalCombinator):
Expand Down Expand Up @@ -1870,7 +1870,7 @@ def __hash__(self):
)

def __eq__(self, other):
return type(self) == type(other) and (
return type(self) is type(other) and (
self.bound_variable,
self.in_variable,
self.inner_formula,
Expand Down
221 changes: 216 additions & 5 deletions src/isla/z3_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,26 @@
#
# You should have received a copy of the GNU General Public License
# along with ISLa. If not, see <http://www.gnu.org/licenses/>.

import itertools
import logging
import operator
import random
import re
import sys
from functools import lru_cache, reduce, partial
from typing import Callable, Tuple, cast, List, Optional, Dict, Union, Generator, Set
from math import prod
from typing import (
Callable,
Tuple,
cast,
List,
Optional,
Dict,
Union,
Generator,
Set,
Sequence,
)

import z3
from z3.z3 import _coerce_exprs
Expand Down Expand Up @@ -400,7 +412,7 @@ def evaluate_z3_add(
if not z3.is_add(expr):
return Maybe.nothing()

return Maybe(construct_result(lambda args: args[0] + args[1], children_results))
return Maybe(construct_result(sum, children_results))


def evaluate_z3_sub(
Expand All @@ -418,7 +430,7 @@ def evaluate_z3_mul(
if not z3.is_mul(expr):
return Maybe.nothing()

return Maybe(construct_result(lambda args: args[0] * args[1], children_results))
return Maybe(construct_result(prod, children_results))


def evaluate_z3_div(
Expand Down Expand Up @@ -1008,7 +1020,11 @@ def numeric_intervals_from_regex(regex: z3.ReRef) -> Maybe[List[Tuple[int, int]]
The interval of strictly positive numbers if created by enforcing the presence of
a leading 1:
>>> numeric_intervals_from_regex(z3.Concat(z3.Star(z3.Re("0")), z3.Range("1", "9"), z3.Star(z3.Range("0", "9"))))
>>> numeric_intervals_from_regex(
... z3.Concat(
... z3.Star(z3.Re("0")),
... z3.Range("1", "9"),
... z3.Star(z3.Range("0", "9"))))
Maybe(a=[(1, 9223372036854775807)])
If the 0-9 interval is inside a Plus, not a Star, we exclude the single-digit
Expand All @@ -1022,6 +1038,16 @@ def numeric_intervals_from_regex(regex: z3.ReRef) -> Maybe[List[Tuple[int, int]]
>>> numeric_intervals_from_regex(z3.Concat(z3.Re("-"), z3.Range("1", "9"), z3.Plus(z3.Range("0", "9"))))
Maybe(a=[(-9223372036854775807, -10)])
ISLa performs some simplifications to handle cases like `a* ++ a` (which is
equivalent to `a+`) as expected.
>>> numeric_intervals_from_regex(
... z3.Concat(
... z3.Concat(z3.Range("1", "9"), z3.Star(z3.Range("0", "9"))),
... z3.Range("0", "9") # <- this was a problem in ISLa <= 1.14.1
... ))
Maybe(a=[(10, 9223372036854775807)])
:param regex: The regular expression from which to extract the represented
intervals.
:return: An optional list of intervals represented by the regular expression. A
Expand Down Expand Up @@ -1197,6 +1223,29 @@ def numeric_intervals_from_concat(
List[z3.ReRef], z3_split_at_operator(regex, z3.Z3_OP_RE_CONCAT)
)

# Replace any `Range(c, c)` by `Re(c)`
children = cast(
List[z3.ReRef],
list(
map(
lambda r: replace_in_z3_expr(
r,
lambda s: (
(z3.Re(s.children()[0]))
if (
s.decl().kind() == z3.Z3_OP_RE_RANGE
and s.children()[0] == s.children()[1]
)
else None
),
),
children,
)
),
)

children = compress_concatenation_elements(children)

first_child = children[0]
if first_child.decl().kind() == z3.Z3_OP_RE_UNION or first_child == z3.Option(
z3.Re("-")
Expand Down Expand Up @@ -1297,3 +1346,165 @@ def numeric_intervals_from_concat(
return Maybe([(0, sys.maxsize)])
else:
return fallback(regex)


def compress_concatenation_elements(
concat_elements: Sequence[z3.ReRef],
) -> List[z3.ReRef]:
"""
This function "compresses" the given elements of a RegEx concatenation.
Example
-------
A regular expression concatenated with its star (in any order) gets its plus.
>>> compress_concatenation_elements([z3.Re("a"), z3.Star(z3.Re("a"))])
[Plus(Re("a"))]
>>> compress_concatenation_elements([z3.Star(z3.Re("a")), z3.Re("a")])
[Plus(Re("a"))]
A concatenation of the same starred expression equals that starred expression
>>> compress_concatenation_elements([z3.Star(z3.Re("a")), z3.Star(z3.Re("a"))])
[Star(Re("a"))]
For plus expressions, this is not the case: `a+ ++ a+` contains at least two times
`a`. However, we can simplify all but one plus expressions to their child and
change the order:
>>> compress_concatenation_elements([
... z3.Plus(z3.Re("a")),
... z3.Plus(z3.Re("a")),
... z3.Plus(z3.Re("a"))])
[Re("a"), Re("a"), Plus(Re("a"))]
If there is a plus in any concatenation, all stars get removed. Also, we normalize
the order.
>>> compress_concatenation_elements([
... z3.Plus(z3.Re("a")),
... z3.Re("a"),
... z3.Star(z3.Re("a"))])
[Re("a"), Plus(Re("a"))]
Nothing happens if the child of the star is different.
>>> compress_concatenation_elements([z3.Star(z3.Re("a")), z3.Re("b")])
[Star(Re("a")), Re("b")]
A concatenation of the same element is also retained.
>>> compress_concatenation_elements([z3.Re("a"), z3.Re("a")])
[Re("a"), Re("a")]
Trivial cases are handled as expected.
>>> compress_concatenation_elements([z3.Re("a")])
[Re("a")]
:param concat_elements: The elements of a concatenation. None of them should be a
concatenation themselves.
:return: An equivalent, compressed/simplified version of the concatenated
expressions.
"""

def key_group_by_star_plus_child(r: z3.ReRef) -> z3.ReRef:
if r.decl().kind() in [z3.Z3_OP_RE_STAR, z3.Z3_OP_RE_PLUS]:
return r.children()[0]

return r

# Each group will contain an arbitrary number of expressions `r`, `r*`, and `r+`
# for the same r.
children_groups: List[List[z3.ReRef]] = list(
map(
lambda p: list(p[1]),
itertools.groupby(concat_elements, key_group_by_star_plus_child),
)
)

new_children: List[z3.ReRef] = []
for group in children_groups:
if len(group) == 1:
# Trivial group.
new_children.append(group[0])
elif all(elem.decl().kind() == z3.Z3_OP_RE_STAR for elem in group):
# Compress spurious star expressions.
new_children.append(group[0])
elif any(elem.decl().kind() == z3.Z3_OP_RE_PLUS for elem in group):
# Compress `r+ ++ r ++ r*` etc. to `r ++ r+`: Remove all stars, turn all but
# one plus into a non-plus and move to the beginning.
# Note: We already ruled out `r* ++ r*`. Thus, if there is any starred
# element in the group, there must also be plus or atomic elements in
# there with the same child. We can compress to plus.
assert all(
not elem.decl().kind() == z3.Z3_OP_RE_STAR
or any(
other_elem == elem.children()[0]
or other_elem.children()[0] == elem.children()[0]
for other_elem in group
if other_elem is not elem
)
for elem in group
)

cleaned_group = [
elem for elem in group if elem.decl().kind() != z3.Z3_OP_RE_STAR
]
cleaned_group = sorted(
cleaned_group,
key=lambda elem: int(elem.decl().kind() == z3.Z3_OP_RE_PLUS),
)
cleaned_group = list(
map(
lambda elem: (
elem
if elem.decl().kind() != z3.Z3_OP_RE_PLUS
else elem.children()[0]
),
cleaned_group[:-1],
)
) + [cleaned_group[-1]]

new_children.extend(cleaned_group)
elif any(elem.decl().kind() == z3.Z3_OP_RE_STAR for elem in group):
# Note: We already ruled out `r* ++ r*` or the existence of any pluses.
# Thus, if there is any starred element in the group, there must also be
# atomic elements in there the same child. We can re-order the elements
# and compress one atomic element and one starred element to a plus.
assert all(
not elem.decl().kind() == z3.Z3_OP_RE_STAR
or ( # any + all ==> there must at least be one of these elements
any(
other_elem == elem.children()[0]
for other_elem in group
if not z3.AstRef.eq(other_elem, elem)
)
and all(
other_elem == elem.children()[0]
for other_elem in group
if not z3.AstRef.eq(other_elem, elem)
)
)
for elem in group
)

cleaned_group = [
elem for elem in group if elem.decl().kind() != z3.Z3_OP_RE_STAR
]
cleaned_group = cleaned_group[:-1] + [z3.Plus(cleaned_group[-1])]

new_children.extend(cleaned_group)
elif all(elem == group[0] for elem in group[1:]):
# Something like `a ++ a`. Note that `a` cannot be a star or plus expression
# since these cases have been addressed before.
new_children.extend(group)
else:
assert False, (
"Seemingly impossible case triggered "
+ "when compressing RegEx concatenation"
)

return new_children
Loading

0 comments on commit 0be5292

Please sign in to comment.