-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX #2295: Warn when user reloads modified model #2306
FIX #2295: Warn when user reloads modified model #2306
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
When modifying a model with `get_peft_model` that was already modified in the same way, even specifying a different config may not change the trainable parameter count, e.g. when specifying target modules that are only a subset of the previous target modules. With this patch a warning will be issued with a hint to `.unload()` when calling `get_peft_model` on an already modified model.
d8231de
to
f32e517
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this warning. I have a couple of small comments, but nothing major. Please check.
src/peft/mapping.py
Outdated
@@ -181,6 +182,21 @@ def get_peft_model( | |||
new_name = model.__dict__.get("name_or_path", None) | |||
peft_config.base_model_name_or_path = new_name | |||
|
|||
# Especially in notebook environments there could be a case that a user |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have 120 char limits in the project, could you please configure your formatter accordingly?
src/peft/mapping.py
Outdated
# is likely that there won't be any changes for new configs on an already | ||
# initialized PEFT model. The best we can do is warn the user about it. | ||
try: | ||
if len(get_layer_status(model)) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice idea to use get_layer_status
. I wonder if in this case, a simple
if any(isinstance(module, BaseTunerLayer) for module in model.modules())
would not serve the purpose better. This check would stop once the first PEFT layer is found, while get_layer_status
would do a bunch of more work unnecessarily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason I thought that p-tuning/prompt-tuning are also layer tuners (which they aren't) so I thought it was worthwhile to use the more complex get_layer_status
. But you're correct, a simple check suffices.
tests/test_mapping.py
Outdated
get_peft_model(base_model, lora_config) | ||
|
||
with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH): | ||
get_peft_model(base_model, lora_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about also adding a check where the user calls get_peft_model
on the PeftModel
instance itself?
|
||
for warning in recwarn: | ||
if warning_checker.matches(warning): | ||
pytest.fail("Warning raised even though model was unloaded.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a bit of an edge case, but currently we would not detect if a user tries to create, say, a LoRA model based on a base model that was already modified with a prompt learning method, like prefix tuning. This is because those methods don't add any BaseTunerLayer
s (which is what get_layer_status
relies on). Implementing such a check is probably not easy and out of scope for this PR. We could still add an xfail-ing test though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But that case is also not error-prone, or is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, I'm not quite sure if it would just work or not. Coincidentally, there is a recent issue asking a similar question (#2307).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For our sanity I'd suggest to go forward with the current state of things and once we know more about the interplay between soft-prompting vs. soft-prompting and soft-prompting vs. lora we adapt tests and/or implementation. WDYT?
@BenjaminBossan please review again |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, LGTM.
You need to run make style
for the linter, after that, if CI is green (or as green as it gets these days), feel free to merge.
This is a fix for #2295.
When modifying a model with
get_peft_model
that was already modified in the same way, even specifying a different config may not change the trainable parameter count, e.g. when specifying target modules that are only a subset of the previous target modules.With this patch a warning will be issued with a hint to
.unload()
when callingget_peft_model
on an already modified model.