diff --git a/.gitignore b/.gitignore index b079d6e..84c9dd0 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,4 @@ cython_debug/ #.idea/ .DS_Store +.aider* diff --git a/image_datasets/openpose_dataset.py b/image_datasets/openpose_dataset.py new file mode 100644 index 0000000..356db6f --- /dev/null +++ b/image_datasets/openpose_dataset.py @@ -0,0 +1,62 @@ +import os +import pandas as pd +import numpy as np +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader +import json +import random + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + +class OpenPoseImageDataset(Dataset): + def __init__(self, img_dir, img_size=512): + self.img_dir = img_dir + self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if ('.jpg' in i or '.png' in i) and not i.endswith('_pose.jpg') and not i.endswith('_pose.png')] + self.images.sort() + self.img_size = img_size + + print('OpenPoseImageDataset: ', len(self.images)) + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + try: + json_path = self.images[idx].split('.')[0] + '.json' + json_data = json.load(open(json_path)) + + img = Image.open(self.images[idx]) + img = c_crop(img) + img = img.resize((self.img_size, self.img_size)) + # support gray scale images as well + if img.mode != 'RGB': + img = img.convert('RGB') + img = torch.from_numpy((np.array(img) / 127.5) - 1) + img = img.permute(2, 0, 1) + + hint_path = os.path.join(self.img_dir, json_data['conditioning_image']) + hint = Image.open(hint_path) + hint = c_crop(hint) + hint = hint.resize((self.img_size, self.img_size)) + hint = torch.from_numpy((np.array(hint) / 127.5) - 1) + hint = hint.permute(2, 0, 1) + + prompt = json_data['caption'] + return img, hint, prompt + + except Exception as e: + print(e) + return self.__getitem__(random.randint(0, len(self.images) - 1)) + + +def openpose_dataset_loader(train_batch_size, num_workers, **args): + dataset = OpenPoseImageDataset(**args) + return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True) diff --git a/main.py b/main.py index a4ccc7a..7cdc9fa 100644 --- a/main.py +++ b/main.py @@ -131,6 +131,10 @@ def create_argparser(): parser.add_argument( "--save_path", type=str, default='results', help="Path to save" ) + parser.add_argument( + "--two_gpus_pipeline", action='store_true', default=False, + help="Enable two-GPU pipeline (cuda:0 and cuda:1), the transformer will be loaded on the device specified by --device" + ) return parser @@ -140,7 +144,7 @@ def main(args): else: image = None - xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload) + xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload, two_gpus_pipeline=args.two_gpus_pipeline) if args.use_ip: print('load ip-adapter:', args.ip_local_path, args.ip_repo_id, args.ip_name) xflux_pipeline.set_ip(args.ip_local_path, args.ip_repo_id, args.ip_name) diff --git a/src/flux/sampling.py b/src/flux/sampling.py index e7a97f7..37a8cdc 100644 --- a/src/flux/sampling.py +++ b/src/flux/sampling.py @@ -172,23 +172,40 @@ def denoise_controlnet( image_proj: Tensor=None, neg_image_proj: Tensor=None, ip_scale: Tensor | float = 1, - neg_ip_scale: Tensor | float = 1, + neg_ip_scale: Tensor | float = 1, + controlnet_device: torch.device = "cuda:0", + model_device: torch.device = "cuda:0" ): # this is ignored for schnell i = 0 guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + + # move controlnet params to controlnet's device + img_controlnet_device = img.to(controlnet_device) + img_ids_controlnet_device = img_ids.to(controlnet_device) + controlnet_cond_controlnet_device = controlnet_cond.to(controlnet_device) + txt_controlnet_device = txt.to(controlnet_device) + txt_ids_controlnet_device = txt_ids.to(controlnet_device) + vec_controlnet_device = vec.to(controlnet_device) + t_vec_controlnet_device = t_vec.to(controlnet_device) + guidance_vec_controlnet_device = guidance_vec.to(controlnet_device) + block_res_samples = controlnet( - img=img, - img_ids=img_ids, - controlnet_cond=controlnet_cond, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=t_vec, - guidance=guidance_vec, + img=img_controlnet_device, + img_ids=img_ids_controlnet_device, + controlnet_cond=controlnet_cond_controlnet_device, + txt=txt_controlnet_device, + txt_ids=txt_ids_controlnet_device, + y=vec_controlnet_device, + timesteps=t_vec_controlnet_device, + guidance=guidance_vec_controlnet_device, ) + + # move results back to model's device + block_res_samples = [i.to(model_device) for i in block_res_samples] + pred = model( img=img, img_ids=img_ids, @@ -202,16 +219,25 @@ def denoise_controlnet( ip_scale=ip_scale, ) if i >= timestep_to_start_cfg: + # move negative prompt to controlnet's device + neg_txt_controlnet_device = neg_txt.to(controlnet_device) + neg_txt_ids_controlnet_device = neg_txt_ids.to(controlnet_device) + neg_vec_controlnet_device = neg_vec.to(controlnet_device) + neg_block_res_samples = controlnet( - img=img, - img_ids=img_ids, - controlnet_cond=controlnet_cond, - txt=neg_txt, - txt_ids=neg_txt_ids, - y=neg_vec, - timesteps=t_vec, - guidance=guidance_vec, + img=img_controlnet_device, + img_ids=img_ids_controlnet_device, + controlnet_cond=controlnet_cond_controlnet_device, + txt=neg_txt_controlnet_device, + txt_ids=neg_txt_ids_controlnet_device, + y=neg_vec_controlnet_device, + timesteps=t_vec_controlnet_device, + guidance=guidance_vec_controlnet_device, ) + + # move results back to model's device + neg_block_res_samples = [i.to(model_device) for i in neg_block_res_samples] + neg_pred = model( img=img, img_ids=img_ids, diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index f132549..f27f5f1 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -31,18 +31,23 @@ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor class XFluxPipeline: - def __init__(self, model_type, device, offload: bool = False): - self.device = torch.device(device) + def __init__(self, model_type, device, offload: bool = False, two_gpus_pipeline: bool = False): + if two_gpus_pipeline: + self.model_device = torch.device(device) + self.other_device = torch.device("cuda:0" if device == "cuda:1" else "cuda:1") + else: + self.model_device = self.other_device = torch.device(device) + self.offload = offload self.model_type = model_type - self.clip = load_clip(self.device) - self.t5 = load_t5(self.device, max_length=512) - self.ae = load_ae(model_type, device="cpu" if offload else self.device) + self.clip = load_clip(self.other_device) + self.t5 = load_t5(self.other_device, max_length=512) + self.ae = load_ae(model_type, device="cpu" if offload else self.other_device) if "fp8" in model_type: - self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) + self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.model_device) else: - self.model = load_flow_model(model_type, device="cpu" if offload else self.device) + self.model = load_flow_model(model_type, device="cpu" if offload else self.model_device) self.image_encoder_path = "openai/clip-vit-large-patch14" self.hf_lora_collection = "XLabs-AI/flux-lora-collection" @@ -53,7 +58,7 @@ def __init__(self, model_type, device, offload: bool = False): self.ip_loaded = False def set_ip(self, local_path: str = None, repo_id = None, name: str = None): - self.model.to(self.device) + self.model.to(self.model_device) # unpack checkpoint checkpoint = load_checkpoint(local_path, repo_id, name) @@ -69,14 +74,14 @@ def set_ip(self, local_path: str = None, repo_id = None, name: str = None): # load image encoder self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( - self.device, dtype=torch.float16 + self.other_device, dtype=torch.float16 ) self.clip_image_processor = CLIPImageProcessor() # setup image embedding projection model self.improj = ImageProjModel(4096, 768, 4) self.improj.load_state_dict(proj) - self.improj = self.improj.to(self.device, dtype=torch.bfloat16) + self.improj = self.improj.to(self.other_device, dtype=torch.bfloat16) ip_attn_procs = {} @@ -88,7 +93,7 @@ def set_ip(self, local_path: str = None, repo_id = None, name: str = None): if ip_state_dict: ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) ip_attn_procs[name].load_state_dict(ip_state_dict) - ip_attn_procs[name].to(self.device, dtype=torch.bfloat16) + ip_attn_procs[name].to(self.model_device, dtype=torch.bfloat16) else: ip_attn_procs[name] = self.model.attn_processors[name] @@ -122,7 +127,7 @@ def update_model_with_lora(self, checkpoint, lora_weight): else: lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) lora_attn_procs[name].load_state_dict(lora_state_dict) - lora_attn_procs[name].to(self.device) + lora_attn_procs[name].to(self.model_device) else: if name.startswith("single_blocks"): lora_attn_procs[name] = SingleStreamBlockProcessor() @@ -132,12 +137,13 @@ def update_model_with_lora(self, checkpoint, lora_weight): self.model.set_attn_processor(lora_attn_procs) def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): - self.model.to(self.device) - self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) + self.model.to(self.model_device) + + self.controlnet = load_controlnet(self.model_type, self.other_device).to(torch.bfloat16) checkpoint = load_checkpoint(local_path, repo_id, name) self.controlnet.load_state_dict(checkpoint, strict=False) - self.annotator = Annotator(control_type, self.device) + self.annotator = Annotator(control_type, self.other_device) self.controlnet_loaded = True self.control_type = control_type @@ -154,7 +160,7 @@ def get_image_proj( image_prompt_embeds = self.image_encoder( image_prompt ).image_embeds.to( - device=self.device, dtype=torch.bfloat16, + device=self.model_device, dtype=torch.bfloat16, ) # encode image image_proj = self.improj(image_prompt_embeds) @@ -196,7 +202,7 @@ def __call__(self, controlnet_image = self.annotator(controlnet_image, width, height) controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) controlnet_image = controlnet_image.permute( - 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) + 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.model_device) return self.forward( prompt, @@ -277,8 +283,10 @@ def forward( ip_scale=1.0, neg_ip_scale=1.0, ): + print("Starting the diffusion process...") + x = get_noise( - 1, height, width, device=self.device, + 1, height, width, device=self.model_device, dtype=torch.bfloat16, seed=seed ) timesteps = get_schedule( @@ -286,16 +294,17 @@ def forward( (width // 8) * (height // 8) // (16 * 16), shift=True, ) + torch.manual_seed(seed) with torch.no_grad(): if self.offload: - self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) + self.t5, self.clip = self.t5.to(self.other_device), self.clip.to(self.other_device) inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt) neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) if self.offload: self.offload_model_to_cpu(self.t5, self.clip) - self.model = self.model.to(self.device) + self.model = self.model.to(self.model_device) if self.controlnet_loaded: x = denoise_controlnet( self.model, @@ -314,6 +323,8 @@ def forward( neg_image_proj=neg_image_proj, ip_scale=ip_scale, neg_ip_scale=neg_ip_scale, + controlnet_device=self.other_device, + model_device=self.model_device, ) else: x = denoise( @@ -334,9 +345,9 @@ def forward( if self.offload: self.offload_model_to_cpu(self.model) - self.ae.decoder.to(x.device) + self.ae.decoder.to(self.other_device) x = unpack(x.float(), height, width) - x = self.ae.decode(x) + x = self.ae.decode(x.to(self.other_device)) self.offload_model_to_cpu(self.ae.decoder) x1 = x.clamp(-1, 1) diff --git a/train_configs/test_openpose_controlnet.yaml b/train_configs/test_openpose_controlnet.yaml new file mode 100644 index 0000000..2d8b4bb --- /dev/null +++ b/train_configs/test_openpose_controlnet.yaml @@ -0,0 +1,26 @@ +model_name: "flux-dev" +is_openpose: true +data_config: + train_batch_size: 2 + num_workers: 2 + img_size: 512 + img_dir: images/ +report_to: wandb +train_batch_size: 2 +output_dir: saves_openpose/ +max_train_steps: 100000 +learning_rate: 2e-5 +lr_scheduler: constant +lr_warmup_steps: 10 +adam_beta1: 0.9 +adam_beta2: 0.999 +adam_weight_decay: 0.01 +adam_epsilon: 1e-8 +max_grad_norm: 1.0 +logging_dir: logs +mixed_precision: "bf16" +checkpointing_steps: 2500 +checkpoints_total_limit: 50 +tracker_project_name: openpose_training +resume_from_checkpoint: latest +gradient_accumulation_steps: 2 diff --git a/train_flux_deepspeed_controlnet.py b/train_flux_deepspeed_controlnet.py index fa8cc69..44b5d79 100644 --- a/train_flux_deepspeed_controlnet.py +++ b/train_flux_deepspeed_controlnet.py @@ -38,6 +38,7 @@ from src.flux.util import (configs, load_ae, load_clip, load_flow_model2, load_controlnet, load_t5) from image_datasets.canny_dataset import loader +from image_datasets.openpose_dataset import openpose_dataset_loader if is_wandb_available(): import wandb logger = get_logger(__name__, log_level="INFO") @@ -122,7 +123,11 @@ def main(): eps=args.adam_epsilon, ) - train_dataloader = loader(**args.data_config) + if args.is_openpose: + train_dataloader = openpose_dataset_loader(**args.data_config) + else: + train_dataloader = loader(**args.data_config) + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -219,7 +224,7 @@ def main(): t = torch.sigmoid(torch.randn((bs,), device=accelerator.device)) x_0 = torch.randn_like(x_1).to(accelerator.device) - print(t.shape, x_1.shape, x_0.shape) + # print(t.shape, x_1.shape, x_0.shape) x_t = (1 - t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2])) * x_1 + t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2]) * x_0 bsz = x_1.shape[0] guidance_vec = torch.full((x_t.shape[0],), 4, device=x_t.device, dtype=x_t.dtype)