-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
49 lines (43 loc) · 1.54 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import inspect
import torch
import numpy as np
import random
class LinearSchedule(object):
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
"""Linear interpolation between initial_p and final_p over
schedule_timesteps. After this many timesteps pass final_p is
returned.
Parameters
----------
schedule_timesteps: int
Number of timesteps for which to linearly anneal initial_p
to final_p
initial_p: float
initial output value
final_p: float
final output value
"""
self.schedule_timesteps = schedule_timesteps
self.final_p = final_p
self.initial_p = initial_p
def __call__(self, t):
"""See Schedule.value"""
dt = min(float(t) / self.schedule_timesteps, 1.0)
return self.initial_p + dt * (self.final_p - self.initial_p)
def set_global_seeds(seed):
"""
set the seed for python random, tensorflow, numpy and gym spaces
:param seed: (int) the seed
"""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
class Parameters:
def save_parameters(self, ignore=[]):
"""Save function arguments into class attributes"""
frame = inspect.currentframe().f_back
_, _, _, local_vars = inspect.getargvalues(frame)
self.hparams = {k:v for k, v in local_vars.items()
if k not in set(ignore+['self']) and not k.startswith('_')}
for k, v in self.hparams.items():
setattr(self, k, v)