-
Notifications
You must be signed in to change notification settings - Fork 303
/
Copy pathload_checkpoint.py
executable file
·102 lines (87 loc) · 4.16 KB
/
load_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import torch
def load_checkpoint(model,
load_dir,
tag,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):
r"""Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
Return:
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
client_state: State dictionary used for loading required training states in the client code.
"""
load_path, client_states = _load_checkpoint(model,
load_dir,
tag,
load_module_strict=load_module_strict,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
if load_optimizer_states:
if model.zero_optimization() and load_path is not None:
model._load_zero_checkpoint(load_dir,
tag,
load_optimizer_states=load_optimizer_states)
return load_path, client_states
def _get_ckpt_name(mpu, checkpoints_path, tag):
mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank()
ckpt_name = os.path.join(checkpoints_path,
str(tag),
'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
return ckpt_name
def pre_load(mpu,
load_dir,
tag):
load_path = _get_ckpt_name(mpu, load_dir, tag)
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
return checkpoint['module']
def _load_checkpoint(model,
load_dir,
tag,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):
load_path = model._get_ckpt_name(load_dir, tag)
if not os.path.exists(load_path):
return None, None
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
model.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict)
if not model.zero_optimization() and load_optimizer_states:
if model.fp16_enabled():
model.optimizer.load_state_dict(
checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
elif load_optimizer_states:
model.optimizer.load_state_dict(checkpoint['optimizer'])
if load_lr_scheduler_states and model.lr_scheduler is not None:
model.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
model.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
model.global_steps = checkpoint['global_steps']
model.global_samples = checkpoint.get('global_samples',
model.global_steps * model.train_batch_size())
model.skipped_steps = checkpoint['skipped_steps']
model.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
model.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
deepspeed_states = [
'module',
'optimizer',
'lr_scheduler',
'csr_tensor_module_names',
'skipped_steps',
'global_steps',
'dp_world_size',
'mp_world_size'
]
client_state = {
key: value
for key,
value in checkpoint.items() if not key in deepspeed_states
}
return load_path, client_state