diff --git a/fairdarts/train_search.py b/fairdarts/train_search.py index 6358193..56c1f29 100755 --- a/fairdarts/train_search.py +++ b/fairdarts/train_search.py @@ -123,7 +123,7 @@ def main(): start_epoch = checkpoint['epoch'] dur_time = checkpoint['dur_time'] model_optimizer.load_state_dict(checkpoint['model_optimizer']) - architect.arch_optimizer.load_state_dict(checkpoint['arch_optimizer']) + architect.optimizer.load_state_dict(checkpoint['arch_optimizer']) model.restore(checkpoint['network_states']) logging.info('=> loaded checkpoint \'{}\'(epoch {})'.format(args.resume, start_epoch)) else: diff --git a/requirements.txt b/requirements.txt index 71abe93..c001fab 100755 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,6 @@ Pillow==6.2.0 pyparsing==2.4.2 python-dateutil==2.8.0 six==1.12.0 -thop==0.0.31.post1910221501 +thop>=0.0.31,<0.0.32 torch==1.1.0 torchvision==0.2.1