generated from klb2/reproducible-paper-python-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
48 lines (43 loc) · 1.47 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os.path
from datetime import datetime
import ray
from ray import tune
from ray import train
from environment import SecretKeyEnv
if __name__ == "__main__":
from configs import env_config
from configs import RESULTS_DIR
ray.shutdown()
ray.init(num_cpus=63, num_gpus=0)
algorithm = "PPO"
tuner = tune.Tuner(
algorithm,
tune_config=tune.TuneConfig(
metric="episode_reward_min",
mode="max",
num_samples=42,
max_concurrent_trials=63,
),
param_space={
"env": SecretKeyEnv,
"env_config": env_config,
"model": {
"free_log_std": True,
"use_lstm": False,
},
},
run_config=train.RunConfig(
stop={"training_iteration": 5000},
checkpoint_config=train.CheckpointConfig(
num_to_keep=2, checkpoint_at_end=True, checkpoint_frequency=4
),
),
)
results = tuner.fit()
best_result = results.get_best_result()
best_checkpoint_mean = best_result.get_best_checkpoint("episode_reward_mean", "max")
best_checkpoint_min = best_result.get_best_checkpoint("episode_reward_min", "max")
_dirname = "{}-{}".format(datetime.now().strftime("%Y-%m-%d-%H-%M"), algorithm)
best_checkpoint_mean.to_directory(os.path.join(RESULTS_DIR, _dirname))
best_checkpoint_min.to_directory(os.path.join(RESULTS_DIR, f"{_dirname}-min"))
ray.shutdown()