-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flux dreambooth lora training] make LoRA target modules configurable + small bug fix #9646
base: main
Are you sure you want to change the base?
Changes from all commits
51b0194
beb11ea
ad37cdf
31d8576
ff5511c
f611e5f
faa95af
b17f9bf
0ca6950
8c95792
e912ff8
29152db
7276da7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -554,6 +554,15 @@ def parse_args(input_args=None): | |
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" | ||
) | ||
|
||
parser.add_argument( | ||
"--lora_layers", | ||
type=str, | ||
default=None, | ||
help=( | ||
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' | ||
), | ||
) | ||
Comment on lines
+557
to
+564
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could do this with |
||
|
||
parser.add_argument( | ||
"--adam_epsilon", | ||
type=float, | ||
|
@@ -1186,12 +1195,30 @@ def main(args): | |
if args.train_text_encoder: | ||
text_encoder_one.gradient_checkpointing_enable() | ||
|
||
# now we will add new LoRA weights to the attention layers | ||
if args.lora_layers is not None: | ||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")] | ||
else: | ||
target_modules = [ | ||
"attn.to_k", | ||
"attn.to_q", | ||
"attn.to_v", | ||
"attn.to_out.0", | ||
"attn.add_k_proj", | ||
"attn.add_q_proj", | ||
"attn.add_v_proj", | ||
"attn.to_add_out", | ||
"ff.net.0.proj", | ||
"ff.net.2", | ||
"ff_context.net.0.proj", | ||
"ff_context.net.2", | ||
] | ||
Comment on lines
+1201
to
+1214
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like a bit breaking no? Better to not do it and instead make a note from the README? WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Breaking or just changing default behavior? I think it's geared more towards the latter, but I think it's in line with the other trainers & makes sense for Transformer based models, so maybe a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah maybe a warning note at the beginning of the README should cut it. With this change, we're likely also increasing the total training wall-clock time in the default setting, so, that is worth noting. |
||
|
||
# now we will add new LoRA weights the transformer layers | ||
transformer_lora_config = LoraConfig( | ||
r=args.rank, | ||
lora_alpha=args.rank, | ||
init_lora_weights="gaussian", | ||
target_modules=["to_k", "to_q", "to_v", "to_out.0"], | ||
target_modules=target_modules, | ||
) | ||
transformer.add_adapter(transformer_lora_config) | ||
if args.train_text_encoder: | ||
|
@@ -1367,10 +1394,9 @@ def load_model_hook(models, input_dir): | |
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " | ||
f"When using prodigy only learning_rate is used as the initial learning rate." | ||
) | ||
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be | ||
# changes the learning rate of text_encoder_parameters_one to be | ||
# --learning_rate | ||
params_to_optimize[1]["lr"] = args.learning_rate | ||
params_to_optimize[2]["lr"] = args.learning_rate | ||
|
||
optimizer = optimizer_class( | ||
params_to_optimize, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's provide the author courtesy here.