Skip to content

Commit

Permalink
fix: best metrics bug when resuming training. (#58)
Browse files Browse the repository at this point in the history
* fix: best metrics bug when resuming training.

* fix: compatibility

* update version
  • Loading branch information
cnstark authored Jun 15, 2022
1 parent 987ebee commit ee1d841
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 36 deletions.
30 changes: 0 additions & 30 deletions easytorch/core/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,6 @@
DEFAULT_LOGGER = get_logger('easytorch-checkpoint')


def get_ckpt_dict(model: nn.Module, optimizer: optim.Optimizer, epoch: int) -> Dict:
"""Generate checkpoint dict.
checkpoint dict format:
{
'epoch': current epoch ([1, num_epochs]),
'model_state_dict': state_dict of model,
'optim_state_dict': state_dict of optimizer
}
if model is a module wrapper, use `model.module`
Args:
model (nn.Module): the model to be saved
optimizer (optim.Optimizer): the optimizer to be saved
epoch: current epoch
Returns:
checkpoint dict (Dict): generated checkpoint dict
"""

if isinstance(model, DDP):
_model = model.module
else:
_model = model
return {
'epoch': epoch,
'model_state_dict': _model.state_dict(),
'optim_state_dict': optimizer.state_dict()
}


def get_last_ckpt_path(ckpt_save_dir: str, name_pattern: str = r'^.+_[\d]*.pt$') -> str:
"""Get last checkpoint path in `ckpt_save_dir`
checkpoint files will be sorted by name
Expand Down
24 changes: 19 additions & 5 deletions easytorch/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.tensorboard import SummaryWriter

from .meter_pool import MeterPool
from .checkpoint import get_ckpt_dict, load_ckpt, save_ckpt, backup_last_ckpt, clear_ckpt
from .checkpoint import load_ckpt, save_ckpt, backup_last_ckpt, clear_ckpt
from .data_loader import build_data_loader, build_data_loader_ddp
from .optimizer_builder import build_optim, build_lr_scheduler
from ..config import get_ckpt_save_dir
Expand Down Expand Up @@ -213,6 +213,7 @@ def get_ckpt_path(self, epoch: int) -> str:
ckpt_name = '{}_{}.pt'.format(self.model_name, epoch_str)
return os.path.join(self.ckpt_save_dir, ckpt_name)

@master_only
def save_model(self, epoch: int):
"""Save checkpoint every epoch.
Expand All @@ -228,7 +229,13 @@ def save_model(self, epoch: int):
epoch (int): current epoch.
"""

ckpt_dict = get_ckpt_dict(self.model, self.optim, epoch)
model = self.model.module if isinstance(self.model, DDP) else self.model
ckpt_dict = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optim_state_dict': self.optim.state_dict(),
'best_metrics': self.best_metrics
}

# backup last epoch
last_ckpt_path = self.get_ckpt_path(epoch - 1)
Expand Down Expand Up @@ -263,6 +270,8 @@ def load_model_resume(self, strict: bool = True):
self.model.load_state_dict(checkpoint_dict['model_state_dict'], strict=strict)
self.optim.load_state_dict(checkpoint_dict['optim_state_dict'])
self.start_epoch = checkpoint_dict['epoch']
if checkpoint_dict.get('best_metrics') is not None:
self.best_metrics = checkpoint_dict['best_metrics']
if self.scheduler is not None:
self.scheduler.last_epoch = checkpoint_dict['epoch']
self.logger.info('resume training')
Expand Down Expand Up @@ -435,8 +444,7 @@ def on_epoch_end(self, epoch: int):
if self.val_data_loader is not None and epoch % self.val_interval == 0:
self.validate(train_epoch=epoch)
# save model
if is_master():
self.save_model(epoch)
self.save_model(epoch)
# reset meters
self.reset_epoch_meters()

Expand Down Expand Up @@ -571,7 +579,13 @@ def save_best_model(self, epoch: int, metric_name: str, greater_best: bool = Tru
best_metric = self.best_metrics.get(metric_name)
if best_metric is None or (metric > best_metric if greater_best else metric < best_metric):
self.best_metrics[metric_name] = metric
ckpt_dict = get_ckpt_dict(self.model, self.optim, epoch)
model = self.model.module if isinstance(self.model, DDP) else self.model
ckpt_dict = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optim_state_dict': self.optim.state_dict(),
'best_metrics': self.best_metrics
}
ckpt_path = os.path.join(
self.ckpt_save_dir,
'{}_best_{}.pt'.format(self.model_name, metric_name.replace('/', '_'))
Expand Down
2 changes: 1 addition & 1 deletion easytorch/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = '1.2.6'
__version__ = '1.2.7'
__all__ = ['__version__']

0 comments on commit ee1d841

Please sign in to comment.