-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransform_synthesis.py
145 lines (135 loc) · 6.6 KB
/
transform_synthesis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# taskName: str, vocab: VocabFactory, oeManager: OEValuesManager, contexts: List[Map[String, Any]]):
from transform import *
from task import *
import unittest
from typing import Union, List, Dict, Iterator
from enum import Enum
from VocabMaker import VocabFactory
from lookaheadIterator import LookaheadIterator
from childrenIterator import ChildrenIterator
from itertools import chain
class TSizeEnumerator:
def __init__(self, task: Task, vocab: VocabFactory, oeManager, filterast=None, contexts=[]):
self.task = task
self.vocab = vocab
self.oeManager = oeManager
self.contexts = contexts
self.filter = filterast
self.nextProgram = None
self.bank: Dict[int, List[TransformASTNode]] = {}
self.costLevel = 1
self.currLevelProgs: List[TransformASTNode] = []
self.currIter = LookaheadIterator(iter(vocab.leaves()))
self.rootMaker = self.currIter.next()
self.childrenIterator = LookaheadIterator(iter([None]))
self.childrenIterators = []
self.currentChildIteratorIndex = 0
self.maxterminals = max(
[nonleaf.default_size + nonleaf.arity for nonleaf in vocab.nonLeaves()])
self.programCounter = 0
self.currentValueSets = None
self.currentProgram = None
def hasNext(self) -> bool:
if self.nextProgram:
return True
else:
self.nextProgram = self.getNextProgram()
return self.nextProgram is not None
def next(self) -> TransformASTNode:
if not self.nextProgram:
self.nextProgram = self.getNextProgram()
res = self.nextProgram
self.nextProgram = None
return res
def advanceRoot(self) -> bool:
if not self.currIter.hasNext():
return False
self.rootMaker = self.currIter.next()
if self.rootMaker.arity == 0 and self.rootMaker.size == self.costLevel:
self.childrenIterator = LookaheadIterator(iter([None]))
elif self.rootMaker.arity == 0 and self.rootMaker.nodeType == Types.TRANSFORMS:
self.childrenIterator = LookaheadIterator(iter([None]))
elif self.rootMaker.arity > 0: # TODO: Cost-based enumeration
childrenCost = self.costLevel - self.rootMaker.default_size
self.childrenIterator = ChildrenIterator(
self.rootMaker.childTypes, childrenCost, self.bank)
self.childrenIterators = [ChildrenIterator(
childType, childrenCost, self.bank) for childType in self.rootMaker.childTypes]
self.currentChildIteratorIndex = 0 # Keep track of which iterator is current
self.childrenIterator = self.childrenIterators[self.currentChildIteratorIndex]
elif self.rootMaker.arity == 2 and self.rootMaker.childTypes == [Types.TRANSFORMS, Types.TRANSFORMS]:
childrenCost = self.costLevel - 1
self.childrenIterator = ChildrenIterator(
self.rootMaker.childTypes, childrenCost, self.bank)
else:
self.childrenIterator = LookaheadIterator(iter([]))
return True
def changeLevel(self) -> bool:
self.costLevel += 1
#if self.costLevel > self.maxterminals + 2:
#self.currIter = LookaheadIterator(iter([Transforms]))
self.currIter = LookaheadIterator(chain(self.vocab.leaves(), self.vocab.nonLeaves(), [Transforms]))
for p in self.currLevelProgs:
self.updateBank(p)
self.currLevelProgs.clear()
return self.advanceRoot()
def getNextProgram(self):
if self.currentValueSets is not None and self.currentValueSets.hasNext():
value = self.currentValueSets.get_nextValue()
return self.createProgram(value)
while not self.nextProgram:
if self.costLevel > 25:
break
if self.childrenIterator.hasNext():
children = self.childrenIterator.next()
if (children is None and self.rootMaker.arity == 0) or (self.rootMaker.arity == len(children)
and all(child.nodeType == child_type for child, child_type
in zip(children, self.rootMaker.childTypes[self.currentChildIteratorIndex]))):
prog = self.rootMaker.apply(self.task, children, self.filter)
if isinstance(prog.values, VariableIterator):
self.programCounter = 0
self.currentValueSets = prog.values # save the iterator values
value = self.currentValueSets.get_nextValue() # first value in the iterator
self.currentProgram = prog
if self.oeManager.is_representative(value): # OE
return self.createProgram(value)
elif children is None or self.oeManager.is_representative(prog.values):
self.nextProgram = prog
if children is not None:
if any("Var" in child.code for child in children):
self.nextProgram.values_apply = self.task.values_to_apply[0]
elif self.currentChildIteratorIndex + 1 < len(self.childrenIterators):
self.currentChildIteratorIndex += 1
self.childrenIterator = self.childrenIterators[self.currentChildIteratorIndex]
elif self.currIter.hasNext():
if (not self.advanceRoot()):
return None
else:
if (not self.changeLevel()):
self.changeLevel()
if self.nextProgram:
self.currLevelProgs.append(self.nextProgram)
res = self.nextProgram
self.nextProgram = None
return res
return None
def createProgram(self, value_set):
code_parts = self.currentProgram.code.rsplit('_', 1)
if len(code_parts) > 1 and code_parts[-1].isdigit():
base_code = code_parts[0]
else:
base_code = self.currentProgram.code
new_code = f"{base_code}_{self.programCounter}"
new_program = self.currentProgram.custom_copy()
new_program.values = [value_set]
new_program.code = new_code
new_program.values_apply = self.task.values_to_apply[self.programCounter]
new_program.spec = self.task.all_specs[self.programCounter]
self.programCounter += 1
self.currLevelProgs.append(new_program)
return new_program
def updateBank(self, program):
if program.size not in self.bank:
self.bank[program.size] = [program]
else:
self.bank[program.size].append(program)