-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flux dreambooth lora training] make LoRA target modules configurable + small bug fix #9646
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Minor notes.
@@ -161,7 +161,7 @@ def log_validation( | |||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" | |||
f" {args.validation_prompt}." | |||
) | |||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's provide the author courtesy here.
parser.add_argument( | ||
"--lora_layers", | ||
type=str, | ||
default=None, | ||
help=( | ||
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' | ||
), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could do this with nargs
. Better no?
target_modules = [ | ||
"attn.to_k", | ||
"attn.to_q", | ||
"attn.to_v", | ||
"attn.to_out.0", | ||
"attn.add_k_proj", | ||
"attn.add_q_proj", | ||
"attn.add_v_proj", | ||
"attn.to_add_out", | ||
"ff.net.0.proj", | ||
"ff.net.2", | ||
"ff_context.net.0.proj", | ||
"ff_context.net.2", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a bit breaking no? Better to not do it and instead make a note from the README?
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Breaking or just changing default behavior? I think it's geared more towards the latter, but I think it's in line with the other trainers & makes sense for Transformer based models, so maybe a Warning
note and a guide on how to train it the old way for e.g.?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah maybe a warning note at the beginning of the README should cut it.
With this change, we're likely also increasing the total training wall-clock time in the default setting, so, that is worth noting.
new feature for the Flux dreambooth lora training script:
make LoRA target modules configurable through
--lora_blocks
change the current default target modules to not be attention layers only (?)
& small fix to mixed precision training for dreambooth script, as proposed in #9565