Skip to content

Commit

Permalink
[Training] Add training script
Browse files Browse the repository at this point in the history
  • Loading branch information
yeungchenwa committed Dec 10, 2023
1 parent e25e693 commit 2cca9d0
Show file tree
Hide file tree
Showing 19 changed files with 308 additions and 2 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Initially taken from GitHub's Python gitignore file
outputs/
run_sh/
dataset/
train.py

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
Binary file added data_examples/train/ContentImage/氮.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data_examples/train/ContentImage/潮.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data_examples/train/ContentImage/舶.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data_examples/train/ContentImage/镀.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
69 changes: 69 additions & 0 deletions dataset/font_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import random
from PIL import Image

from torch.utils.data import Dataset
import torchvision.transforms as transforms

def get_nonorm_transform(resolution):
nonorm_transform = transforms.Compose(
[transforms.Resize((resolution, resolution),
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor()])
return nonorm_transform


class FontDataset(Dataset):
"""The dataset of font generation
"""
def __init__(self, args, phase, transforms=None):
super().__init__()
self.root = args.data_root
self.phase = phase

# Get Data path
self.get_path()
self.transforms = transforms
self.nonorm_transforms = get_nonorm_transform(args.resolution)

def get_path(self):
self.target_images = []
# images with related style
self.style_to_images = {}
target_image_dir = f"{self.root}/{self.phase}/TargetImage"
for style in os.listdir(target_image_dir):
images_related_style = []
for img in os.listdir(f"{target_image_dir}/{style}"):
img_path = f"{target_image_dir}/{style}/{img}"
self.target_images.append(img_path)
images_related_style.append(img_path)
self.style_to_images[style] = images_related_style

def __getitem__(self, index):
target_image_path = self.target_images[index]
target_image_name = target_image_path.split('/')[-1]
style, content = target_image_name.split('.')[0].split('+')

# Read content image
content_image_path = f"{self.root}/{self.phase}/ContentImage/{content}.jpg"
content_image = Image.open(content_image_path).convert('RGB')

# Random sample used for style image
images_related_style = self.style_to_images[style].copy()
images_related_style.remove(target_image_path)
style_image_path = random.choice(images_related_style)
style_image = Image.open(style_image_path).convert("RGB")

# Read target image
target_image = Image.open(target_image_path).convert("RGB")
nonorm_target_image = self.nonorm_transforms(target_image)

if self.transforms is not None:
content_image = self.transforms[0](content_image)
style_image = self.transforms[1](style_image)
target_image = self.transforms[2](target_image)

return content_image, style_image, target_image, nonorm_target_image, target_image_path

def __len__(self):
return len(self.target_images)
239 changes: 239 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import os
import math
import time
import logging
from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from torchvision import transforms

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers.optimization import get_scheduler

from dataset.font_dataset import FontDataset
from configs.fontdiffuser import get_parser
from src import (FontDiffuserModel,
ContentPerceptualLoss,
build_unet,
build_style_encoder,
build_content_encoder,
build_ddpm_scheduler)
from utils import (save_args_to_yaml,
x0_from_epsilon,
reNormalize_img,
normalize_mean_std)


logger = get_logger(__name__)

def get_args():
parser = get_parser()
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank

return args


def main():

args = get_args()

logging_dir = f"{args.output_dir}/{args.logging_dir}"

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_dir=logging_dir)

if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)

logging.basicConfig(
filename=f"{args.output_dir}/fontdiffuser_training.log",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO)

# Ser training seed
if args.seed is not None:
set_seed(args.seed)

# Load model and noise_scheduler
unet = build_unet(args=args)
style_encoder = build_style_encoder(args=args)
content_encoder = build_content_encoder(args=args)
noise_scheduler = build_ddpm_scheduler(args)

model = FontDiffuserModel(
unet=unet,
style_encoder=style_encoder,
content_encoder=content_encoder)

# Build content perceptaual Loss
perceptual_loss = ContentPerceptualLoss()

# Load the datasets
content_transforms = transforms.Compose(
[
transforms.Resize(args.content_image_size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
style_transforms = transforms.Compose(
[
transforms.Resize(args.style_image_size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
target_transforms = transforms.Compose(
[
transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
train_font_dataset = FontDataset(
args=args,
phase='train',
transforms=[
content_transforms,
style_transforms,
target_transforms])
train_dataloader = torch.utils.data.DataLoader(
train_font_dataset, shuffle=True, batch_size=args.train_batch_size)

# Build optimizer and learning rate
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon)
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,)

# Accelerate preparation
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler)

# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("fontdiffuser_training")
save_args_to_yaml(args=args, output_file=f"{args.output_dir}/{args.experience_name}_config.yaml")

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")

# Convert to the training epoch
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

global_step = 0
for epoch in range(num_train_epochs):
train_loss = 0.0
for step, (content_images, style_images, target_images, nonorm_target_images, target_image_paths) in enumerate(train_dataloader):
model.train()
with accelerator.accumulate(model):
# Sample noise that we'll add to the samples
noise = torch.randn_like(target_images)
bsz = target_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=target_images.device)
timesteps = timesteps.long()

# Add noise to the target_images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_target_images = noise_scheduler.add_noise(target_images, noise, timesteps)

# Classifier-free training strategy
context_mask = torch.bernoulli(torch.zeros(bsz) + args.drop_prob)
for i, mask_value in enumerate(context_mask):
if mask_value==1:
content_images[i, :, :, :] = 1
style_images[i, :, :, :] = 1

# Predict the noise residual and compute loss
noise_pred, offset_out_sum = model(
x_t=noisy_target_images,
timesteps=timesteps,
style_images=style_images,
content_images=content_images,
content_encoder_downsample_size=args.content_encoder_downsample_size)
diff_loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
offset_loss = offset_out_sum / 2

# output processing for content perceptual loss
pred_original_sample = x0_from_epsilon(
scheduler=noise_scheduler,
noise_pred=noise_pred,
x_t=noisy_target_images,
timesteps=timesteps)
pred_original_sample = reNormalize_img(pred_original_sample)
norm_pred_ori = normalize_mean_std(pred_original_sample)
norm_target_ori = normalize_mean_std(nonorm_target_images)
percep_loss = perceptual_loss.calculate_loss(
generated_images=norm_pred_ori,
target_images=norm_target_ori,
device=target_images.device)

loss = diff_loss + \
args.perceptual_coefficient * percep_loss + \
args.offset_coefficient * offset_loss

# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps

# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0

if accelerator.is_main_process:
if global_step % args.ckpt_interval == 0:
save_dir = f"{args.output_dir}/global_step_{global_step}"
os.makedirs(save_dir)
torch.save(model.unet.state_dict(), f"{save_dir}/unet.pth")
torch.save(model.style_encoder.state_dict(), f"{save_dir}/style_encoder.pth")
torch.save(model.content_encoder.state_dict(), f"{save_dir}/content_encoder.pth")
torch.save(model, f"{save_dir}/total_model.pth")
logging.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))}] Save the checkpoint on global step {global_step}")
print("Save the checkpoint on global step {}".format(global_step))

logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
if global_step % args.log_interval == 0:
logging.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))}] Global Step {global_step} => train_loss = {loss}")
progress_bar.set_postfix(**logs)

# Quit
if global_step >= args.max_train_steps:
break

accelerator.end_training()

if __name__ == "__main__":
main()

0 comments on commit 2cca9d0

Please sign in to comment.