From f3c7bf2b7e5b8a6e548c808a6de1fd45b30021b5 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Tue, 10 Sep 2024 11:34:17 +0200 Subject: [PATCH 01/23] controlnet on cuda:1 --- src/flux/xflux_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index fcee1f2..1d57f8c 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -133,11 +133,11 @@ def update_model_with_lora(self, checkpoint, lora_weight): def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): self.model.to(self.device) - self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) + self.controlnet = load_controlnet(self.model_type, "cuda:1").to(torch.bfloat16) checkpoint = load_checkpoint(local_path, repo_id, name) self.controlnet.load_state_dict(checkpoint, strict=False) - self.annotator = Annotator(control_type, self.device) + self.annotator = Annotator(control_type, "cuda:1") self.controlnet_loaded = True self.control_type = control_type From 4ec28f5c8273e83782b71f1e3e0791f287ec7b6c Mon Sep 17 00:00:00 2001 From: raci0399 Date: Tue, 10 Sep 2024 12:01:33 +0200 Subject: [PATCH 02/23] controlnet on cuda1 --- .gitignore | 1 + src/flux/sampling.py | 56 +++++++++++++++++++++++++++----------- src/flux/xflux_pipeline.py | 4 ++- 3 files changed, 44 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index b079d6e..84c9dd0 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,4 @@ cython_debug/ #.idea/ .DS_Store +.aider* diff --git a/src/flux/sampling.py b/src/flux/sampling.py index e7a97f7..f106697 100644 --- a/src/flux/sampling.py +++ b/src/flux/sampling.py @@ -179,16 +179,31 @@ def denoise_controlnet( guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + + # move controlnet params to cuda:1 + img_cuda1 = img.to('cuda:1') + img_ids_cuda1 = img_ids.to('cuda:1') + controlnet_cond_cuda1 = controlnet_cond.to('cuda:1') + txt_cuda1 = txt.to('cuda:1') + txt_ids_cuda1 = txt_ids.to('cuda:1') + vec_cuda1 = vec.to('cuda:1') + t_vec_cuda1 = t_vec.to('cuda:1') + guidance_vec_cuda1 = guidance_vec.to('cuda:1') + block_res_samples = controlnet( - img=img, - img_ids=img_ids, - controlnet_cond=controlnet_cond, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=t_vec, - guidance=guidance_vec, + img=img_cuda1, + img_ids=img_ids_cuda1, + controlnet_cond=controlnet_cond_cuda1, + txt=txt_cuda1, + txt_ids=txt_ids_cuda1, + y=vec_cuda1, + timesteps=t_vec_cuda1, + guidance=guidance_vec_cuda1, ) + + # move results back to cuda:0 + block_res_samples = block_res_samples.to('cuda:0') + pred = model( img=img, img_ids=img_ids, @@ -202,16 +217,25 @@ def denoise_controlnet( ip_scale=ip_scale, ) if i >= timestep_to_start_cfg: + # move negative prompt to cuda:1 + neg_txt_cuda1 = neg_txt.to('cuda:1') + neg_txt_ids_cuda1 = neg_txt_ids.to('cuda:1') + neg_vec_cuda1 = neg_vec.to('cuda:1') + neg_block_res_samples = controlnet( - img=img, - img_ids=img_ids, - controlnet_cond=controlnet_cond, - txt=neg_txt, - txt_ids=neg_txt_ids, - y=neg_vec, - timesteps=t_vec, - guidance=guidance_vec, + img=img_cuda1, + img_ids=img_ids_cuda1, + controlnet_cond=controlnet_cond_cuda1, + txt=neg_txt_cuda1, + txt_ids=neg_txt_ids_cuda1, + y=neg_vec_cuda1, + timesteps=t_vec_cuda1, + guidance=guidance_vec_cuda1, ) + + # move results back to cuda:0 + neg_block_res_samples = neg_block_res_samples.to('cuda:0') + neg_pred = model( img=img, img_ids=img_ids, diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index 1d57f8c..508f49e 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -132,7 +132,9 @@ def update_model_with_lora(self, checkpoint, lora_weight): self.model.set_attn_processor(lora_attn_procs) def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): - self.model.to(self.device) + # the model is already on device or will be moved below in the if self.offload block + # self.model.to(self.device) + self.controlnet = load_controlnet(self.model_type, "cuda:1").to(torch.bfloat16) checkpoint = load_checkpoint(local_path, repo_id, name) From bef9d83d053e6a61111745d96dffca0a14896949 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Tue, 10 Sep 2024 12:07:45 +0200 Subject: [PATCH 03/23] move each entry --- src/flux/sampling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/flux/sampling.py b/src/flux/sampling.py index f106697..434d39e 100644 --- a/src/flux/sampling.py +++ b/src/flux/sampling.py @@ -202,7 +202,7 @@ def denoise_controlnet( ) # move results back to cuda:0 - block_res_samples = block_res_samples.to('cuda:0') + block_res_samples = [i.to('cuda:0') for i in block_res_samples] pred = model( img=img, @@ -234,8 +234,8 @@ def denoise_controlnet( ) # move results back to cuda:0 - neg_block_res_samples = neg_block_res_samples.to('cuda:0') - + neg_block_res_samples = [i.to('cuda:0') for i in neg_block_res_samples] + neg_pred = model( img=img, img_ids=img_ids, From 1a830da2d320f5ff1e89f59171fd6cd76e577482 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Fri, 13 Sep 2024 14:33:22 +0200 Subject: [PATCH 04/23] use cuda:1 --- src/flux/xflux_pipeline.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index 508f49e..db46887 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -33,12 +33,13 @@ class XFluxPipeline: def __init__(self, model_type, device, offload: bool = False): self.device = torch.device(device) + self.device1 = torch.device("cuda:1") self.offload = offload self.model_type = model_type - self.clip = load_clip(self.device) - self.t5 = load_t5(self.device, max_length=512) - self.ae = load_ae(model_type, device="cpu" if offload else self.device) + self.clip = load_clip(self.device1) + self.t5 = load_t5(self.device1, max_length=512) + self.ae = load_ae(model_type, device="cpu" if offload else self.device1) if "fp8" in model_type: self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) else: From 0a21852927ba0ef95239cad701a0c093337b95b4 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Fri, 13 Sep 2024 15:06:49 +0200 Subject: [PATCH 05/23] multigpu --- src/flux/xflux_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index db46887..52bccd1 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -339,7 +339,7 @@ def forward( self.offload_model_to_cpu(self.model) self.ae.decoder.to(x.device) x = unpack(x.float(), height, width) - x = self.ae.decode(x) + x = self.ae.decode(x.to(self.device1)) self.offload_model_to_cpu(self.ae.decoder) x1 = x.clamp(-1, 1) From a736257ec8bb67d5d0874abf149a74605fa7809b Mon Sep 17 00:00:00 2001 From: raci0399 Date: Sat, 14 Sep 2024 17:39:03 +0200 Subject: [PATCH 06/23] multigpu params --- main.py | 6 +++++- src/flux/xflux_pipeline.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index a4ccc7a..6e50452 100644 --- a/main.py +++ b/main.py @@ -131,6 +131,10 @@ def create_argparser(): parser.add_argument( "--save_path", type=str, default='results', help="Path to save" ) + parser.add_argument( + "--multi_gpu", action='store_true', default=False, + help="Enable multi-GPU support, the transformer will be loaded on the device specified by --device" + ) return parser @@ -140,7 +144,7 @@ def main(args): else: image = None - xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload) + xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload, multi_gpu=args.multi_gpu) if args.use_ip: print('load ip-adapter:', args.ip_local_path, args.ip_repo_id, args.ip_name) xflux_pipeline.set_ip(args.ip_local_path, args.ip_repo_id, args.ip_name) diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index 52bccd1..bba60f0 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -31,7 +31,7 @@ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor class XFluxPipeline: - def __init__(self, model_type, device, offload: bool = False): + def __init__(self, model_type, device, offload: bool = False, multi_gpu: bool = False): self.device = torch.device(device) self.device1 = torch.device("cuda:1") self.offload = offload From 643542a218eebc471ffd0f509dcc8dde3fc68d94 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Sat, 14 Sep 2024 18:09:18 +0200 Subject: [PATCH 07/23] renamed 2 gpu pipeline --- main.py | 6 +++--- src/flux/xflux_pipeline.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 6e50452..7cdc9fa 100644 --- a/main.py +++ b/main.py @@ -132,8 +132,8 @@ def create_argparser(): "--save_path", type=str, default='results', help="Path to save" ) parser.add_argument( - "--multi_gpu", action='store_true', default=False, - help="Enable multi-GPU support, the transformer will be loaded on the device specified by --device" + "--two_gpus_pipeline", action='store_true', default=False, + help="Enable two-GPU pipeline (cuda:0 and cuda:1), the transformer will be loaded on the device specified by --device" ) return parser @@ -144,7 +144,7 @@ def main(args): else: image = None - xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload, multi_gpu=args.multi_gpu) + xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload, two_gpus_pipeline=args.two_gpus_pipeline) if args.use_ip: print('load ip-adapter:', args.ip_local_path, args.ip_repo_id, args.ip_name) xflux_pipeline.set_ip(args.ip_local_path, args.ip_repo_id, args.ip_name) diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index bba60f0..1c4de82 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -31,7 +31,7 @@ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor class XFluxPipeline: - def __init__(self, model_type, device, offload: bool = False, multi_gpu: bool = False): + def __init__(self, model_type, device, offload: bool = False, two_gpus_pipeline: bool = False): self.device = torch.device(device) self.device1 = torch.device("cuda:1") self.offload = offload From 31534191faf5b0115a3f370918f556f203d18266 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Sat, 14 Sep 2024 18:38:11 +0200 Subject: [PATCH 08/23] add support for two gpus pipeline --- src/flux/xflux_pipeline.py | 47 ++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index 1c4de82..3a5528d 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -32,18 +32,22 @@ class XFluxPipeline: def __init__(self, model_type, device, offload: bool = False, two_gpus_pipeline: bool = False): - self.device = torch.device(device) - self.device1 = torch.device("cuda:1") + if two_gpus_pipeline: + self.model_device = torch.device(device) + self.other_device = torch.device("cuda:0" if device == "cuda:1" else "cuda:1") + else: + self.model_device = self.other_device = torch.device(device) + self.offload = offload self.model_type = model_type - self.clip = load_clip(self.device1) - self.t5 = load_t5(self.device1, max_length=512) - self.ae = load_ae(model_type, device="cpu" if offload else self.device1) + self.clip = load_clip(self.other_device) + self.t5 = load_t5(self.other_device, max_length=512) + self.ae = load_ae(model_type, device="cpu" if offload else self.other_device) if "fp8" in model_type: - self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) + self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.model_device) else: - self.model = load_flow_model(model_type, device="cpu" if offload else self.device) + self.model = load_flow_model(model_type, device="cpu" if offload else self.model_device) self.image_encoder_path = "openai/clip-vit-large-patch14" self.hf_lora_collection = "XLabs-AI/flux-lora-collection" @@ -54,7 +58,7 @@ def __init__(self, model_type, device, offload: bool = False, two_gpus_pipeline: self.ip_loaded = False def set_ip(self, local_path: str = None, repo_id = None, name: str = None): - self.model.to(self.device) + self.model.to(self.model_device) # unpack checkpoint checkpoint = load_checkpoint(local_path, repo_id, name) @@ -70,14 +74,14 @@ def set_ip(self, local_path: str = None, repo_id = None, name: str = None): # load image encoder self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( - self.device, dtype=torch.float16 + self.other_device, dtype=torch.float16 ) self.clip_image_processor = CLIPImageProcessor() # setup image embedding projection model self.improj = ImageProjModel(4096, 768, 4) self.improj.load_state_dict(proj) - self.improj = self.improj.to(self.device, dtype=torch.bfloat16) + self.improj = self.improj.to(self.other_device, dtype=torch.bfloat16) ip_attn_procs = {} @@ -89,7 +93,7 @@ def set_ip(self, local_path: str = None, repo_id = None, name: str = None): if ip_state_dict: ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) ip_attn_procs[name].load_state_dict(ip_state_dict) - ip_attn_procs[name].to(self.device, dtype=torch.bfloat16) + ip_attn_procs[name].to(self.model_device, dtype=torch.bfloat16) else: ip_attn_procs[name] = self.model.attn_processors[name] @@ -123,7 +127,7 @@ def update_model_with_lora(self, checkpoint, lora_weight): else: lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) lora_attn_procs[name].load_state_dict(lora_state_dict) - lora_attn_procs[name].to(self.device) + lora_attn_procs[name].to(self.model_device) else: if name.startswith("single_blocks"): lora_attn_procs[name] = SingleStreamBlockProcessor() @@ -133,10 +137,9 @@ def update_model_with_lora(self, checkpoint, lora_weight): self.model.set_attn_processor(lora_attn_procs) def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): - # the model is already on device or will be moved below in the if self.offload block - # self.model.to(self.device) + self.model.to(self.model_device) - self.controlnet = load_controlnet(self.model_type, "cuda:1").to(torch.bfloat16) + self.controlnet = load_controlnet(self.model_type, self.other_device).to(torch.bfloat16) checkpoint = load_checkpoint(local_path, repo_id, name) self.controlnet.load_state_dict(checkpoint, strict=False) @@ -157,7 +160,7 @@ def get_image_proj( image_prompt_embeds = self.image_encoder( image_prompt ).image_embeds.to( - device=self.device, dtype=torch.bfloat16, + device=self.model_device, dtype=torch.bfloat16, ) # encode image image_proj = self.improj(image_prompt_embeds) @@ -199,7 +202,7 @@ def __call__(self, controlnet_image = self.annotator(controlnet_image, width, height) controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) controlnet_image = controlnet_image.permute( - 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) + 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.model_device) return self.forward( prompt, @@ -281,7 +284,7 @@ def forward( neg_ip_scale=1.0, ): x = get_noise( - 1, height, width, device=self.device, + 1, height, width, model_device=self.model_device, dtype=torch.bfloat16, seed=seed ) timesteps = get_schedule( @@ -292,13 +295,13 @@ def forward( torch.manual_seed(seed) with torch.no_grad(): if self.offload: - self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) + self.t5, self.clip = self.t5.to(self.other_device), self.clip.to(self.other_device) inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt) neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) if self.offload: self.offload_model_to_cpu(self.t5, self.clip) - self.model = self.model.to(self.device) + self.model = self.model.to(self.model_device) if self.controlnet_loaded: x = denoise_controlnet( self.model, @@ -337,9 +340,9 @@ def forward( if self.offload: self.offload_model_to_cpu(self.model) - self.ae.decoder.to(x.device) + self.ae.decoder.to(self.other_device) x = unpack(x.float(), height, width) - x = self.ae.decode(x.to(self.device1)) + x = self.ae.decode(x.to(self.other_device)) self.offload_model_to_cpu(self.ae.decoder) x1 = x.clamp(-1, 1) From 1e4cfaf13040ee313a2768b0554c0c9a56245a64 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Sat, 14 Sep 2024 18:51:34 +0200 Subject: [PATCH 09/23] move results back to model's device --- src/flux/sampling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/flux/sampling.py b/src/flux/sampling.py index 434d39e..bf77ef3 100644 --- a/src/flux/sampling.py +++ b/src/flux/sampling.py @@ -201,8 +201,8 @@ def denoise_controlnet( guidance=guidance_vec_cuda1, ) - # move results back to cuda:0 - block_res_samples = [i.to('cuda:0') for i in block_res_samples] + # move results back to model's device + block_res_samples = [i.to(model.device) for i in block_res_samples] pred = model( img=img, @@ -233,8 +233,8 @@ def denoise_controlnet( guidance=guidance_vec_cuda1, ) - # move results back to cuda:0 - neg_block_res_samples = [i.to('cuda:0') for i in neg_block_res_samples] + # move results back to model's device + neg_block_res_samples = [i.to(model.device) for i in neg_block_res_samples] neg_pred = model( img=img, From 20bd5315b10599a9ca1eaa3353439d56bf0f35d1 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Sat, 14 Sep 2024 18:54:04 +0200 Subject: [PATCH 10/23] move controlnet args to controlnet's device --- src/flux/sampling.py | 58 ++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/flux/sampling.py b/src/flux/sampling.py index bf77ef3..4957502 100644 --- a/src/flux/sampling.py +++ b/src/flux/sampling.py @@ -180,25 +180,25 @@ def denoise_controlnet( for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - # move controlnet params to cuda:1 - img_cuda1 = img.to('cuda:1') - img_ids_cuda1 = img_ids.to('cuda:1') - controlnet_cond_cuda1 = controlnet_cond.to('cuda:1') - txt_cuda1 = txt.to('cuda:1') - txt_ids_cuda1 = txt_ids.to('cuda:1') - vec_cuda1 = vec.to('cuda:1') - t_vec_cuda1 = t_vec.to('cuda:1') - guidance_vec_cuda1 = guidance_vec.to('cuda:1') + # move controlnet params to controlnet's device + img_controlnet_device = img.to(controlnet.device) + img_ids_controlnet_device = img_ids.to(controlnet.device) + controlnet_cond_controlnet_device = controlnet_cond.to(controlnet.device) + txt_controlnet_device = txt.to(controlnet.device) + txt_ids_controlnet_device = txt_ids.to(controlnet.device) + vec_controlnet_device = vec.to(controlnet.device) + t_vec_controlnet_device = t_vec.to(controlnet.device) + guidance_vec_controlnet_device = guidance_vec.to(controlnet.device) block_res_samples = controlnet( - img=img_cuda1, - img_ids=img_ids_cuda1, - controlnet_cond=controlnet_cond_cuda1, - txt=txt_cuda1, - txt_ids=txt_ids_cuda1, - y=vec_cuda1, - timesteps=t_vec_cuda1, - guidance=guidance_vec_cuda1, + img=img_controlnet_device, + img_ids=img_ids_controlnet_device, + controlnet_cond=controlnet_cond_controlnet_device, + txt=txt_controlnet_device, + txt_ids=txt_ids_controlnet_device, + y=vec_controlnet_device, + timesteps=t_vec_controlnet_device, + guidance=guidance_vec_controlnet_device, ) # move results back to model's device @@ -217,20 +217,20 @@ def denoise_controlnet( ip_scale=ip_scale, ) if i >= timestep_to_start_cfg: - # move negative prompt to cuda:1 - neg_txt_cuda1 = neg_txt.to('cuda:1') - neg_txt_ids_cuda1 = neg_txt_ids.to('cuda:1') - neg_vec_cuda1 = neg_vec.to('cuda:1') + # move negative prompt to controlnet's device + neg_txt_controlnet_device = neg_txt.to(controlnet.device) + neg_txt_ids_controlnet_device = neg_txt_ids.to(controlnet.device) + neg_vec_controlnet_device = neg_vec.to(controlnet.device) neg_block_res_samples = controlnet( - img=img_cuda1, - img_ids=img_ids_cuda1, - controlnet_cond=controlnet_cond_cuda1, - txt=neg_txt_cuda1, - txt_ids=neg_txt_ids_cuda1, - y=neg_vec_cuda1, - timesteps=t_vec_cuda1, - guidance=guidance_vec_cuda1, + img=img_controlnet_device, + img_ids=img_ids_controlnet_device, + controlnet_cond=controlnet_cond_controlnet_device, + txt=neg_txt_controlnet_device, + txt_ids=neg_txt_ids_controlnet_device, + y=neg_vec_controlnet_device, + timesteps=t_vec_controlnet_device, + guidance=guidance_vec_controlnet_device, ) # move results back to model's device From a8b60498cc4cc8d29c9ca7491df22e83a34be7d0 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Sat, 14 Sep 2024 22:57:40 +0200 Subject: [PATCH 11/23] fix param --- src/flux/xflux_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index 3a5528d..ad5af62 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -284,7 +284,7 @@ def forward( neg_ip_scale=1.0, ): x = get_noise( - 1, height, width, model_device=self.model_device, + 1, height, width, device=self.model_device, dtype=torch.bfloat16, seed=seed ) timesteps = get_schedule( From e08ab95e4a00e04fac13ccc13263ac3de2be1160 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Sat, 14 Sep 2024 23:03:15 +0200 Subject: [PATCH 12/23] info when the diffusion process has started --- src/flux/xflux_pipeline.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index ad5af62..5b0ef5a 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -283,6 +283,8 @@ def forward( ip_scale=1.0, neg_ip_scale=1.0, ): + print("Starting the diffusion process...") + x = get_noise( 1, height, width, device=self.model_device, dtype=torch.bfloat16, seed=seed @@ -292,6 +294,9 @@ def forward( (width // 8) * (height // 8) // (16 * 16), shift=True, ) + + print("the schedule is:", timesteps) + torch.manual_seed(seed) with torch.no_grad(): if self.offload: From ffb5a6052cb0561654cb4b693519c39567f5b40a Mon Sep 17 00:00:00 2001 From: raci0399 Date: Sat, 14 Sep 2024 23:06:33 +0200 Subject: [PATCH 13/23] rm print --- src/flux/xflux_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index 5b0ef5a..1fa6d32 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -295,8 +295,6 @@ def forward( shift=True, ) - print("the schedule is:", timesteps) - torch.manual_seed(seed) with torch.no_grad(): if self.offload: From e81aa053553f19bb312c14d512ba1181e00cfcb7 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Sat, 14 Sep 2024 23:38:03 +0200 Subject: [PATCH 14/23] devices as params --- src/flux/sampling.py | 30 ++++++++++++++++-------------- src/flux/xflux_pipeline.py | 2 ++ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/flux/sampling.py b/src/flux/sampling.py index 4957502..37a8cdc 100644 --- a/src/flux/sampling.py +++ b/src/flux/sampling.py @@ -172,7 +172,9 @@ def denoise_controlnet( image_proj: Tensor=None, neg_image_proj: Tensor=None, ip_scale: Tensor | float = 1, - neg_ip_scale: Tensor | float = 1, + neg_ip_scale: Tensor | float = 1, + controlnet_device: torch.device = "cuda:0", + model_device: torch.device = "cuda:0" ): # this is ignored for schnell i = 0 @@ -181,14 +183,14 @@ def denoise_controlnet( t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) # move controlnet params to controlnet's device - img_controlnet_device = img.to(controlnet.device) - img_ids_controlnet_device = img_ids.to(controlnet.device) - controlnet_cond_controlnet_device = controlnet_cond.to(controlnet.device) - txt_controlnet_device = txt.to(controlnet.device) - txt_ids_controlnet_device = txt_ids.to(controlnet.device) - vec_controlnet_device = vec.to(controlnet.device) - t_vec_controlnet_device = t_vec.to(controlnet.device) - guidance_vec_controlnet_device = guidance_vec.to(controlnet.device) + img_controlnet_device = img.to(controlnet_device) + img_ids_controlnet_device = img_ids.to(controlnet_device) + controlnet_cond_controlnet_device = controlnet_cond.to(controlnet_device) + txt_controlnet_device = txt.to(controlnet_device) + txt_ids_controlnet_device = txt_ids.to(controlnet_device) + vec_controlnet_device = vec.to(controlnet_device) + t_vec_controlnet_device = t_vec.to(controlnet_device) + guidance_vec_controlnet_device = guidance_vec.to(controlnet_device) block_res_samples = controlnet( img=img_controlnet_device, @@ -202,7 +204,7 @@ def denoise_controlnet( ) # move results back to model's device - block_res_samples = [i.to(model.device) for i in block_res_samples] + block_res_samples = [i.to(model_device) for i in block_res_samples] pred = model( img=img, @@ -218,9 +220,9 @@ def denoise_controlnet( ) if i >= timestep_to_start_cfg: # move negative prompt to controlnet's device - neg_txt_controlnet_device = neg_txt.to(controlnet.device) - neg_txt_ids_controlnet_device = neg_txt_ids.to(controlnet.device) - neg_vec_controlnet_device = neg_vec.to(controlnet.device) + neg_txt_controlnet_device = neg_txt.to(controlnet_device) + neg_txt_ids_controlnet_device = neg_txt_ids.to(controlnet_device) + neg_vec_controlnet_device = neg_vec.to(controlnet_device) neg_block_res_samples = controlnet( img=img_controlnet_device, @@ -234,7 +236,7 @@ def denoise_controlnet( ) # move results back to model's device - neg_block_res_samples = [i.to(model.device) for i in neg_block_res_samples] + neg_block_res_samples = [i.to(model_device) for i in neg_block_res_samples] neg_pred = model( img=img, diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index 1fa6d32..f21cf9a 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -323,6 +323,8 @@ def forward( neg_image_proj=neg_image_proj, ip_scale=ip_scale, neg_ip_scale=neg_ip_scale, + controlnet_device=self.other_device, + model_device=self.model_device, ) else: x = denoise( From 237552596cd458afed09626c27e7ad52f0cb05fa Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 23 Sep 2024 11:30:44 +0200 Subject: [PATCH 15/23] openpose dataset and config --- image_datasets/openpose_dataset.py | 58 +++++++++++++++++++++ train_configs/test_openpose_controlnet.yaml | 25 +++++++++ 2 files changed, 83 insertions(+) create mode 100644 image_datasets/openpose_dataset.py create mode 100644 train_configs/test_openpose_controlnet.yaml diff --git a/image_datasets/openpose_dataset.py b/image_datasets/openpose_dataset.py new file mode 100644 index 0000000..d0497b9 --- /dev/null +++ b/image_datasets/openpose_dataset.py @@ -0,0 +1,58 @@ +import os +import pandas as pd +import numpy as np +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader +import json +import random +import cv2 + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + +class OpenPoseImageDataset(Dataset): + def __init__(self, img_dir, img_size=512): + self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] + self.images.sort() + self.img_size = img_size + + print('OpenPoseImageDataset: ', len(self.images)) + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + try: + json_path = self.images[idx].split('.')[0] + '.json' + json_data = json.load(open(json_path)) + + img = Image.open(self.images[idx]) + img = c_crop(img) + img = img.resize((self.img_size, self.img_size)) + img = torch.from_numpy((np.array(img) / 127.5) - 1) + img = img.permute(2, 0, 1) + + json_path = self.images[idx].split('.')[0] + '.json' + json_data = json.load(open(json_path)) + hint_path = json_data['conditioning_image'] + hint = Image.open(hint_path) + hint = torch.from_numpy((np.array(hint) / 127.5) - 1) + hint = hint.permute(2, 0, 1) + + prompt = json_data['caption'] + return img, hint, prompt + except Exception as e: + print(e) + return self.__getitem__(random.randint(0, len(self.images) - 1)) + + +def openpose_dataset_loader(train_batch_size, num_workers, **args): + dataset = OpenPoseImageDataset(**args) + return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True) diff --git a/train_configs/test_openpose_controlnet.yaml b/train_configs/test_openpose_controlnet.yaml new file mode 100644 index 0000000..7759d6a --- /dev/null +++ b/train_configs/test_openpose_controlnet.yaml @@ -0,0 +1,25 @@ +model_name: "flux-dev" +data_config: + train_batch_size: 4 + num_workers: 4 + img_size: 512 + img_dir: images/ +report_to: wandb +train_batch_size: 2 +output_dir: saves_openpose/ +max_train_steps: 100000 +learning_rate: 2e-5 +lr_scheduler: constant +lr_warmup_steps: 10 +adam_beta1: 0.9 +adam_beta2: 0.999 +adam_weight_decay: 0.01 +adam_epsilon: 1e-8 +max_grad_norm: 1.0 +logging_dir: logs +mixed_precision: "bf16" +checkpointing_steps: 2500 +checkpoints_total_limit: 10 +tracker_project_name: openpose_training +resume_from_checkpoint: latest +gradient_accumulation_steps: 2 From 73fa7686bf3ca2d2ee1f6619b97ddc5d8eb8c86a Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 23 Sep 2024 11:48:25 +0200 Subject: [PATCH 16/23] read openpose controlnet --- image_datasets/openpose_dataset.py | 9 +++++---- train_configs/test_canny_controlnet.yaml | 1 + train_flux_deepspeed_controlnet.py | 7 ++++++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/image_datasets/openpose_dataset.py b/image_datasets/openpose_dataset.py index d0497b9..1a68114 100644 --- a/image_datasets/openpose_dataset.py +++ b/image_datasets/openpose_dataset.py @@ -6,7 +6,6 @@ from torch.utils.data import Dataset, DataLoader import json import random -import cv2 def c_crop(image): width, height = image.size @@ -19,6 +18,7 @@ def c_crop(image): class OpenPoseImageDataset(Dataset): def __init__(self, img_dir, img_size=512): + self.img_dir = img_dir self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] self.images.sort() self.img_size = img_size @@ -39,15 +39,16 @@ def __getitem__(self, idx): img = torch.from_numpy((np.array(img) / 127.5) - 1) img = img.permute(2, 0, 1) - json_path = self.images[idx].split('.')[0] + '.json' - json_data = json.load(open(json_path)) - hint_path = json_data['conditioning_image'] + hint_path = os.path.join(self.img_dir, json_data['conditioning_image']) hint = Image.open(hint_path) + hint = c_crop(hint) + hint = hint.resize((self.img_size, self.img_size)) hint = torch.from_numpy((np.array(hint) / 127.5) - 1) hint = hint.permute(2, 0, 1) prompt = json_data['caption'] return img, hint, prompt + except Exception as e: print(e) return self.__getitem__(random.randint(0, len(self.images) - 1)) diff --git a/train_configs/test_canny_controlnet.yaml b/train_configs/test_canny_controlnet.yaml index 1b4c7ee..3b7da36 100644 --- a/train_configs/test_canny_controlnet.yaml +++ b/train_configs/test_canny_controlnet.yaml @@ -1,4 +1,5 @@ model_name: "flux-dev" +is_openpose: true data_config: train_batch_size: 4 num_workers: 4 diff --git a/train_flux_deepspeed_controlnet.py b/train_flux_deepspeed_controlnet.py index fa8cc69..a793c22 100644 --- a/train_flux_deepspeed_controlnet.py +++ b/train_flux_deepspeed_controlnet.py @@ -38,6 +38,7 @@ from src.flux.util import (configs, load_ae, load_clip, load_flow_model2, load_controlnet, load_t5) from image_datasets.canny_dataset import loader +from image_datasets.openpose_dataset import openpose_dataset_loader if is_wandb_available(): import wandb logger = get_logger(__name__, log_level="INFO") @@ -122,7 +123,11 @@ def main(): eps=args.adam_epsilon, ) - train_dataloader = loader(**args.data_config) + if args.is_openpose: + train_dataloader = openpose_dataset_loader(**args.data_config) + else: + train_dataloader = loader(**args.data_config) + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) From 251465827a01886d95648477e8e33b40445097ea Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 23 Sep 2024 11:54:49 +0200 Subject: [PATCH 17/23] ignore _pose files --- image_datasets/canny_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/image_datasets/canny_dataset.py b/image_datasets/canny_dataset.py index 728c756..35d3076 100644 --- a/image_datasets/canny_dataset.py +++ b/image_datasets/canny_dataset.py @@ -29,7 +29,7 @@ def c_crop(image): class CustomImageDataset(Dataset): def __init__(self, img_dir, img_size=512): - self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] + self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if ('.jpg' in i or '.png' in i) and not i.endswith('_pose.jpg') and not i.endswith('_pose.png')] self.images.sort() self.img_size = img_size From f0bbc5f084d2a926470c6ee6d1f5554a5ea7f886 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 23 Sep 2024 11:59:57 +0200 Subject: [PATCH 18/23] correct file --- train_configs/test_canny_controlnet.yaml | 1 - train_configs/test_openpose_controlnet.yaml | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/train_configs/test_canny_controlnet.yaml b/train_configs/test_canny_controlnet.yaml index 3b7da36..1b4c7ee 100644 --- a/train_configs/test_canny_controlnet.yaml +++ b/train_configs/test_canny_controlnet.yaml @@ -1,5 +1,4 @@ model_name: "flux-dev" -is_openpose: true data_config: train_batch_size: 4 num_workers: 4 diff --git a/train_configs/test_openpose_controlnet.yaml b/train_configs/test_openpose_controlnet.yaml index 7759d6a..40a1a4e 100644 --- a/train_configs/test_openpose_controlnet.yaml +++ b/train_configs/test_openpose_controlnet.yaml @@ -1,4 +1,5 @@ model_name: "flux-dev" +is_openpose: true data_config: train_batch_size: 4 num_workers: 4 From a21366f5caaa1c42c543278949e8203ca0297126 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 23 Sep 2024 12:01:06 +0200 Subject: [PATCH 19/23] correct file --- image_datasets/canny_dataset.py | 2 +- image_datasets/openpose_dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/image_datasets/canny_dataset.py b/image_datasets/canny_dataset.py index 35d3076..728c756 100644 --- a/image_datasets/canny_dataset.py +++ b/image_datasets/canny_dataset.py @@ -29,7 +29,7 @@ def c_crop(image): class CustomImageDataset(Dataset): def __init__(self, img_dir, img_size=512): - self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if ('.jpg' in i or '.png' in i) and not i.endswith('_pose.jpg') and not i.endswith('_pose.png')] + self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] self.images.sort() self.img_size = img_size diff --git a/image_datasets/openpose_dataset.py b/image_datasets/openpose_dataset.py index 1a68114..8e4c6c7 100644 --- a/image_datasets/openpose_dataset.py +++ b/image_datasets/openpose_dataset.py @@ -19,7 +19,7 @@ def c_crop(image): class OpenPoseImageDataset(Dataset): def __init__(self, img_dir, img_size=512): self.img_dir = img_dir - self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] + self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if ('.jpg' in i or '.png' in i) and not i.endswith('_pose.jpg') and not i.endswith('_pose.png')] self.images.sort() self.img_size = img_size From 3ed30ce3a3b322a64d56293bb5afae7f9312994c Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 23 Sep 2024 12:03:47 +0200 Subject: [PATCH 20/23] training config --- train_configs/test_openpose_controlnet.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_configs/test_openpose_controlnet.yaml b/train_configs/test_openpose_controlnet.yaml index 40a1a4e..efa9d81 100644 --- a/train_configs/test_openpose_controlnet.yaml +++ b/train_configs/test_openpose_controlnet.yaml @@ -1,8 +1,8 @@ model_name: "flux-dev" is_openpose: true data_config: - train_batch_size: 4 - num_workers: 4 + train_batch_size: 2 + num_workers: 2 img_size: 512 img_dir: images/ report_to: wandb From 8524be4649f358ab5f92d6706a539368bd26c060 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 23 Sep 2024 14:21:34 +0200 Subject: [PATCH 21/23] checkpoints limit and rm print --- train_configs/test_openpose_controlnet.yaml | 2 +- train_flux_deepspeed_controlnet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/train_configs/test_openpose_controlnet.yaml b/train_configs/test_openpose_controlnet.yaml index efa9d81..2d8b4bb 100644 --- a/train_configs/test_openpose_controlnet.yaml +++ b/train_configs/test_openpose_controlnet.yaml @@ -20,7 +20,7 @@ max_grad_norm: 1.0 logging_dir: logs mixed_precision: "bf16" checkpointing_steps: 2500 -checkpoints_total_limit: 10 +checkpoints_total_limit: 50 tracker_project_name: openpose_training resume_from_checkpoint: latest gradient_accumulation_steps: 2 diff --git a/train_flux_deepspeed_controlnet.py b/train_flux_deepspeed_controlnet.py index a793c22..44b5d79 100644 --- a/train_flux_deepspeed_controlnet.py +++ b/train_flux_deepspeed_controlnet.py @@ -224,7 +224,7 @@ def main(): t = torch.sigmoid(torch.randn((bs,), device=accelerator.device)) x_0 = torch.randn_like(x_1).to(accelerator.device) - print(t.shape, x_1.shape, x_0.shape) + # print(t.shape, x_1.shape, x_0.shape) x_t = (1 - t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2])) * x_1 + t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2]) * x_0 bsz = x_1.shape[0] guidance_vec = torch.full((x_t.shape[0],), 4, device=x_t.device, dtype=x_t.dtype) From b4b339dc89d75934120e71f37b1375e24c2c289d Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 23 Sep 2024 14:42:22 +0200 Subject: [PATCH 22/23] convert to rgb to support grayscale images as well --- image_datasets/openpose_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/image_datasets/openpose_dataset.py b/image_datasets/openpose_dataset.py index 8e4c6c7..356db6f 100644 --- a/image_datasets/openpose_dataset.py +++ b/image_datasets/openpose_dataset.py @@ -36,6 +36,9 @@ def __getitem__(self, idx): img = Image.open(self.images[idx]) img = c_crop(img) img = img.resize((self.img_size, self.img_size)) + # support gray scale images as well + if img.mode != 'RGB': + img = img.convert('RGB') img = torch.from_numpy((np.array(img) / 127.5) - 1) img = img.permute(2, 0, 1) From f37134040dbab68feb11b1bdb1f6fcf6aace588b Mon Sep 17 00:00:00 2001 From: raci0399 Date: Wed, 25 Sep 2024 08:49:57 +0200 Subject: [PATCH 23/23] device for annotator --- src/flux/xflux_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py index f21cf9a..f402281 100644 --- a/src/flux/xflux_pipeline.py +++ b/src/flux/xflux_pipeline.py @@ -143,7 +143,7 @@ def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str checkpoint = load_checkpoint(local_path, repo_id, name) self.controlnet.load_state_dict(checkpoint, strict=False) - self.annotator = Annotator(control_type, "cuda:1") + self.annotator = Annotator(control_type, self.other_device) self.controlnet_loaded = True self.control_type = control_type