-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: added lora, fixed inference #15
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -196,6 +196,30 @@ def decode_and_postprocess( | |
)[0] | ||
return image | ||
|
||
from peft import get_peft_model, LoraConfig # Add this import | ||
|
||
def get_lora_model(unet, rank=4, lora_alpha=4, target_modules=None): | ||
if target_modules is None: | ||
target_modules = [ | ||
"to_q", | ||
"to_k", | ||
"to_v", | ||
"to_out.0", | ||
"conv1", | ||
"conv2", | ||
"conv_shortcut", | ||
"conv3", | ||
"conv4", | ||
] | ||
|
||
config = LoraConfig( | ||
r=rank, | ||
lora_alpha=lora_alpha, | ||
target_modules=target_modules, | ||
lora_dropout=0.0, | ||
bias="none", | ||
) | ||
return get_peft_model(unet, config) | ||
|
||
def run_inference_img_conditioning_with_params( | ||
unet, | ||
|
@@ -210,11 +234,22 @@ def run_inference_img_conditioning_with_params( | |
do_classifier_free_guidance=True, | ||
guidance_scale=7.5, | ||
skip_action_conditioning=False, | ||
is_lora=False, # Changed default to False | ||
) -> Image: | ||
""" | ||
Run inference with the model. If is_lora is True, assumes unet is already wrapped with LoRA. | ||
""" | ||
assert batch["pixel_values"].shape[0] == 1, "Batch size must be 1" | ||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) | ||
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) | ||
batch_size = batch["pixel_values"].shape[0] | ||
|
||
# If using LoRA, we need to use the base model for inference | ||
if is_lora: | ||
inference_unet = unet.base_model.model | ||
Comment on lines
+248
to
+249
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't you directly pass that as the argument to the inference function? |
||
else: | ||
inference_unet = unet | ||
|
||
with torch.no_grad(), autocast(device_type="cuda", dtype=torch.float32): | ||
actions = batch["input_ids"] | ||
latent_height = HEIGHT // vae_scale_factor | ||
|
@@ -228,7 +263,7 @@ def run_inference_img_conditioning_with_params( | |
dtype=torch.float32, | ||
) | ||
new_frame = next_latent( | ||
unet=unet, | ||
unet=inference_unet, # Use the inference_unet here | ||
vae=vae, | ||
noise_scheduler=noise_scheduler, | ||
action_embedding=action_embedding, | ||
|
@@ -242,8 +277,7 @@ def run_inference_img_conditioning_with_params( | |
image = decode_and_postprocess( | ||
vae=vae, image_processor=image_processor, latents=new_frame | ||
) | ||
return image[0] | ||
|
||
return image | ||
|
||
def main(model_folder: str) -> None: | ||
device = torch.device( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,77 +26,78 @@ def get_ft_vae_decoder(): | |
decoder_state_dict = torch.load(file_path, weights_only=True) | ||
return decoder_state_dict | ||
|
||
|
||
def get_model( | ||
action_embedding_dim: int, skip_image_conditioning: bool = False | ||
) -> tuple[ | ||
UNet2DConditionModel, | ||
AutoencoderKL, | ||
torch.nn.Embedding, | ||
DDIMScheduler, | ||
CLIPTokenizer, | ||
CLIPTextModel, | ||
]: | ||
action_embedding_dim: int, | ||
skip_image_conditioning: bool = False, | ||
device: torch.device | None = None | ||
) -> tuple[UNet2DConditionModel, AutoencoderKL, torch.nn.Embedding, DDIMScheduler, CLIPTokenizer, CLIPTextModel]: | ||
Comment on lines
+30
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CAn you run |
||
""" | ||
Args: | ||
action_embedding_dim: the dimension of the action embedding, i.e the number of possible actions + 1 (do nothing action) | ||
action_embedding_dim: the dimension of the action embedding | ||
skip_image_conditioning: whether to skip image conditioning | ||
device: the device to load the models to | ||
""" | ||
|
||
# This will be used to encode the actions | ||
# Create action embedding | ||
action_embedding = torch.nn.Embedding( | ||
num_embeddings=action_embedding_dim + 1, embedding_dim=768 | ||
num_embeddings=action_embedding_dim + 1, | ||
embedding_dim=768 | ||
) | ||
torch.nn.init.normal_(action_embedding.weight, mean=0.0, std=0.02) | ||
|
||
# DDIM scheduler allows for v-prediction and less sampling steps | ||
# Load models with device placement | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comments should primarily provide additional context for code that's non trivial to understand. In this case, the comment is simply describing what is done below, which is already explicit. |
||
noise_scheduler = DDIMScheduler.from_pretrained( | ||
PRETRAINED_MODEL_NAME_OR_PATH, subfolder="scheduler" | ||
PRETRAINED_MODEL_NAME_OR_PATH, | ||
subfolder="scheduler" | ||
) | ||
# This is what the paper uses | ||
noise_scheduler.register_to_config(prediction_type="v_prediction") | ||
|
||
vae = AutoencoderKL.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH, subfolder="vae") | ||
# Load VAE with custom decoder directly | ||
vae = AutoencoderKL.from_pretrained( | ||
PRETRAINED_MODEL_NAME_OR_PATH, | ||
subfolder="vae", | ||
device_map=device if device else "auto" | ||
) | ||
decoder_state_dict = get_ft_vae_decoder() | ||
vae.decoder.load_state_dict(decoder_state_dict) | ||
|
||
unet = UNet2DConditionModel.from_pretrained( | ||
PRETRAINED_MODEL_NAME_OR_PATH, subfolder="unet" | ||
PRETRAINED_MODEL_NAME_OR_PATH, | ||
subfolder="unet", | ||
device_map=device if device else "auto" | ||
) | ||
# There are 10 noise buckets total | ||
unet.register_to_config(num_class_embeds=NUM_BUCKETS) | ||
# We do not use .add_module() because the class_embedding is already initialized as None | ||
unet.class_embedding = torch.nn.Embedding( | ||
NUM_BUCKETS, unet.time_embedding.linear_2.out_features | ||
NUM_BUCKETS, | ||
unet.time_embedding.linear_2.out_features | ||
) | ||
|
||
# Load text models | ||
tokenizer = CLIPTokenizer.from_pretrained( | ||
PRETRAINED_MODEL_NAME_OR_PATH, subfolder="tokenizer" | ||
PRETRAINED_MODEL_NAME_OR_PATH, | ||
subfolder="tokenizer" | ||
) | ||
text_encoder = CLIPTextModel.from_pretrained( | ||
PRETRAINED_MODEL_NAME_OR_PATH, subfolder="text_encoder" | ||
PRETRAINED_MODEL_NAME_OR_PATH, | ||
subfolder="text_encoder", | ||
device_map=device if device else "auto" | ||
) | ||
|
||
if not skip_image_conditioning: | ||
# This is to accomodate concatenating previous frames in the channels dimension | ||
# Modify UNet input channels | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same thing: the new comment describes what we're doing and not why. |
||
new_in_channels = 4 * (BUFFER_SIZE + 1) | ||
new_conv_in = torch.nn.Conv2d( | ||
new_in_channels, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) | ||
) | ||
torch.nn.init.xavier_uniform_(new_conv_in.weight) | ||
torch.nn.init.zeros_(new_conv_in.bias) | ||
|
||
# Replace the conv_in layer | ||
unet.conv_in = new_conv_in | ||
# Have to account for BUFFER SIZE conditioning frames + 1 for the noise | ||
Comment on lines
-88
to
-91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And +1 for removing those comments that are useless |
||
unet.config["in_channels"] = new_in_channels | ||
|
||
unet.requires_grad_(True) | ||
vae.requires_grad_(False) | ||
text_encoder.requires_grad_(False) | ||
return unet, vae, action_embedding, noise_scheduler, tokenizer, text_encoder | ||
|
||
|
||
def load_embedding_info_dict(model_folder: str) -> dict: | ||
if os.path.exists(model_folder): | ||
with open(os.path.join(model_folder, "embedding_info.json"), "r") as f: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,6 +123,24 @@ def parse_args(): | |
default='arnaudstiegler/game-n-gen-finetuned-sd', | ||
help="The name of the model to use as a base model.", | ||
) | ||
parser.add_argument( | ||
"--use_lora", | ||
action="store_true", | ||
help="Whether to use LoRA for training.", | ||
) | ||
parser.add_argument( | ||
"--lora_alpha", | ||
type=int, | ||
default=4, | ||
help="The alpha parameter for LoRA scaling", | ||
) | ||
parser.add_argument( | ||
"--target_modules", | ||
type=str, | ||
nargs="+", | ||
default=None, | ||
help="List of module names to apply LoRA to", | ||
) | ||
Comment on lines
+137
to
+143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this used anywhere? Also, this parameter is typically hardcoded |
||
parser.add_argument( | ||
"--dataset_name", | ||
type=str, | ||
|
@@ -510,10 +528,17 @@ def main(): | |
exist_ok=True, | ||
token=args.hub_token).repo_id | ||
|
||
# This is a bit wasteful | ||
# # This is a bit wasteful | ||
dataset = load_dataset(args.dataset_name) | ||
action_dim = max(max(actions) for actions in dataset['train']['actions']) | ||
|
||
# # from sd3.model import load_model | ||
# if args.pretrained_model_name_or_path: | ||
# print(f"Loading pretrained model from {args.pretrained_model_name_or_path}") | ||
# unet, vae, action_embedding, noise_scheduler, tokenizer, text_encoder = load_model( | ||
# args.pretrained_model_name_or_path, device=accelerator.device | ||
# ) | ||
# # unet, vae, action_embedding, noise_scheduler, tokenizer, text_encoder = load_model(args.pretrained_model_name_or_path) | ||
# else: | ||
Comment on lines
+534
to
+541
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
unet, vae, action_embedding, noise_scheduler, tokenizer, text_encoder = get_model( | ||
action_dim, skip_image_conditioning=args.skip_image_conditioning) | ||
|
||
|
@@ -640,10 +665,70 @@ def main(): | |
num_warmup_steps=num_warmup_steps_for_scheduler, | ||
num_training_steps=num_training_steps_for_scheduler, | ||
) | ||
|
||
from peft import LoraConfig, get_peft_model | ||
def get_lora_model(unet, rank=4, lora_alpha=4, target_modules=None): | ||
if target_modules is None: | ||
# Default target modules for SD UNet | ||
target_modules = [ | ||
"to_q", | ||
"to_k", | ||
"to_v", | ||
"to_out.0", | ||
"conv1", | ||
"conv2", | ||
"conv_shortcut", | ||
"conv3", | ||
"conv4", | ||
] | ||
|
||
config = LoraConfig( | ||
r=rank, | ||
lora_alpha=lora_alpha, | ||
target_modules=target_modules, | ||
lora_dropout=0.0, | ||
bias="none", | ||
) | ||
return get_peft_model(unet, config) | ||
Comment on lines
+669
to
+692
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should go to |
||
|
||
# Only apply LoRA if the flag is set | ||
if args.use_lora: | ||
logger.info("Using LoRA for training") | ||
unet = get_lora_model( | ||
unet, | ||
rank=args.rank, | ||
lora_alpha=args.lora_alpha, | ||
) | ||
# Only train LoRA parameters | ||
params_to_optimize = filter(lambda p: p.requires_grad, unet.parameters()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you make sure that only the LORA parameters are marked with requires_grad? I'm not seeing this logic anywhere |
||
else: | ||
logger.info("Training full model (no LoRA)") | ||
params_to_optimize = unet.parameters() | ||
|
||
if args.skip_action_conditioning: | ||
optimizer = optimizer_cls( | ||
unet.parameters(), | ||
lr=args.learning_rate, | ||
betas=(args.adam_beta1, args.adam_beta2), | ||
weight_decay=args.adam_weight_decay, | ||
eps=args.adam_epsilon, | ||
) | ||
else: | ||
optimizer = optimizer_cls( | ||
[ | ||
{"params": params_to_optimize}, | ||
{"params": action_embedding.parameters()}, | ||
], | ||
lr=args.learning_rate, | ||
betas=(args.adam_beta1, args.adam_beta2), | ||
weight_decay=args.adam_weight_decay, | ||
eps=args.adam_epsilon, | ||
) | ||
Comment on lines
+708
to
+726
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is copied from line 624 -> 639. Update the existing code rather than duplicating code |
||
|
||
# Prepare everything with our `accelerator`. | ||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | ||
unet, optimizer, train_dataloader, lr_scheduler) | ||
|
||
|
||
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | ||
num_update_steps_per_epoch = math.ceil( | ||
|
@@ -953,8 +1038,8 @@ def main(): | |
num_inference_steps=50, | ||
do_classifier_free_guidance=args.use_cfg, | ||
guidance_scale=CFG_GUIDANCE_SCALE, | ||
skip_action_conditioning=args. | ||
skip_action_conditioning, | ||
skip_action_conditioning=args.skip_action_conditioning, | ||
is_lora=args.use_lora, # Pass the LoRA flag | ||
) | ||
validation_images.append(generated_image) | ||
|
||
|
@@ -995,11 +1080,22 @@ def main(): | |
# Save the model | ||
accelerator.wait_for_everyone() | ||
|
||
# Modify save function to save LoRA weights | ||
if accelerator.is_main_process: | ||
unet = accelerator.unwrap_model(unet) | ||
|
||
if args.use_lora: | ||
# Save LoRA weights | ||
unet.save_pretrained(os.path.join(args.output_dir, "lora")) | ||
Comment on lines
+1088
to
+1089
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update the save function rather than saving it separately |
||
base_unet = unet.base_model.model | ||
else: | ||
base_unet = unet | ||
|
||
# Save the rest of the model components | ||
save_and_maybe_upload_to_hub( | ||
repo_id=REPO_NAME, | ||
output_dir=args.output_dir, | ||
unet=unet, | ||
unet=base_unet, | ||
vae=vae, | ||
noise_scheduler=noise_scheduler, | ||
action_embedding=action_embedding, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this to
sd3/model.py