Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added lora, fixed inference #15

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,30 @@ def decode_and_postprocess(
)[0]
return image

from peft import get_peft_model, LoraConfig # Add this import

def get_lora_model(unet, rank=4, lora_alpha=4, target_modules=None):
if target_modules is None:
target_modules = [
"to_q",
"to_k",
"to_v",
"to_out.0",
"conv1",
"conv2",
"conv_shortcut",
"conv3",
"conv4",
]

config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=0.0,
bias="none",
)
return get_peft_model(unet, config)
Comment on lines +201 to +222
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to sd3/model.py


def run_inference_img_conditioning_with_params(
unet,
Expand All @@ -210,11 +234,22 @@ def run_inference_img_conditioning_with_params(
do_classifier_free_guidance=True,
guidance_scale=7.5,
skip_action_conditioning=False,
is_lora=False, # Changed default to False
) -> Image:
"""
Run inference with the model. If is_lora is True, assumes unet is already wrapped with LoRA.
"""
assert batch["pixel_values"].shape[0] == 1, "Batch size must be 1"
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
batch_size = batch["pixel_values"].shape[0]

# If using LoRA, we need to use the base model for inference
if is_lora:
inference_unet = unet.base_model.model
Comment on lines +248 to +249
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you directly pass that as the argument to the inference function?
A priori, the inference doesn't need to be aware of whether the model was trained with LORA or not

else:
inference_unet = unet

with torch.no_grad(), autocast(device_type="cuda", dtype=torch.float32):
actions = batch["input_ids"]
latent_height = HEIGHT // vae_scale_factor
Expand All @@ -228,7 +263,7 @@ def run_inference_img_conditioning_with_params(
dtype=torch.float32,
)
new_frame = next_latent(
unet=unet,
unet=inference_unet, # Use the inference_unet here
vae=vae,
noise_scheduler=noise_scheduler,
action_embedding=action_embedding,
Expand All @@ -242,8 +277,7 @@ def run_inference_img_conditioning_with_params(
image = decode_and_postprocess(
vae=vae, image_processor=image_processor, latents=new_frame
)
return image[0]

return image

def main(model_folder: str) -> None:
device = torch.device(
Expand Down
59 changes: 30 additions & 29 deletions sd3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,77 +26,78 @@ def get_ft_vae_decoder():
decoder_state_dict = torch.load(file_path, weights_only=True)
return decoder_state_dict


def get_model(
action_embedding_dim: int, skip_image_conditioning: bool = False
) -> tuple[
UNet2DConditionModel,
AutoencoderKL,
torch.nn.Embedding,
DDIMScheduler,
CLIPTokenizer,
CLIPTextModel,
]:
action_embedding_dim: int,
skip_image_conditioning: bool = False,
device: torch.device | None = None
) -> tuple[UNet2DConditionModel, AutoencoderKL, torch.nn.Embedding, DDIMScheduler, CLIPTokenizer, CLIPTextModel]:
Comment on lines +30 to +33
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CAn you run ruff format on that file?

"""
Args:
action_embedding_dim: the dimension of the action embedding, i.e the number of possible actions + 1 (do nothing action)
action_embedding_dim: the dimension of the action embedding
skip_image_conditioning: whether to skip image conditioning
device: the device to load the models to
"""

# This will be used to encode the actions
# Create action embedding
action_embedding = torch.nn.Embedding(
num_embeddings=action_embedding_dim + 1, embedding_dim=768
num_embeddings=action_embedding_dim + 1,
embedding_dim=768
)
torch.nn.init.normal_(action_embedding.weight, mean=0.0, std=0.02)

# DDIM scheduler allows for v-prediction and less sampling steps
# Load models with device placement
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments should primarily provide additional context for code that's non trivial to understand. In this case, the comment is simply describing what is done below, which is already explicit.

noise_scheduler = DDIMScheduler.from_pretrained(
PRETRAINED_MODEL_NAME_OR_PATH, subfolder="scheduler"
PRETRAINED_MODEL_NAME_OR_PATH,
subfolder="scheduler"
)
# This is what the paper uses
noise_scheduler.register_to_config(prediction_type="v_prediction")

vae = AutoencoderKL.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH, subfolder="vae")
# Load VAE with custom decoder directly
vae = AutoencoderKL.from_pretrained(
PRETRAINED_MODEL_NAME_OR_PATH,
subfolder="vae",
device_map=device if device else "auto"
)
decoder_state_dict = get_ft_vae_decoder()
vae.decoder.load_state_dict(decoder_state_dict)

unet = UNet2DConditionModel.from_pretrained(
PRETRAINED_MODEL_NAME_OR_PATH, subfolder="unet"
PRETRAINED_MODEL_NAME_OR_PATH,
subfolder="unet",
device_map=device if device else "auto"
)
# There are 10 noise buckets total
unet.register_to_config(num_class_embeds=NUM_BUCKETS)
# We do not use .add_module() because the class_embedding is already initialized as None
unet.class_embedding = torch.nn.Embedding(
NUM_BUCKETS, unet.time_embedding.linear_2.out_features
NUM_BUCKETS,
unet.time_embedding.linear_2.out_features
)

# Load text models
tokenizer = CLIPTokenizer.from_pretrained(
PRETRAINED_MODEL_NAME_OR_PATH, subfolder="tokenizer"
PRETRAINED_MODEL_NAME_OR_PATH,
subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
PRETRAINED_MODEL_NAME_OR_PATH, subfolder="text_encoder"
PRETRAINED_MODEL_NAME_OR_PATH,
subfolder="text_encoder",
device_map=device if device else "auto"
)

if not skip_image_conditioning:
# This is to accomodate concatenating previous frames in the channels dimension
# Modify UNet input channels
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing: the new comment describes what we're doing and not why.
The previous comment gave that additional context

new_in_channels = 4 * (BUFFER_SIZE + 1)
new_conv_in = torch.nn.Conv2d(
new_in_channels, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
)
torch.nn.init.xavier_uniform_(new_conv_in.weight)
torch.nn.init.zeros_(new_conv_in.bias)

# Replace the conv_in layer
unet.conv_in = new_conv_in
# Have to account for BUFFER SIZE conditioning frames + 1 for the noise
Comment on lines -88 to -91
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And +1 for removing those comments that are useless

unet.config["in_channels"] = new_in_channels

unet.requires_grad_(True)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
return unet, vae, action_embedding, noise_scheduler, tokenizer, text_encoder


def load_embedding_info_dict(model_folder: str) -> dict:
if os.path.exists(model_folder):
with open(os.path.join(model_folder, "embedding_info.json"), "r") as f:
Expand Down
106 changes: 101 additions & 5 deletions train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,24 @@ def parse_args():
default='arnaudstiegler/game-n-gen-finetuned-sd',
help="The name of the model to use as a base model.",
)
parser.add_argument(
"--use_lora",
action="store_true",
help="Whether to use LoRA for training.",
)
parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="The alpha parameter for LoRA scaling",
)
parser.add_argument(
"--target_modules",
type=str,
nargs="+",
default=None,
help="List of module names to apply LoRA to",
)
Comment on lines +137 to +143
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used anywhere? Also, this parameter is typically hardcoded

parser.add_argument(
"--dataset_name",
type=str,
Expand Down Expand Up @@ -510,10 +528,17 @@ def main():
exist_ok=True,
token=args.hub_token).repo_id

# This is a bit wasteful
# # This is a bit wasteful
dataset = load_dataset(args.dataset_name)
action_dim = max(max(actions) for actions in dataset['train']['actions'])

# # from sd3.model import load_model
# if args.pretrained_model_name_or_path:
# print(f"Loading pretrained model from {args.pretrained_model_name_or_path}")
# unet, vae, action_embedding, noise_scheduler, tokenizer, text_encoder = load_model(
# args.pretrained_model_name_or_path, device=accelerator.device
# )
# # unet, vae, action_embedding, noise_scheduler, tokenizer, text_encoder = load_model(args.pretrained_model_name_or_path)
# else:
Comment on lines +534 to +541
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

unet, vae, action_embedding, noise_scheduler, tokenizer, text_encoder = get_model(
action_dim, skip_image_conditioning=args.skip_image_conditioning)

Expand Down Expand Up @@ -640,10 +665,70 @@ def main():
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
)

from peft import LoraConfig, get_peft_model
def get_lora_model(unet, rank=4, lora_alpha=4, target_modules=None):
if target_modules is None:
# Default target modules for SD UNet
target_modules = [
"to_q",
"to_k",
"to_v",
"to_out.0",
"conv1",
"conv2",
"conv_shortcut",
"conv3",
"conv4",
]

config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=0.0,
bias="none",
)
return get_peft_model(unet, config)
Comment on lines +669 to +692
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should go to sd3/model.py


# Only apply LoRA if the flag is set
if args.use_lora:
logger.info("Using LoRA for training")
unet = get_lora_model(
unet,
rank=args.rank,
lora_alpha=args.lora_alpha,
)
# Only train LoRA parameters
params_to_optimize = filter(lambda p: p.requires_grad, unet.parameters())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you make sure that only the LORA parameters are marked with requires_grad? I'm not seeing this logic anywhere

else:
logger.info("Training full model (no LoRA)")
params_to_optimize = unet.parameters()

if args.skip_action_conditioning:
optimizer = optimizer_cls(
unet.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
else:
optimizer = optimizer_cls(
[
{"params": params_to_optimize},
{"params": action_embedding.parameters()},
],
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
Comment on lines +708 to +726
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from line 624 -> 639. Update the existing code rather than duplicating code


# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler)


# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
Expand Down Expand Up @@ -953,8 +1038,8 @@ def main():
num_inference_steps=50,
do_classifier_free_guidance=args.use_cfg,
guidance_scale=CFG_GUIDANCE_SCALE,
skip_action_conditioning=args.
skip_action_conditioning,
skip_action_conditioning=args.skip_action_conditioning,
is_lora=args.use_lora, # Pass the LoRA flag
)
validation_images.append(generated_image)

Expand Down Expand Up @@ -995,11 +1080,22 @@ def main():
# Save the model
accelerator.wait_for_everyone()

# Modify save function to save LoRA weights
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)

if args.use_lora:
# Save LoRA weights
unet.save_pretrained(os.path.join(args.output_dir, "lora"))
Comment on lines +1088 to +1089
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the save function rather than saving it separately

base_unet = unet.base_model.model
else:
base_unet = unet

# Save the rest of the model components
save_and_maybe_upload_to_hub(
repo_id=REPO_NAME,
output_dir=args.output_dir,
unet=unet,
unet=base_unet,
vae=vae,
noise_scheduler=noise_scheduler,
action_embedding=action_embedding,
Expand Down