Skip to content

Commit

Permalink
fix the bug for dreambooth-sana-lora training;
Browse files Browse the repository at this point in the history
Signed-off-by: lawrence-cj <[email protected]>

pre-commit all

Signed-off-by: lawrence-cj <[email protected]>
  • Loading branch information
lawrence-cj committed Dec 18, 2024
1 parent cab0681 commit 3cfdccf
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 42 deletions.
9 changes: 3 additions & 6 deletions asset/docs/sana_lora_dreambooth.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Let's first download it locally:
```python
from huggingface_hub import snapshot_download

local_dir = "./dog"
local_dir = "data/dreambooth/dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
Expand All @@ -74,9 +74,7 @@ bash train_scripts/train_lora.sh
or you can run it locally:

```bash
huggingface-cli download diffusers/dog-example --local-dir data/dreambooth/dog --repo-type dataset

export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_diffusers"
export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
export INSTANCE_DIR="data/dreambooth/dog"
export OUTPUT_DIR="trained-sana-lora"

Expand All @@ -87,7 +85,6 @@ accelerate launch --num_processes 8 --main_process_port 29500 --gpu_ids 0,1,2,3
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--mixed_precision="fp16" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
Expand All @@ -97,7 +94,7 @@ accelerate launch --num_processes 8 --main_process_port 29500 --gpu_ids 0,1,2,3
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_prompt="A photo of sks dog in a pond, yarn art style" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
Expand Down
45 changes: 15 additions & 30 deletions train_scripts/train_dreambooth_lora_sana.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -24,13 +23,25 @@
import warnings
from pathlib import Path

import diffusers
import numpy as np
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from peft import LoraConfig, set_peft_model_state_dict
Expand All @@ -43,29 +54,6 @@
from tqdm.auto import tqdm
from transformers import AutoTokenizer, Gemma2Model

import diffusers
from diffusers import (
AutoencoderDC,
FlowMatchEulerDiscreteScheduler,
SanaPipeline,
SanaTransformer2DModel,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.utils import (
check_min_version,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module


if is_wandb_available():
import wandb

Expand Down Expand Up @@ -365,9 +353,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
Expand Down Expand Up @@ -932,6 +918,7 @@ def main(args):
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
private=True,
).repo_id

# Load the tokenizer
Expand Down Expand Up @@ -1219,9 +1206,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
vae = vae.to("cuda")
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=vae.dtype
)
batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=vae.dtype)
latents_cache.append(vae.encode(batch["pixel_values"]).latent)

if args.validation_prompt is None:
Expand Down
9 changes: 3 additions & 6 deletions train_scripts/train_lora.sh
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
#! /bin/bash

huggingface-cli download diffusers/dog-example --local-dir data/dreambooth/dog --repo-type dataset

export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_diffusers"
export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
export INSTANCE_DIR="data/dreambooth/dog"
export OUTPUT_DIR="trained-sana-lora"

accelerate launch --num_processes 8 --main_process_port 29500 --gpu_ids 0,1,2,3 \
accelerate launch --num_processes 4 --main_process_port 29500 --gpu_ids 0,1,2,3 \
train_scripts/train_dreambooth_lora_sana.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--mixed_precision="fp16" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
Expand All @@ -23,7 +20,7 @@ accelerate launch --num_processes 8 --main_process_port 29500 --gpu_ids 0,1,2,3
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_prompt="A photo of sks dog in a pond, yarn art style" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub

0 comments on commit 3cfdccf

Please sign in to comment.