Skip to content

Commit

Permalink
tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
namin committed Mar 22, 2022
1 parent 65ed468 commit 054f694
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
19 changes: 14 additions & 5 deletions generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ def generate_input(N=3, LB=5, UB=8):
return [inp]


def run_bustle_cache(dsl, typ, inp, N, src="cache_bustle.pt"):
try:
return torch.load(src)
except FileNotFoundError:
x = run_bustle(dsl, typ, inp, N)
torch.save(x, src)
return x

def run_bustle(dsl, typ, inp, N):
all_search = bustle(dsl, typ, inp, ["dummy" for _ in inp[0]], N=N, print_stats=True)
search = [v for i in range(2, N) for v in all_search[i]["str"]]
Expand All @@ -106,15 +114,16 @@ def run_bustle(dsl, typ, inp, N):


def generate_dataset():
return generate_dataset_cheat()

return generate_dataset_cheat([0])

def generate_dataset_cheat():
def generate_dataset_cheat(only=None):
from dslparser import parse

dsl = stringdsl
progs = [parse(dsl, prog) for prog in stringprogs.stringprogs]
progs1 = [prog for prog in progs if dsl.numInputs(prog) == 1]
if only is not None:
progs1 = [prog for i,prog in enumerate(progs1) if i in only]
exps = list(itertools.chain(*(subexpressions(prog) for prog in progs1)))

N = 6
Expand All @@ -125,15 +134,15 @@ def generate_dataset_cheat():
inp = [stringprogs.input]
print("")
print("BUSTLE")
search, all_search = run_bustle(dsl, typ, inp, N)
search, all_search = run_bustle_cache(dsl, typ, inp, N)
samples = [(e, dsl.evalIO(e, inp)) for e in exps]
all_search = all_search + samples
samples = [(e, o) for (e, o) in samples if type(o[0]) is str]
print("")
for sample in track(samples, description="Samples ..."):
for i in range(N_selected):
data.append(build_sample(sample, all_search, search, dsl, inp))
for j in track(range(10*N_selected), description="Extra samples ..."):
for j in track(range(100*N_selected), description="Extra samples ..."):
data.append(select_expression(all_search, search, dsl, inp))
print()
random.shuffle(data)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def saveModel(Ms):
loss = BCELoss()

print()
for epoch in range(10):
for epoch in range(100):
print(f"Epoch {epoch + 1}")
Ts = {}
for key in dataset:
Expand Down

0 comments on commit 054f694

Please sign in to comment.