-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: added lora, fixed inference #15
base: main
Are you sure you want to change the base?
Conversation
P-H-B-D
commented
Nov 5, 2024
- Re-Added support for lora with flag --use_lora and associated flags --rank, --lora_alpha, and --target_modules.
- Fixed inference to return the image outside of list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments. As a general rule, avoid duplicating logic and try to reuse existing code as much as possible
def get_lora_model(unet, rank=4, lora_alpha=4, target_modules=None): | ||
if target_modules is None: | ||
target_modules = [ | ||
"to_q", | ||
"to_k", | ||
"to_v", | ||
"to_out.0", | ||
"conv1", | ||
"conv2", | ||
"conv_shortcut", | ||
"conv3", | ||
"conv4", | ||
] | ||
|
||
config = LoraConfig( | ||
r=rank, | ||
lora_alpha=lora_alpha, | ||
target_modules=target_modules, | ||
lora_dropout=0.0, | ||
bias="none", | ||
) | ||
return get_peft_model(unet, config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this to sd3/model.py
if is_lora: | ||
inference_unet = unet.base_model.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you directly pass that as the argument to the inference function?
A priori, the inference doesn't need to be aware of whether the model was trained with LORA or not
action_embedding_dim: int, | ||
skip_image_conditioning: bool = False, | ||
device: torch.device | None = None | ||
) -> tuple[UNet2DConditionModel, AutoencoderKL, torch.nn.Embedding, DDIMScheduler, CLIPTokenizer, CLIPTextModel]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CAn you run ruff format
on that file?
) | ||
torch.nn.init.normal_(action_embedding.weight, mean=0.0, std=0.02) | ||
|
||
# DDIM scheduler allows for v-prediction and less sampling steps | ||
# Load models with device placement |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments should primarily provide additional context for code that's non trivial to understand. In this case, the comment is simply describing what is done below, which is already explicit.
) | ||
|
||
if not skip_image_conditioning: | ||
# This is to accomodate concatenating previous frames in the channels dimension | ||
# Modify UNet input channels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing: the new comment describes what we're doing and not why.
The previous comment gave that additional context
from peft import LoraConfig, get_peft_model | ||
def get_lora_model(unet, rank=4, lora_alpha=4, target_modules=None): | ||
if target_modules is None: | ||
# Default target modules for SD UNet | ||
target_modules = [ | ||
"to_q", | ||
"to_k", | ||
"to_v", | ||
"to_out.0", | ||
"conv1", | ||
"conv2", | ||
"conv_shortcut", | ||
"conv3", | ||
"conv4", | ||
] | ||
|
||
config = LoraConfig( | ||
r=rank, | ||
lora_alpha=lora_alpha, | ||
target_modules=target_modules, | ||
lora_dropout=0.0, | ||
bias="none", | ||
) | ||
return get_peft_model(unet, config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should go to sd3/model.py
# Save LoRA weights | ||
unet.save_pretrained(os.path.join(args.output_dir, "lora")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update the save function rather than saving it separately
lora_alpha=args.lora_alpha, | ||
) | ||
# Only train LoRA parameters | ||
params_to_optimize = filter(lambda p: p.requires_grad, unet.parameters()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you make sure that only the LORA parameters are marked with requires_grad? I'm not seeing this logic anywhere
if args.skip_action_conditioning: | ||
optimizer = optimizer_cls( | ||
unet.parameters(), | ||
lr=args.learning_rate, | ||
betas=(args.adam_beta1, args.adam_beta2), | ||
weight_decay=args.adam_weight_decay, | ||
eps=args.adam_epsilon, | ||
) | ||
else: | ||
optimizer = optimizer_cls( | ||
[ | ||
{"params": params_to_optimize}, | ||
{"params": action_embedding.parameters()}, | ||
], | ||
lr=args.learning_rate, | ||
betas=(args.adam_beta1, args.adam_beta2), | ||
weight_decay=args.adam_weight_decay, | ||
eps=args.adam_epsilon, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is copied from line 624 -> 639. Update the existing code rather than duplicating code
parser.add_argument( | ||
"--target_modules", | ||
type=str, | ||
nargs="+", | ||
default=None, | ||
help="List of module names to apply LoRA to", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used anywhere? Also, this parameter is typically hardcoded