Skip to content

Commit

Permalink
feature: awesome config (#68)
Browse files Browse the repository at this point in the history
* feature: awesome config

* fix

* rm easydict in requirements

* update version

* fix

* rename

* fix
  • Loading branch information
cnstark authored Sep 22, 2022
1 parent 587ee7d commit 828a3b2
Show file tree
Hide file tree
Showing 18 changed files with 409 additions and 208 deletions.
4 changes: 2 additions & 2 deletions easytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .config import import_config
from .config import Config, import_config
from .core import Runner, AvgMeter, MeterPool
from .launcher import launch_runner, launch_training
from .version import __version__

__all__ = [
'import_config', 'Runner', 'Runner', 'AvgMeter', 'MeterPool', 'launch_runner',
'Config', 'import_config', 'Runner', 'Runner', 'AvgMeter', 'MeterPool', 'launch_runner',
'launch_training', '__version__'
]
82 changes: 82 additions & 0 deletions easytorch/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Everything is based on config.
`Config` is the set of all configurations. `Config` is is implemented by `dict`, We recommend using `Config`.
Look at the following example:
cfg.py
```python
import os
from easytorch import Config
from my_runner import MyRunner
CFG = {}
CFG.DESC = 'my net' # customized description
CFG.RUNNER = MyRunner
CFG.GPU_NUM = 1
CFG.MODEL = {}
CFG.MODEL.NAME = 'my_net'
CFG.TRAIN = {}
CFG.TRAIN.NUM_EPOCHS = 100
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
'checkpoints',
'_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)])
)
CFG.TRAIN.CKPT_SAVE_STRATEGY = None
CFG.TRAIN.OPTIM = {}
CFG.TRAIN.OPTIM.TYPE = 'SGD'
CFG.TRAIN.OPTIM.PARAM = {
'lr': 0.002,
'momentum': 0.1,
}
CFG.TRAIN.DATA = {}
CFG.TRAIN.DATA.BATCH_SIZE = 4
CFG.TRAIN.DATA.DIR = './my_data'
CFG.TRAIN.DATA.SHUFFLE = True
CFG.TRAIN.DATA.PIN_MEMORY = True
CFG.TRAIN.DATA.PREFETCH = True
CFG.VAL = {}
CFG.VAL.INTERVAL = 1
CFG.VAL.DATA = {}
CFG.VAL.DATA.DIR = 'mnist_data'
CFG._TRAINING_INDEPENDENT` = [
'OTHER_CONFIG'
]
```
All configurations consists of two parts:
1. Training dependent configuration: changing this will affect the training results.
2. Training independent configuration: changing this will not affect the training results.
Notes:
All training dependent configurations will be calculated MD5,
this MD5 value will be the sub directory name of checkpoint save directory.
If the MD5 value is `098f6bcd4621d373cade4e832627b4f6`,
real checkpoint save directory is `{CFG.TRAIN.CKPT_SAVE_DIR}/098f6bcd4621d373cade4e832627b4f6`
Notes:
Each configuration default is training dependent,
except the key is in `TRAINING_INDEPENDENT_KEYS` or `CFG._TRAINING_INDEPENDENT`
"""
from .config import Config
from .utils import config_str, config_md5, save_config_str, copy_config_file, import_config, convert_config, \
get_ckpt_save_dir, init_cfg


__all__ = [
'Config', 'config_str', 'config_md5', 'save_config_str', 'copy_config_file',
'import_config', 'convert_config', 'get_ckpt_save_dir', 'init_cfg'
]
191 changes: 191 additions & 0 deletions easytorch/config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Modified from: https://github.com/makinacorpus/easydict/blob/master/easydict/__init__.py
from typing import overload


class Config(dict):
"""
Get attributes
>>> d = Config({'foo':3})
>>> d['foo']
3
>>> d.foo
3
>>> d.bar
Traceback (most recent call last):
...
AttributeError: 'Config' object has no attribute 'bar'
Works recursively
>>> d = Config({'foo':3, 'bar':{'x':1, 'y':2}})
>>> isinstance(d.bar, dict)
True
>>> d.bar.x
1
>>> d['bar.x']
1
>>> d.get('bar.x')
1
>>> d.get('bar.z')
None
>>> d.get('bar.z', 3)
3
>>> d.has('bar.x')
True
>>> d.has('bar.z')
False
Bullet-proof
>>> Config({})
{}
>>> Config(d={})
{}
>>> Config(None)
{}
>>> d = {'a': 1}
>>> Config(**d)
{'a': 1}
Set attributes
>>> d = Config()
>>> d.foo = 3
>>> d.foo
3
>>> d.bar = {'prop': 'value'}
>>> d.bar.prop
'value'
>>> d
{'foo': 3, 'bar': {'prop': 'value'}}
>>> d.bar.prop = 'newer'
>>> d.bar.prop
'newer'
Values extraction
>>> d = Config({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
>>> isinstance(d.bar, list)
True
>>> from operator import attrgetter
>>> map(attrgetter('x'), d.bar)
[1, 3]
>>> map(attrgetter('y'), d.bar)
[2, 4]
>>> d = Config()
>>> d.keys()
[]
>>> d = Config(foo=3, bar=dict(x=1, y=2))
>>> d.foo
3
>>> d.bar.x
1
Still like a dict though
>>> o = Config({'clean':True})
>>> o.items()
[('clean', True)]
And like a class
>>> class Flower(Config):
... power = 1
...
>>> f = Flower()
>>> f.power
1
>>> f = Flower({'height': 12})
>>> f.height
12
>>> f['power']
1
>>> sorted(f.keys())
['height', 'power']
update and pop items
>>> d = Config(a=1, b='2')
>>> e = Config(c=3.0, a=9.0)
>>> d.update(e)
>>> d.c
3.0
>>> d['c']
3.0
>>> d.get('c')
3.0
>>> d.update(a=4, b=4)
>>> d.b
4
>>> d.pop('a')
4
>>> d.a
Traceback (most recent call last):
...
AttributeError: 'Config' object has no attribute 'a'
"""

# pylint: disable=super-init-not-called
def __init__(self, d=None, **kwargs):
if d is None:
d = {}
if kwargs:
d.update(**kwargs)
for k, v in d.items():
setattr(self, k, v)
# Class attributes
for k in self.__class__.__dict__:
if not (k.startswith('__') and k.endswith('__')) and not k in ('has', 'get', 'update', 'pop'):
setattr(self, k, getattr(self, k))

def __setattr__(self, name, value):
if isinstance(value, (list, tuple)):
v = [self.__class__(x) if isinstance(x, dict) else x for x in value]
# Don't repalce tuple with list
if isinstance(value, tuple):
v = tuple(v)
value = v
elif isinstance(value, dict) and not isinstance(value, self.__class__):
value = self.__class__(value)
super().__setattr__(name, value)
super().__setitem__(name, value)

__setitem__ = __setattr__

def __getitem__(self, key):
# Support `cfg['AA.BB.CC']`
if isinstance(key, str):
keys = key.split('.')
else:
keys = key
value = super().__getitem__(keys[0])
if len(keys) > 1:
return value.__getitem__(keys[1:])
else:
return value

def has(self, key):
return self.get(key) is not None

@overload
def get(self, key): ...

def get(self, key, default=None):
# Support `cfg.get('AA.BB.CC')` and `cfg.get('AA.BB.CC', default_value)`
try:
return self[key]
except KeyError:
return default

def update(self, e=None, **f):
d = e or {}
d.update(f)
for k in d:
setattr(self, k, d[k])

def pop(self, k, d=None):
# Check for existence
if hasattr(self, k):
delattr(self, k)
return super().pop(k, d)
Loading

0 comments on commit 828a3b2

Please sign in to comment.