-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
99 lines (81 loc) · 3.24 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Some useful tools:
1) set random seeds for everywhere;
2) a lightweight configuration class inspired by yacs
References:
0) Karpathy's code from https://github.com/karpathy/minGPT
"""
import os
import random
from ast import literal_eval
import numpy as np
import torch
# -----------------------------------------------------------------------------
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Apple Chip
torch.mps.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class ConfigNode:
""" a lightweight configuration class inspired by yacs """
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __str__(self):
return self._str_helper(0)
def _str_helper(self, indent):
""" need to have a helper to support nested indentation for pretty printing """
parts = []
for k, v in self.__dict__.items():
if isinstance(v, ConfigNode):
parts.append("%s:\n" % k)
parts.append(v._str_helper(indent + 1))
else:
parts.append("%s: %s\n" % (k, v))
parts = [' ' * (indent * 4) + p for p in parts]
return "".join(parts)
def to_dict(self):
""" return a dict representation of the config """
return { k: v.to_dict() if isinstance(v, ConfigNode) else v for k, v in self.__dict__.items() }
def merge_from_dict(self, d):
self.__dict__.update(d)
def merge_from_args(self, args):
"""
update the configuration from a list of strings that is expected
to come from the command line, i.e. sys.argv[1:].
The arguments are expected to be in the form of `--arg=value`, and
the arg can use . to denote nested sub-attributes. Example:
--model.n_layer=10 --trainer.batch_size=32
"""
for arg in args:
keyval = arg.split('=')
assert len(keyval) == 2, "expecting each override arg to be of form --arg=value, got %s" % arg
key, val = keyval # unpack
# first translate val into a python object
try:
val = literal_eval(val)
"""
need some explanation here.
- if val is simply a string, literal_eval will throw a ValueError
- if val represents a thing (like an 3, 3.14, [1,2,3], False, None, etc.) it will get created
"""
except ValueError:
pass
# find the appropriate object to insert the attribute into
assert key[:2] == '--'
key = key[2:] # strip the '--'
keys = key.split('.')
obj = self
for k in keys[:-1]:
obj = getattr(obj, k)
leaf_key = keys[-1]
# ensure that this attribute exists
assert hasattr(obj, leaf_key), f"{key} is not an attribute that exists in the config"
# overwrite the attribute
print("command line overwriting config attribute %s with %s" % (key, val))
setattr(obj, leaf_key, val)