-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e25e693
commit 2cca9d0
Showing
19 changed files
with
308 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import os | ||
import random | ||
from PIL import Image | ||
|
||
from torch.utils.data import Dataset | ||
import torchvision.transforms as transforms | ||
|
||
def get_nonorm_transform(resolution): | ||
nonorm_transform = transforms.Compose( | ||
[transforms.Resize((resolution, resolution), | ||
interpolation=transforms.InterpolationMode.BILINEAR), | ||
transforms.ToTensor()]) | ||
return nonorm_transform | ||
|
||
|
||
class FontDataset(Dataset): | ||
"""The dataset of font generation | ||
""" | ||
def __init__(self, args, phase, transforms=None): | ||
super().__init__() | ||
self.root = args.data_root | ||
self.phase = phase | ||
|
||
# Get Data path | ||
self.get_path() | ||
self.transforms = transforms | ||
self.nonorm_transforms = get_nonorm_transform(args.resolution) | ||
|
||
def get_path(self): | ||
self.target_images = [] | ||
# images with related style | ||
self.style_to_images = {} | ||
target_image_dir = f"{self.root}/{self.phase}/TargetImage" | ||
for style in os.listdir(target_image_dir): | ||
images_related_style = [] | ||
for img in os.listdir(f"{target_image_dir}/{style}"): | ||
img_path = f"{target_image_dir}/{style}/{img}" | ||
self.target_images.append(img_path) | ||
images_related_style.append(img_path) | ||
self.style_to_images[style] = images_related_style | ||
|
||
def __getitem__(self, index): | ||
target_image_path = self.target_images[index] | ||
target_image_name = target_image_path.split('/')[-1] | ||
style, content = target_image_name.split('.')[0].split('+') | ||
|
||
# Read content image | ||
content_image_path = f"{self.root}/{self.phase}/ContentImage/{content}.jpg" | ||
content_image = Image.open(content_image_path).convert('RGB') | ||
|
||
# Random sample used for style image | ||
images_related_style = self.style_to_images[style].copy() | ||
images_related_style.remove(target_image_path) | ||
style_image_path = random.choice(images_related_style) | ||
style_image = Image.open(style_image_path).convert("RGB") | ||
|
||
# Read target image | ||
target_image = Image.open(target_image_path).convert("RGB") | ||
nonorm_target_image = self.nonorm_transforms(target_image) | ||
|
||
if self.transforms is not None: | ||
content_image = self.transforms[0](content_image) | ||
style_image = self.transforms[1](style_image) | ||
target_image = self.transforms[2](target_image) | ||
|
||
return content_image, style_image, target_image, nonorm_target_image, target_image_path | ||
|
||
def __len__(self): | ||
return len(self.target_images) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
import os | ||
import math | ||
import time | ||
import logging | ||
from tqdm.auto import tqdm | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torchvision import transforms | ||
|
||
from accelerate import Accelerator | ||
from accelerate.logging import get_logger | ||
from accelerate.utils import set_seed | ||
from diffusers.optimization import get_scheduler | ||
|
||
from dataset.font_dataset import FontDataset | ||
from configs.fontdiffuser import get_parser | ||
from src import (FontDiffuserModel, | ||
ContentPerceptualLoss, | ||
build_unet, | ||
build_style_encoder, | ||
build_content_encoder, | ||
build_ddpm_scheduler) | ||
from utils import (save_args_to_yaml, | ||
x0_from_epsilon, | ||
reNormalize_img, | ||
normalize_mean_std) | ||
|
||
|
||
logger = get_logger(__name__) | ||
|
||
def get_args(): | ||
parser = get_parser() | ||
args = parser.parse_args() | ||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
if env_local_rank != -1 and env_local_rank != args.local_rank: | ||
args.local_rank = env_local_rank | ||
|
||
return args | ||
|
||
|
||
def main(): | ||
|
||
args = get_args() | ||
|
||
logging_dir = f"{args.output_dir}/{args.logging_dir}" | ||
|
||
accelerator = Accelerator( | ||
gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
mixed_precision=args.mixed_precision, | ||
log_with=args.report_to, | ||
project_dir=logging_dir) | ||
|
||
if accelerator.is_main_process: | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
|
||
logging.basicConfig( | ||
filename=f"{args.output_dir}/fontdiffuser_training.log", | ||
datefmt="%m/%d/%Y %H:%M:%S", | ||
level=logging.INFO) | ||
|
||
# Ser training seed | ||
if args.seed is not None: | ||
set_seed(args.seed) | ||
|
||
# Load model and noise_scheduler | ||
unet = build_unet(args=args) | ||
style_encoder = build_style_encoder(args=args) | ||
content_encoder = build_content_encoder(args=args) | ||
noise_scheduler = build_ddpm_scheduler(args) | ||
|
||
model = FontDiffuserModel( | ||
unet=unet, | ||
style_encoder=style_encoder, | ||
content_encoder=content_encoder) | ||
|
||
# Build content perceptaual Loss | ||
perceptual_loss = ContentPerceptualLoss() | ||
|
||
# Load the datasets | ||
content_transforms = transforms.Compose( | ||
[ | ||
transforms.Resize(args.content_image_size, interpolation=transforms.InterpolationMode.BILINEAR), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5], [0.5]), | ||
] | ||
) | ||
style_transforms = transforms.Compose( | ||
[ | ||
transforms.Resize(args.style_image_size, interpolation=transforms.InterpolationMode.BILINEAR), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5], [0.5]), | ||
] | ||
) | ||
target_transforms = transforms.Compose( | ||
[ | ||
transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5], [0.5]), | ||
] | ||
) | ||
train_font_dataset = FontDataset( | ||
args=args, | ||
phase='train', | ||
transforms=[ | ||
content_transforms, | ||
style_transforms, | ||
target_transforms]) | ||
train_dataloader = torch.utils.data.DataLoader( | ||
train_font_dataset, shuffle=True, batch_size=args.train_batch_size) | ||
|
||
# Build optimizer and learning rate | ||
if args.scale_lr: | ||
args.learning_rate = ( | ||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes) | ||
optimizer = torch.optim.AdamW( | ||
model.parameters(), | ||
lr=args.learning_rate, | ||
betas=(args.adam_beta1, args.adam_beta2), | ||
weight_decay=args.adam_weight_decay, | ||
eps=args.adam_epsilon) | ||
lr_scheduler = get_scheduler( | ||
args.lr_scheduler, | ||
optimizer=optimizer, | ||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | ||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,) | ||
|
||
# Accelerate preparation | ||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | ||
model, optimizer, train_dataloader, lr_scheduler) | ||
|
||
# The trackers initializes automatically on the main process. | ||
if accelerator.is_main_process: | ||
accelerator.init_trackers("fontdiffuser_training") | ||
save_args_to_yaml(args=args, output_file=f"{args.output_dir}/{args.experience_name}_config.yaml") | ||
|
||
# Only show the progress bar once on each machine. | ||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | ||
progress_bar.set_description("Steps") | ||
|
||
# Convert to the training epoch | ||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | ||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
|
||
global_step = 0 | ||
for epoch in range(num_train_epochs): | ||
train_loss = 0.0 | ||
for step, (content_images, style_images, target_images, nonorm_target_images, target_image_paths) in enumerate(train_dataloader): | ||
model.train() | ||
with accelerator.accumulate(model): | ||
# Sample noise that we'll add to the samples | ||
noise = torch.randn_like(target_images) | ||
bsz = target_images.shape[0] | ||
# Sample a random timestep for each image | ||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=target_images.device) | ||
timesteps = timesteps.long() | ||
|
||
# Add noise to the target_images according to the noise magnitude at each timestep | ||
# (this is the forward diffusion process) | ||
noisy_target_images = noise_scheduler.add_noise(target_images, noise, timesteps) | ||
|
||
# Classifier-free training strategy | ||
context_mask = torch.bernoulli(torch.zeros(bsz) + args.drop_prob) | ||
for i, mask_value in enumerate(context_mask): | ||
if mask_value==1: | ||
content_images[i, :, :, :] = 1 | ||
style_images[i, :, :, :] = 1 | ||
|
||
# Predict the noise residual and compute loss | ||
noise_pred, offset_out_sum = model( | ||
x_t=noisy_target_images, | ||
timesteps=timesteps, | ||
style_images=style_images, | ||
content_images=content_images, | ||
content_encoder_downsample_size=args.content_encoder_downsample_size) | ||
diff_loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
offset_loss = offset_out_sum / 2 | ||
|
||
# output processing for content perceptual loss | ||
pred_original_sample = x0_from_epsilon( | ||
scheduler=noise_scheduler, | ||
noise_pred=noise_pred, | ||
x_t=noisy_target_images, | ||
timesteps=timesteps) | ||
pred_original_sample = reNormalize_img(pred_original_sample) | ||
norm_pred_ori = normalize_mean_std(pred_original_sample) | ||
norm_target_ori = normalize_mean_std(nonorm_target_images) | ||
percep_loss = perceptual_loss.calculate_loss( | ||
generated_images=norm_pred_ori, | ||
target_images=norm_target_ori, | ||
device=target_images.device) | ||
|
||
loss = diff_loss + \ | ||
args.perceptual_coefficient * percep_loss + \ | ||
args.offset_coefficient * offset_loss | ||
|
||
# Gather the losses across all processes for logging (if we use distributed training). | ||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() | ||
train_loss += avg_loss.item() / args.gradient_accumulation_steps | ||
|
||
# Backpropagate | ||
accelerator.backward(loss) | ||
if accelerator.sync_gradients: | ||
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) | ||
optimizer.step() | ||
lr_scheduler.step() | ||
optimizer.zero_grad() | ||
|
||
# Checks if the accelerator has performed an optimization step behind the scenes | ||
if accelerator.sync_gradients: | ||
progress_bar.update(1) | ||
global_step += 1 | ||
accelerator.log({"train_loss": train_loss}, step=global_step) | ||
train_loss = 0.0 | ||
|
||
if accelerator.is_main_process: | ||
if global_step % args.ckpt_interval == 0: | ||
save_dir = f"{args.output_dir}/global_step_{global_step}" | ||
os.makedirs(save_dir) | ||
torch.save(model.unet.state_dict(), f"{save_dir}/unet.pth") | ||
torch.save(model.style_encoder.state_dict(), f"{save_dir}/style_encoder.pth") | ||
torch.save(model.content_encoder.state_dict(), f"{save_dir}/content_encoder.pth") | ||
torch.save(model, f"{save_dir}/total_model.pth") | ||
logging.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))}] Save the checkpoint on global step {global_step}") | ||
print("Save the checkpoint on global step {}".format(global_step)) | ||
|
||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | ||
if global_step % args.log_interval == 0: | ||
logging.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))}] Global Step {global_step} => train_loss = {loss}") | ||
progress_bar.set_postfix(**logs) | ||
|
||
# Quit | ||
if global_step >= args.max_train_steps: | ||
break | ||
|
||
accelerator.end_training() | ||
|
||
if __name__ == "__main__": | ||
main() |