Skip to content

Commit

Permalink
More chaching
Browse files Browse the repository at this point in the history
  • Loading branch information
rindPHI committed Sep 1, 2022
1 parent ff42135 commit 8d325b9
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 75 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "grammar_graph"
version = "0.1.13"
version = "0.1.14"
authors = [
{ name="Dominic Steinhöfel", email="[email protected]" },
]
Expand Down
2 changes: 1 addition & 1 deletion src/grammar_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.13"
__version__ = "0.1.14"
112 changes: 44 additions & 68 deletions src/grammar_graph/gg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,19 @@
import json
import re
import sys
from functools import lru_cache, wraps
from functools import lru_cache
from typing import List, Dict, Callable, Union, Optional, Tuple, cast, Set

import fibheap as fh
from graphviz import Digraph

from grammar_graph.helpers import traverse_tree, TRAVERSE_POSTORDER, unreachable_nonterminals
from grammar_graph.type_defs import ParseTree, Grammar
from grammar_graph.helpers import traverse_tree, TRAVERSE_POSTORDER, unreachable_nonterminals, parse_tree_arg_hashable, \
grammar_to_immutable
from grammar_graph.type_defs import ParseTree, Grammar, ImmutableGrammar

RE_NONTERMINAL = re.compile(r'(<[^<> ]*>)')


def parse_tree_arg_hashable(a_func: callable) -> callable:
# This assumes that the first argument of the decorated function is a `ParseTree`

@wraps(a_func)
def decorated(*args, **kwargs):
assert isinstance(args[1], ParseTree)
args = (args[0], parse_tree_to_hashable(args[1]),) + args[2:]
return a_func(*args, **kwargs)

return decorated


def parse_tree_to_hashable(elem: ParseTree) -> ParseTree:
stack: List[ParseTree] = []

def action(_, node: ParseTree):
if not node[1]:
# noinspection PyTypeChecker
stack.append((node[0], None if node[1] is None else ()))
else:
children = []
for _ in range(len(node[1])):
children.append(stack.pop())
# noinspection PyTypeChecker
stack.append((node[0], tuple(children)))

traverse_tree(elem, action, kind=TRAVERSE_POSTORDER, reverse=True)

assert len(stack) == 1
return stack[0]


@lru_cache(maxsize=None)
def is_nonterminal(s):
return RE_NONTERMINAL.match(s)
Expand Down Expand Up @@ -136,27 +105,40 @@ def __init__(
root: Node,
grammar: Optional[Grammar] = None,
all_nodes: Optional[Set[Node]] = None,
all_edges: Optional[Set[Tuple[Node, Node]]] = None,
reachable: Optional[Dict[Tuple[Node, Node], bool]] = None):
all_edges: Optional[Set[Tuple[Node, Node]]] = None):
assert isinstance(root, Node)
self.root = root
self._grammar = grammar
self.__grammar: Optional[Grammar] = grammar
self.__immutable_grammar: Optional[ImmutableGrammar] = \
grammar_to_immutable(grammar) if grammar is not None else None
self.__all_nodes: Optional[Set[Node]] = all_nodes
self.__all_edges: Optional[Set[Tuple[Node, Node]]] = all_edges
self.__reachable: Dict[Tuple[Node, Node], bool] = reachable or {}
self.__hash = None

@property
def grammar(self):
if self._grammar is None:
self._grammar = self._compute_grammar()
def grammar(self) -> Grammar:
if self.__grammar is None:
self.__grammar = self._compute_grammar()
self.__immutable_grammar = grammar_to_immutable(self.__grammar)

return self._grammar
return self.__grammar

@grammar.setter
def grammar(self, grammar: Grammar):
raise NotImplementedError()

@property
def immutable_grammar(self) -> ImmutableGrammar:
if self.__immutable_grammar is None:
self.__grammar = self._compute_grammar()
self.__immutable_grammar = grammar_to_immutable(self.__grammar)

return self.__immutable_grammar

@immutable_grammar.setter
def immutable_grammar(self, grammar: ImmutableGrammar):
raise NotImplementedError()

def __repr__(self):
return f"GrammarGraph({repr(self.root)})"

Expand All @@ -165,7 +147,7 @@ def __eq__(self, other):

def __hash__(self):
if self.__hash is None:
self.__hash = hash(json.dumps(self.to_grammar()))
self.__hash = hash(self.immutable_grammar)
return self.__hash

def bfs(self, action: Callable[[Node], Union[None, bool]], start_node: Union[None, Node] = None):
Expand Down Expand Up @@ -224,23 +206,12 @@ def all_edges(self, val: Set[Tuple[Node, Node]]) -> None:
self.__all_edges = val

def reachable(self, from_node: Union[str, Node], to_node: Union[str, Node]) -> bool:
# Note: Reachability is not reflexive!
def node_in_children(node: Node) -> bool:
return isinstance(node, NonterminalNode) and to_node in node.children

if isinstance(from_node, str):
from_node = self.get_node(from_node)
if isinstance(to_node, str):
to_node = self.get_node(to_node)

assert from_node in self.all_nodes
assert to_node in self.all_nodes

if (from_node, to_node) not in self.__reachable:
sources = self.filter(node_in_children, node_in_children, from_node=from_node)
self.__reachable[(from_node, to_node)] = len(sources) > 0

return self.__reachable[(from_node, to_node)]
return reachable(self, from_node, to_node)

def shortest_non_trivial_path(self, source: Node, target: Node,
nodes_filter: Optional[Callable[[Node], bool]] =
Expand Down Expand Up @@ -346,7 +317,7 @@ def to_grammar(self):
"""Deprecated; use the `grammar` property."""
return self.grammar

def _compute_grammar(self):
def _compute_grammar(self) -> Grammar:
result: Grammar = {}

def action(node: Node):
Expand Down Expand Up @@ -376,7 +347,6 @@ def subgraph(self, nonterminal: Union[NonterminalNode, str]):

all_nodes: Optional[Set[Node]] = copy.copy(self.__all_nodes)
all_edges: Optional[Set[Tuple[Node, Node]]] = copy.copy(self.__all_edges)
reachable: Dict[Tuple[Node, Node], bool] = copy.copy(self.__reachable)

new_grammar = copy.deepcopy(self.grammar)
new_grammar['<start>'] = [nonterminal.symbol]
Expand All @@ -389,12 +359,6 @@ def subgraph(self, nonterminal: Union[NonterminalNode, str]):
all_edges = set(filter(
lambda t: t[0].symbol not in unreachable_symbols and t[1].symbol not in unreachable_symbols,
all_edges))
# noinspection PyTypeChecker
reachable = dict(filter(
lambda item: (
item[0][0].symbol not in unreachable_symbols and
item[0][1].symbol not in unreachable_symbols),
reachable.items()))

for unreachable_symbol in unreachable_symbols:
del new_grammar[unreachable_symbol]
Expand All @@ -403,8 +367,7 @@ def subgraph(self, nonterminal: Union[NonterminalNode, str]):
root_node,
grammar=new_grammar,
all_nodes=all_nodes,
all_edges=all_edges,
reachable=reachable)
all_edges=all_edges)

def parents(self, node: Node) -> List[Node]:
result = []
Expand Down Expand Up @@ -678,7 +641,7 @@ def k_paths(
include_terminals=True) -> set[Tuple[Node, ...]]:
assert k > 0
k += k - 1 # Each path of k terminal/nonterminal nodes includes k-1 choice nodes
result: set[Tuple[Node, ...]] = set([])
result: List[Tuple[Node, ...]] = []

if not start_node:
all_nodes = self.all_nodes
Expand All @@ -701,7 +664,7 @@ def k_paths(

node_result = new_node_result

result.update(node_result)
result.extend(node_result)

return {
kpath for kpath in result
Expand Down Expand Up @@ -791,3 +754,16 @@ def path_to_string(p, include_choice_node=True) -> str:
else n.symbol
for n in p
if include_choice_node or not isinstance(n, ChoiceNode)])


@lru_cache(maxsize=None)
def reachable(graph: GrammarGraph, from_node: Node, to_node: Node) -> bool:
# Note: Reachability is not reflexive!
def node_in_children(node: Node) -> bool:
return isinstance(node, NonterminalNode) and to_node in node.children

assert from_node in graph.all_nodes
assert to_node in graph.all_nodes

sources = graph.filter(node_in_children, node_in_children, from_node=from_node)
return len(sources) > 0
41 changes: 39 additions & 2 deletions src/grammar_graph/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
from typing import TypeVar, Callable, List, Tuple, Set
from functools import wraps
from typing import TypeVar, Callable, List, Tuple, Set, cast

from grammar_graph.type_defs import Tree, Path, ParseTree, Grammar
from grammar_graph.type_defs import Tree, Path, ParseTree, Grammar, ImmutableGrammar

TRAVERSE_PREORDER = 0
TRAVERSE_POSTORDER = 1
Expand Down Expand Up @@ -37,6 +38,10 @@ def delete_unreachable(grammar: Grammar) -> None:
del grammar[unreachable]


def grammar_to_immutable(grammar: Grammar) -> ImmutableGrammar:
return cast(ImmutableGrammar, tuple({k: tuple(v) for k, v in grammar.items()}.items()))


def traverse_tree(
tree: ParseTree,
action: Callable[[Path, ParseTree], None],
Expand Down Expand Up @@ -71,3 +76,35 @@ def traverse_tree(
if kind == TRAVERSE_POSTORDER:
while stack_2:
action(*stack_2.pop())


def parse_tree_arg_hashable(a_func: callable) -> callable:
# This assumes that the first argument of the decorated function is a `ParseTree`

@wraps(a_func)
def decorated(*args, **kwargs):
assert isinstance(args[1], ParseTree)
args = (args[0], parse_tree_to_immutable(args[1]),) + args[2:]
return a_func(*args, **kwargs)

return decorated


def parse_tree_to_immutable(elem: ParseTree) -> ParseTree:
stack: List[ParseTree] = []

def action(_, node: ParseTree):
if not node[1]:
# noinspection PyTypeChecker
stack.append((node[0], None if node[1] is None else ()))
else:
children = []
for _ in range(len(node[1])):
children.append(stack.pop())
# noinspection PyTypeChecker
stack.append((node[0], tuple(children)))

traverse_tree(elem, action, kind=TRAVERSE_POSTORDER, reverse=True)

assert len(stack) == 1
return stack[0]
1 change: 1 addition & 0 deletions src/grammar_graph/type_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

NonterminalType = str
Grammar = Dict[NonterminalType, List[str]]
ImmutableGrammar = Tuple[Tuple[str, Tuple[str, ...]], ...]


class ParseTree(ABC):
Expand Down
4 changes: 1 addition & 3 deletions tests/test_gg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fuzzingbook.Grammars import JSON_GRAMMAR, US_PHONE_GRAMMAR, is_nonterminal, srange
from fuzzingbook.Parser import CSV_GRAMMAR, EarleyParser

from grammar_graph.gg import GrammarGraph, Node, NonterminalNode, ChoiceNode, TerminalNode, path_to_string
from grammar_graph.gg import GrammarGraph, Node, NonterminalNode, ChoiceNode, TerminalNode, path_to_string, reachable
from grammar_graph.helpers import delete_unreachable


Expand Down Expand Up @@ -81,7 +81,6 @@ def test_get_subgraph(self):
self.assertTrue(graph.reachable('<member>', '<object>'))
self.assertFalse(graph.reachable('<member>', '<json>'))
self.assertTrue(graph.reachable('<string>', '<character>'))
self.assertEqual(3, len(graph._GrammarGraph__reachable))
self.assertEqual(
set(JSON_GRAMMAR.keys()),
{n.symbol for n in graph.all_nodes if type(n) is NonterminalNode})
Expand All @@ -92,7 +91,6 @@ def test_get_subgraph(self):
delete_unreachable(sub_grammar)

self.assertEqual(sub_grammar, sub_graph.grammar)
self.assertEqual(1, len(sub_graph._GrammarGraph__reachable))
self.assertEqual(
set(sub_grammar.keys()),
{n.symbol for n in sub_graph._GrammarGraph__all_nodes if type(n) is NonterminalNode})
Expand Down

0 comments on commit 8d325b9

Please sign in to comment.