Skip to content

Commit

Permalink
feature: dynamic ckpt save dir (#57)
Browse files Browse the repository at this point in the history
* feature: dynamic ckpt save dir

* update version
  • Loading branch information
cnstark authored Jun 14, 2022
1 parent 62206cf commit 987ebee
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 12 deletions.
29 changes: 24 additions & 5 deletions easytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@
import hashlib
from typing import Dict, Set, List, Union

__all__ = [
'config_str', 'config_md5', 'save_config_str', 'copy_config_file',
'import_config', 'convert_config', 'get_ckpt_save_dir'
]

TRAINING_INDEPENDENT_FLAG = '_TRAINING_INDEPENDENT'

TRAINING_INDEPENDENT_KEYS = {
Expand Down Expand Up @@ -189,7 +194,7 @@ def config_md5(cfg: Dict) -> str:
return m.hexdigest()


def save_config(cfg: Dict, file_path: str):
def save_config_str(cfg: Dict, file_path: str):
"""Save config
Args:
Expand Down Expand Up @@ -245,10 +250,24 @@ def import_config(path: str, verbose: bool = True) -> Dict:


def convert_config(cfg: Dict):
"""Add MD5 to cfg and convert `CKPT_SAVE_DIR` in `CFG.TRAIN`.
"""Add MD5 to cfg.
Args:
cfg (Dict): config
cfg (Dict): config.
"""
cfg['MD5'] = config_md5(cfg)
cfg['TRAIN']['CKPT_SAVE_DIR'] = os.path.join(cfg['TRAIN']['CKPT_SAVE_DIR'], cfg['MD5'])

if cfg.get('MD5') is None:
cfg['MD5'] = config_md5(cfg)


def get_ckpt_save_dir(cfg: Dict) -> str:
"""Get real ckpt save dir with MD5.
Args:
cfg (Dict): config.
Returns:
str: Real ckpt save dir
"""

return os.path.join(cfg['TRAIN']['CKPT_SAVE_DIR'], cfg['MD5'])
3 changes: 2 additions & 1 deletion easytorch/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .checkpoint import get_ckpt_dict, load_ckpt, save_ckpt, backup_last_ckpt, clear_ckpt
from .data_loader import build_data_loader, build_data_loader_ddp
from .optimizer_builder import build_optim, build_lr_scheduler
from ..config import get_ckpt_save_dir
from ..utils import TimePredictor, get_logger, get_local_rank, is_master, master_only, set_env


Expand All @@ -30,7 +31,7 @@ def __init__(self, cfg: Dict):
# param
self.use_gpu = cfg.get('GPU_NUM', 0) != 0
self.model_name = cfg['MODEL']['NAME']
self.ckpt_save_dir = cfg['TRAIN']['CKPT_SAVE_DIR']
self.ckpt_save_dir = get_ckpt_save_dir(cfg)
self.logger.info('ckpt save dir: \'{}\''.format(self.ckpt_save_dir))
self.ckpt_save_strategy = None
self.num_epochs = None
Expand Down
11 changes: 6 additions & 5 deletions easytorch/launcher/launcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Callable, Dict, Union, Tuple

from ..config import import_config, save_config, copy_config_file, convert_config
from ..config import import_config, save_config_str, copy_config_file, convert_config, get_ckpt_save_dir
from ..utils import set_gpus
from .dist_wrap import dist_wrap

Expand All @@ -17,11 +17,12 @@ def init_cfg(cfg: Union[Dict, str], save: bool = False):
convert_config(cfg)

# save config
if save and not os.path.isdir(cfg['TRAIN']['CKPT_SAVE_DIR']):
os.makedirs(cfg['TRAIN']['CKPT_SAVE_DIR'])
save_config(cfg, os.path.join(cfg['TRAIN']['CKPT_SAVE_DIR'], 'cfg.txt'))
ckpt_save_dir = get_ckpt_save_dir(cfg)
if save and not ckpt_save_dir:
os.makedirs(ckpt_save_dir)
save_config_str(cfg, os.path.join(ckpt_save_dir, 'cfg.txt'))
if cfg_path is not None:
copy_config_file(cfg_path, cfg['TRAIN']['CKPT_SAVE_DIR'])
copy_config_file(cfg_path, ckpt_save_dir)

return cfg

Expand Down
2 changes: 1 addition & 1 deletion easytorch/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = '1.2.5'
__version__ = '1.2.6'
__all__ = ['__version__']

0 comments on commit 987ebee

Please sign in to comment.