Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added lora, fixed inference #15

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

P-H-B-D
Copy link
Collaborator

@P-H-B-D P-H-B-D commented Nov 5, 2024

  • Re-Added support for lora with flag --use_lora and associated flags --rank, --lora_alpha, and --target_modules.
  • Fixed inference to return the image outside of list

Copy link
Owner

@arnaudstiegler arnaudstiegler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments. As a general rule, avoid duplicating logic and try to reuse existing code as much as possible

Comment on lines +201 to +222
def get_lora_model(unet, rank=4, lora_alpha=4, target_modules=None):
if target_modules is None:
target_modules = [
"to_q",
"to_k",
"to_v",
"to_out.0",
"conv1",
"conv2",
"conv_shortcut",
"conv3",
"conv4",
]

config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=0.0,
bias="none",
)
return get_peft_model(unet, config)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to sd3/model.py

Comment on lines +248 to +249
if is_lora:
inference_unet = unet.base_model.model
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you directly pass that as the argument to the inference function?
A priori, the inference doesn't need to be aware of whether the model was trained with LORA or not

Comment on lines +30 to +33
action_embedding_dim: int,
skip_image_conditioning: bool = False,
device: torch.device | None = None
) -> tuple[UNet2DConditionModel, AutoencoderKL, torch.nn.Embedding, DDIMScheduler, CLIPTokenizer, CLIPTextModel]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CAn you run ruff format on that file?

)
torch.nn.init.normal_(action_embedding.weight, mean=0.0, std=0.02)

# DDIM scheduler allows for v-prediction and less sampling steps
# Load models with device placement
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments should primarily provide additional context for code that's non trivial to understand. In this case, the comment is simply describing what is done below, which is already explicit.

)

if not skip_image_conditioning:
# This is to accomodate concatenating previous frames in the channels dimension
# Modify UNet input channels
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing: the new comment describes what we're doing and not why.
The previous comment gave that additional context

Comment on lines +669 to +692
from peft import LoraConfig, get_peft_model
def get_lora_model(unet, rank=4, lora_alpha=4, target_modules=None):
if target_modules is None:
# Default target modules for SD UNet
target_modules = [
"to_q",
"to_k",
"to_v",
"to_out.0",
"conv1",
"conv2",
"conv_shortcut",
"conv3",
"conv4",
]

config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=0.0,
bias="none",
)
return get_peft_model(unet, config)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should go to sd3/model.py

Comment on lines +1088 to +1089
# Save LoRA weights
unet.save_pretrained(os.path.join(args.output_dir, "lora"))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the save function rather than saving it separately

lora_alpha=args.lora_alpha,
)
# Only train LoRA parameters
params_to_optimize = filter(lambda p: p.requires_grad, unet.parameters())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you make sure that only the LORA parameters are marked with requires_grad? I'm not seeing this logic anywhere

Comment on lines +708 to +726
if args.skip_action_conditioning:
optimizer = optimizer_cls(
unet.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
else:
optimizer = optimizer_cls(
[
{"params": params_to_optimize},
{"params": action_embedding.parameters()},
],
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from line 624 -> 639. Update the existing code rather than duplicating code

Comment on lines +137 to +143
parser.add_argument(
"--target_modules",
type=str,
nargs="+",
default=None,
help="List of module names to apply LoRA to",
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used anywhere? Also, this parameter is typically hardcoded

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants