-
Notifications
You must be signed in to change notification settings - Fork 25
/
train.py
272 lines (231 loc) · 11.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import os
import math
import time
import logging
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F
from torchvision import transforms
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers.optimization import get_scheduler
from dataset.font_dataset import FontDataset
from dataset.collate_fn import CollateFN
from configs.fontdiffuser import get_parser
from src import (FontDiffuserModel,
ContentPerceptualLoss,
build_unet,
build_style_encoder,
build_content_encoder,
build_ddpm_scheduler,
build_scr)
from utils import (save_args_to_yaml,
x0_from_epsilon,
reNormalize_img,
normalize_mean_std)
logger = get_logger(__name__)
def get_args():
parser = get_parser()
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
style_image_size = args.style_image_size
content_image_size = args.content_image_size
args.style_image_size = (style_image_size, style_image_size)
args.content_image_size = (content_image_size, content_image_size)
return args
def main():
args = get_args()
logging_dir = f"{args.output_dir}/{args.logging_dir}"
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_dir=logging_dir)
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
logging.basicConfig(
filename=f"{args.output_dir}/fontdiffuser_training.log",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO)
# Ser training seed
if args.seed is not None:
set_seed(args.seed)
# Load model and noise_scheduler
unet = build_unet(args=args)
style_encoder = build_style_encoder(args=args)
content_encoder = build_content_encoder(args=args)
noise_scheduler = build_ddpm_scheduler(args)
if args.phase_2:
unet.load_state_dict(torch.load(f"{args.phase_1_ckpt_dir}/unet.pth"))
style_encoder.load_state_dict(torch.load(f"{args.phase_1_ckpt_dir}/style_encoder.pth"))
content_encoder.load_state_dict(torch.load(f"{args.phase_1_ckpt_dir}/content_encoder.pth"))
model = FontDiffuserModel(
unet=unet,
style_encoder=style_encoder,
content_encoder=content_encoder)
# Build content perceptaual Loss
perceptual_loss = ContentPerceptualLoss()
# Load SCR module for supervision
if args.phase_2:
scr = build_scr(args=args)
scr.load_state_dict(torch.load(args.scr_ckpt_path))
scr.requires_grad_(False)
# Load the datasets
content_transforms = transforms.Compose(
[transforms.Resize(args.content_image_size,
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
style_transforms = transforms.Compose(
[transforms.Resize(args.style_image_size,
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
target_transforms = transforms.Compose(
[transforms.Resize((args.resolution, args.resolution),
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
train_font_dataset = FontDataset(
args=args,
phase='train',
transforms=[
content_transforms,
style_transforms,
target_transforms],
scr=args.phase_2)
train_dataloader = torch.utils.data.DataLoader(
train_font_dataset, shuffle=True, batch_size=args.train_batch_size, collate_fn=CollateFN())
# Build optimizer and learning rate
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon)
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,)
# Accelerate preparation
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler)
## move scr module to the target deivces
if args.phase_2:
scr = scr.to(accelerator.device)
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers(args.experience_name)
save_args_to_yaml(args=args, output_file=f"{args.output_dir}/{args.experience_name}_config.yaml")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
# Convert to the training epoch
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
global_step = 0
for epoch in range(num_train_epochs):
train_loss = 0.0
for step, samples in enumerate(train_dataloader):
model.train()
content_images = samples["content_image"]
style_images = samples["style_image"]
target_images = samples["target_image"]
nonorm_target_images = samples["nonorm_target_image"]
with accelerator.accumulate(model):
# Sample noise that we'll add to the samples
noise = torch.randn_like(target_images)
bsz = target_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=target_images.device)
timesteps = timesteps.long()
# Add noise to the target_images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_target_images = noise_scheduler.add_noise(target_images, noise, timesteps)
# Classifier-free training strategy
context_mask = torch.bernoulli(torch.zeros(bsz) + args.drop_prob)
for i, mask_value in enumerate(context_mask):
if mask_value==1:
content_images[i, :, :, :] = 1
style_images[i, :, :, :] = 1
# Predict the noise residual and compute loss
noise_pred, offset_out_sum = model(
x_t=noisy_target_images,
timesteps=timesteps,
style_images=style_images,
content_images=content_images,
content_encoder_downsample_size=args.content_encoder_downsample_size)
diff_loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
offset_loss = offset_out_sum / 2
# output processing for content perceptual loss
pred_original_sample_norm = x0_from_epsilon(
scheduler=noise_scheduler,
noise_pred=noise_pred,
x_t=noisy_target_images,
timesteps=timesteps)
pred_original_sample = reNormalize_img(pred_original_sample_norm)
norm_pred_ori = normalize_mean_std(pred_original_sample)
norm_target_ori = normalize_mean_std(nonorm_target_images)
percep_loss = perceptual_loss.calculate_loss(
generated_images=norm_pred_ori,
target_images=norm_target_ori,
device=target_images.device)
loss = diff_loss + \
args.perceptual_coefficient * percep_loss + \
args.offset_coefficient * offset_loss
if args.phase_2:
neg_images = samples["neg_images"]
# sc loss
sample_style_embeddings, pos_style_embeddings, neg_style_embeddings = scr(
pred_original_sample_norm,
target_images,
neg_images,
nce_layers=args.nce_layers)
sc_loss = scr.calculate_nce_loss(
sample_s=sample_style_embeddings,
pos_s=pos_style_embeddings,
neg_s=neg_style_embeddings)
loss += args.sc_coefficient * sc_loss
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if accelerator.is_main_process:
if global_step % args.ckpt_interval == 0:
save_dir = f"{args.output_dir}/global_step_{global_step}"
os.makedirs(save_dir, exist_ok=True)
torch.save(model.unet.state_dict(), f"{save_dir}/unet.pth")
torch.save(model.style_encoder.state_dict(), f"{save_dir}/style_encoder.pth")
torch.save(model.content_encoder.state_dict(), f"{save_dir}/content_encoder.pth")
torch.save(model, f"{save_dir}/total_model.pth")
logging.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))}] Save the checkpoint on global step {global_step}")
print("Save the checkpoint on global step {}".format(global_step))
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
if global_step % args.log_interval == 0:
logging.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))}] Global Step {global_step} => train_loss = {loss}")
progress_bar.set_postfix(**logs)
# Quit
if global_step >= args.max_train_steps:
break
accelerator.end_training()
if __name__ == "__main__":
main()