-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feature: awesome config * fix * rm easydict in requirements * update version * fix * rename * fix
- Loading branch information
Showing
18 changed files
with
409 additions
and
208 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
from .config import import_config | ||
from .config import Config, import_config | ||
from .core import Runner, AvgMeter, MeterPool | ||
from .launcher import launch_runner, launch_training | ||
from .version import __version__ | ||
|
||
__all__ = [ | ||
'import_config', 'Runner', 'Runner', 'AvgMeter', 'MeterPool', 'launch_runner', | ||
'Config', 'import_config', 'Runner', 'Runner', 'AvgMeter', 'MeterPool', 'launch_runner', | ||
'launch_training', '__version__' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Everything is based on config. | ||
`Config` is the set of all configurations. `Config` is is implemented by `dict`, We recommend using `Config`. | ||
Look at the following example: | ||
cfg.py | ||
```python | ||
import os | ||
from easytorch import Config | ||
from my_runner import MyRunner | ||
CFG = {} | ||
CFG.DESC = 'my net' # customized description | ||
CFG.RUNNER = MyRunner | ||
CFG.GPU_NUM = 1 | ||
CFG.MODEL = {} | ||
CFG.MODEL.NAME = 'my_net' | ||
CFG.TRAIN = {} | ||
CFG.TRAIN.NUM_EPOCHS = 100 | ||
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( | ||
'checkpoints', | ||
'_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) | ||
) | ||
CFG.TRAIN.CKPT_SAVE_STRATEGY = None | ||
CFG.TRAIN.OPTIM = {} | ||
CFG.TRAIN.OPTIM.TYPE = 'SGD' | ||
CFG.TRAIN.OPTIM.PARAM = { | ||
'lr': 0.002, | ||
'momentum': 0.1, | ||
} | ||
CFG.TRAIN.DATA = {} | ||
CFG.TRAIN.DATA.BATCH_SIZE = 4 | ||
CFG.TRAIN.DATA.DIR = './my_data' | ||
CFG.TRAIN.DATA.SHUFFLE = True | ||
CFG.TRAIN.DATA.PIN_MEMORY = True | ||
CFG.TRAIN.DATA.PREFETCH = True | ||
CFG.VAL = {} | ||
CFG.VAL.INTERVAL = 1 | ||
CFG.VAL.DATA = {} | ||
CFG.VAL.DATA.DIR = 'mnist_data' | ||
CFG._TRAINING_INDEPENDENT` = [ | ||
'OTHER_CONFIG' | ||
] | ||
``` | ||
All configurations consists of two parts: | ||
1. Training dependent configuration: changing this will affect the training results. | ||
2. Training independent configuration: changing this will not affect the training results. | ||
Notes: | ||
All training dependent configurations will be calculated MD5, | ||
this MD5 value will be the sub directory name of checkpoint save directory. | ||
If the MD5 value is `098f6bcd4621d373cade4e832627b4f6`, | ||
real checkpoint save directory is `{CFG.TRAIN.CKPT_SAVE_DIR}/098f6bcd4621d373cade4e832627b4f6` | ||
Notes: | ||
Each configuration default is training dependent, | ||
except the key is in `TRAINING_INDEPENDENT_KEYS` or `CFG._TRAINING_INDEPENDENT` | ||
""" | ||
from .config import Config | ||
from .utils import config_str, config_md5, save_config_str, copy_config_file, import_config, convert_config, \ | ||
get_ckpt_save_dir, init_cfg | ||
|
||
|
||
__all__ = [ | ||
'Config', 'config_str', 'config_md5', 'save_config_str', 'copy_config_file', | ||
'import_config', 'convert_config', 'get_ckpt_save_dir', 'init_cfg' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# Modified from: https://github.com/makinacorpus/easydict/blob/master/easydict/__init__.py | ||
from typing import overload | ||
|
||
|
||
class Config(dict): | ||
""" | ||
Get attributes | ||
>>> d = Config({'foo':3}) | ||
>>> d['foo'] | ||
3 | ||
>>> d.foo | ||
3 | ||
>>> d.bar | ||
Traceback (most recent call last): | ||
... | ||
AttributeError: 'Config' object has no attribute 'bar' | ||
Works recursively | ||
>>> d = Config({'foo':3, 'bar':{'x':1, 'y':2}}) | ||
>>> isinstance(d.bar, dict) | ||
True | ||
>>> d.bar.x | ||
1 | ||
>>> d['bar.x'] | ||
1 | ||
>>> d.get('bar.x') | ||
1 | ||
>>> d.get('bar.z') | ||
None | ||
>>> d.get('bar.z', 3) | ||
3 | ||
>>> d.has('bar.x') | ||
True | ||
>>> d.has('bar.z') | ||
False | ||
Bullet-proof | ||
>>> Config({}) | ||
{} | ||
>>> Config(d={}) | ||
{} | ||
>>> Config(None) | ||
{} | ||
>>> d = {'a': 1} | ||
>>> Config(**d) | ||
{'a': 1} | ||
Set attributes | ||
>>> d = Config() | ||
>>> d.foo = 3 | ||
>>> d.foo | ||
3 | ||
>>> d.bar = {'prop': 'value'} | ||
>>> d.bar.prop | ||
'value' | ||
>>> d | ||
{'foo': 3, 'bar': {'prop': 'value'}} | ||
>>> d.bar.prop = 'newer' | ||
>>> d.bar.prop | ||
'newer' | ||
Values extraction | ||
>>> d = Config({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) | ||
>>> isinstance(d.bar, list) | ||
True | ||
>>> from operator import attrgetter | ||
>>> map(attrgetter('x'), d.bar) | ||
[1, 3] | ||
>>> map(attrgetter('y'), d.bar) | ||
[2, 4] | ||
>>> d = Config() | ||
>>> d.keys() | ||
[] | ||
>>> d = Config(foo=3, bar=dict(x=1, y=2)) | ||
>>> d.foo | ||
3 | ||
>>> d.bar.x | ||
1 | ||
Still like a dict though | ||
>>> o = Config({'clean':True}) | ||
>>> o.items() | ||
[('clean', True)] | ||
And like a class | ||
>>> class Flower(Config): | ||
... power = 1 | ||
... | ||
>>> f = Flower() | ||
>>> f.power | ||
1 | ||
>>> f = Flower({'height': 12}) | ||
>>> f.height | ||
12 | ||
>>> f['power'] | ||
1 | ||
>>> sorted(f.keys()) | ||
['height', 'power'] | ||
update and pop items | ||
>>> d = Config(a=1, b='2') | ||
>>> e = Config(c=3.0, a=9.0) | ||
>>> d.update(e) | ||
>>> d.c | ||
3.0 | ||
>>> d['c'] | ||
3.0 | ||
>>> d.get('c') | ||
3.0 | ||
>>> d.update(a=4, b=4) | ||
>>> d.b | ||
4 | ||
>>> d.pop('a') | ||
4 | ||
>>> d.a | ||
Traceback (most recent call last): | ||
... | ||
AttributeError: 'Config' object has no attribute 'a' | ||
""" | ||
|
||
# pylint: disable=super-init-not-called | ||
def __init__(self, d=None, **kwargs): | ||
if d is None: | ||
d = {} | ||
if kwargs: | ||
d.update(**kwargs) | ||
for k, v in d.items(): | ||
setattr(self, k, v) | ||
# Class attributes | ||
for k in self.__class__.__dict__: | ||
if not (k.startswith('__') and k.endswith('__')) and not k in ('has', 'get', 'update', 'pop'): | ||
setattr(self, k, getattr(self, k)) | ||
|
||
def __setattr__(self, name, value): | ||
if isinstance(value, (list, tuple)): | ||
v = [self.__class__(x) if isinstance(x, dict) else x for x in value] | ||
# Don't repalce tuple with list | ||
if isinstance(value, tuple): | ||
v = tuple(v) | ||
value = v | ||
elif isinstance(value, dict) and not isinstance(value, self.__class__): | ||
value = self.__class__(value) | ||
super().__setattr__(name, value) | ||
super().__setitem__(name, value) | ||
|
||
__setitem__ = __setattr__ | ||
|
||
def __getitem__(self, key): | ||
# Support `cfg['AA.BB.CC']` | ||
if isinstance(key, str): | ||
keys = key.split('.') | ||
else: | ||
keys = key | ||
value = super().__getitem__(keys[0]) | ||
if len(keys) > 1: | ||
return value.__getitem__(keys[1:]) | ||
else: | ||
return value | ||
|
||
def has(self, key): | ||
return self.get(key) is not None | ||
|
||
@overload | ||
def get(self, key): ... | ||
|
||
def get(self, key, default=None): | ||
# Support `cfg.get('AA.BB.CC')` and `cfg.get('AA.BB.CC', default_value)` | ||
try: | ||
return self[key] | ||
except KeyError: | ||
return default | ||
|
||
def update(self, e=None, **f): | ||
d = e or {} | ||
d.update(f) | ||
for k in d: | ||
setattr(self, k, d[k]) | ||
|
||
def pop(self, k, d=None): | ||
# Check for existence | ||
if hasattr(self, k): | ||
delattr(self, k) | ||
return super().pop(k, d) |
Oops, something went wrong.