Skip to content

Commit

Permalink
Save state
Browse files Browse the repository at this point in the history
  • Loading branch information
jnlt3 committed Aug 1, 2022
1 parent 25d68a1 commit 6616ad6
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@

import dataclasses
import json
import pathlib
import time
from graph import Graph
from spsa import Param, SpsaParams, SpsaTuner
Expand All @@ -25,11 +27,38 @@ def spsa_from_config(config_path: str):
return SpsaParams(**config)


def save_state(spsa: SpsaTuner):
save_file = "./tuner/state.json"
spsa_params = spsa.spsa
uci_params = spsa.uci_params
t = spsa.t
with open(save_file, "w") as save_file:
spsa_params = dataclasses.asdict(spsa_params)
uci_params = [dataclasses.asdict(
uci_param) for uci_param in uci_params]

json.dump({"t": t, "spsa_params": spsa_params,
"uci_params": uci_params}, save_file)


def main():
params = params_from_config("config.json")

state_path = pathlib.Path("./tuner/state.json")
t = 0
if state_path.is_file():
print("hey")
with open(state_path) as state:
state_dict = json.load(state)
params = [Param(cfg["name"], cfg["value"], cfg["min_value"], cfg["max_value"], cfg["step"])
for cfg in state_dict["uci_params"]]
spsa_params = SpsaParams(**state_dict["spsa_params"])
t = state_dict["t"]
else:
params = params_from_config("config.json")
spsa_params = spsa_from_config("spsa.json")
cutechess = cutechess_from_config("cutechess.json")
spsa_params = spsa_from_config("spsa.json")
spsa = SpsaTuner(spsa_params, params, cutechess)
spsa.t = t
graph = Graph()

avg_time = 0
Expand All @@ -42,12 +71,14 @@ def main():

graph.update(spsa.t, copy.deepcopy(spsa.params))
graph.save("graph.png")
print(f"iterations: {spsa.t} ({(avg_time / spsa.t):.2f}s per iter)")
print(
f"iterations: {spsa.t} ({(avg_time / spsa.t):.2f}s per iter)")
for param in spsa.params:
print(param)
print()
finally:
print("Saving state...")
save_state(spsa)


if __name__ == "__main__":
Expand Down

0 comments on commit 6616ad6

Please sign in to comment.