-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtrain.py
221 lines (187 loc) · 8.18 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import argparse
import torch
import os
from tqdm import tqdm
from torch import optim
from torch.utils.data import DataLoader
import torch.utils.tensorboard as tensorboard
from torch.cuda.amp import GradScaler
from llama import LlamaTokenizer, LlamaForCausalLM
from utils import world_info_from_env, init_distributed_device, ImageTextDataSet, is_master, get_autocast
from model import MultimodalLlama
special_tokens_dict = {'additional_special_tokens': ['[boi]','[eoi]']}
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a llama model on a causal language modeling task")
parser.add_argument(
"--train_file", type=str, default='train.pkl', help="A pkl file containing the training data."
)
parser.add_argument(
"--model_name_or_path",
type=str,
default='./ckpt',
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument("--output_dir", type=str, default='out', help="Where to store the final model.")
parser.add_argument(
"--tensorboard_path", type=str, default="./tensorboard",
)
parser.add_argument(
"--image_length",
type=int,
default=10,
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=1,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--lr",
type=float,
default=4e-3,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--norm_gradient_clip", type=float, default=1.0, help="Gradient clip."
)
parser.add_argument("--beta1", type=float, default=0.98, help="Adam beta 1.")
parser.add_argument("--beta2", type=float, default=0.999, help="Adam beta 2.")
parser.add_argument("--eps", type=float, default=1e-6, help="Adam epsilon.")
parser.add_argument("--weight_decay", type=float, default=0.2, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=10, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=8,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr_scheduler_type",
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--local_rank", type=int, default=0, help="local rank.")
parser.add_argument(
"--precision",
choices=["amp", "amp_bfloat16", "fp16", "fp32"],
default="fp16",
help="Floating point precision."
)
parser.add_argument(
"--dist-url",
default="env://",
type=str,
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--horovod",
default=False,
action="store_true",
help="Use horovod for distributed training."
)
parser.add_argument(
"--debug",
default=False,
help="if in debug mode",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
print(args)
if torch.cuda.is_available():
# This enables tf32 on Ampere GPUs which is only 8% slower than
# float16 and almost as accurate as float32
# This was a default in pytorch until 1.12
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# discover initial world args early so we can log properly
args.distributed = False
args.local_rank, args.rank, args.world_size = world_info_from_env()
# fully initialize distributed device environment
device = init_distributed_device(args)
if is_master(args):
if not os.path.exists(args.tensorboard_path):
os.makedirs(args.tensorboard_path)
writer = tensorboard.SummaryWriter(args.tensorboard_path)
else:
writer = None
tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path)
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
print(num_added_tokens)
token_ids = tokenizer.convert_tokens_to_ids(['[boi]', '[eoi]'])
print(token_ids)
llama_model = LlamaForCausalLM.from_pretrained(args.model_name_or_path)
llama_model.resize_token_embeddings(len(tokenizer))
model = MultimodalLlama(image_length=args.image_length, llama=llama_model,)
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), eps=args.eps,)
scaler = GradScaler() if args.precision == "amp" else None
train_dataset = ImageTextDataSet(args.train_file, tokenizer=tokenizer, image_length=args.image_length)
train_loader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size)
for epoch in range(args.num_train_epochs):
model.train()
device = torch.device(args.device)
autocast = get_autocast(args.precision)
num_batches_per_epoch = len(train_loader)
loss_cum = .0
progress = tqdm(total=len(train_loader), desc='llama fine-tuning')
for i, batch in enumerate(train_loader):
step = num_batches_per_epoch * epoch + i
image_embedding, tokens, mask = batch
image_embedding, tokens, mask = image_embedding.to(device), tokens.to(device), mask.to(device)
optimizer.zero_grad()
with autocast():
loss = model(tokens=tokens, labels=tokens, image_embedding=image_embedding, mask=mask).loss
if scaler is not None:
scaler.scale(loss).backward()
if args.norm_gradient_clip is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0)
# Zero out the gradients for all token embeddings except the newly added embeddings
grads = model.llm.get_input_embeddings().weight.grad
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = torch.arange(len(tokenizer)) != token_ids[0]
index_grads_to_zero *= torch.arange(len(tokenizer)) != token_ids[1]
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
grads = model.llm.get_input_embeddings().weight.grad
index_grads_to_zero = torch.arange(len(tokenizer)) != token_ids[0]
index_grads_to_zero *= torch.arange(len(tokenizer)) != token_ids[1]
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
optimizer.step()
loss_cum += loss.item()
progress.set_postfix({"loss": loss_cum / (i + 1)})
progress.update()
if is_master(args) and i % 10 == 0:
writer.add_scalar("train/loss", loss.item(), step)
if args.debug == True:
break
if args.debug == True:
break
if is_master(args):
print('save modeling')
torch.save(model.state_dict(), args.output_dir + str(epoch) + '.pt')
torch.cuda.synchronize()
if __name__ == "__main__":
main()