Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX #2295: Warn when user reloads modified model #2306

Merged
merged 6 commits into from
Jan 7, 2025

Conversation

githubnemo
Copy link
Collaborator

@githubnemo githubnemo commented Jan 6, 2025

This is a fix for #2295.

When modifying a model with get_peft_model that was already modified in the same way, even specifying a different config may not change the trainable parameter count, e.g. when specifying target modules that are only a subset of the previous target modules.

With this patch a warning will be issued with a hint to .unload() when calling get_peft_model on an already modified model.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

When modifying a model with `get_peft_model` that was already modified
in the same way, even specifying a different config may not change
the trainable parameter count, e.g. when specifying target modules that
are only a subset of the previous target modules.

With this patch a warning will be issued with a hint to `.unload()`
when calling `get_peft_model` on an already modified model.
@githubnemo githubnemo force-pushed the feature/warn-on-reload branch from d8231de to f32e517 Compare January 6, 2025 16:45
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this warning. I have a couple of small comments, but nothing major. Please check.

@@ -181,6 +182,21 @@ def get_peft_model(
new_name = model.__dict__.get("name_or_path", None)
peft_config.base_model_name_or_path = new_name

# Especially in notebook environments there could be a case that a user
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have 120 char limits in the project, could you please configure your formatter accordingly?

# is likely that there won't be any changes for new configs on an already
# initialized PEFT model. The best we can do is warn the user about it.
try:
if len(get_layer_status(model)) > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea to use get_layer_status. I wonder if in this case, a simple

if any(isinstance(module, BaseTunerLayer) for module in model.modules())

would not serve the purpose better. This check would stop once the first PEFT layer is found, while get_layer_status would do a bunch of more work unnecessarily.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason I thought that p-tuning/prompt-tuning are also layer tuners (which they aren't) so I thought it was worthwhile to use the more complex get_layer_status. But you're correct, a simple check suffices.

tests/test_mapping.py Show resolved Hide resolved
tests/test_mapping.py Show resolved Hide resolved
get_peft_model(base_model, lora_config)

with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH):
get_peft_model(base_model, lora_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about also adding a check where the user calls get_peft_model on the PeftModel instance itself?


for warning in recwarn:
if warning_checker.matches(warning):
pytest.fail("Warning raised even though model was unloaded.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit of an edge case, but currently we would not detect if a user tries to create, say, a LoRA model based on a base model that was already modified with a prompt learning method, like prefix tuning. This is because those methods don't add any BaseTunerLayers (which is what get_layer_status relies on). Implementing such a check is probably not easy and out of scope for this PR. We could still add an xfail-ing test though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that case is also not error-prone, or is it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I'm not quite sure if it would just work or not. Coincidentally, there is a recent issue asking a similar question (#2307).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our sanity I'd suggest to go forward with the current state of things and once we know more about the interplay between soft-prompting vs. soft-prompting and soft-prompting vs. lora we adapt tests and/or implementation. WDYT?

@githubnemo
Copy link
Collaborator Author

@BenjaminBossan please review again

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes, LGTM.

You need to run make style for the linter, after that, if CI is green (or as green as it gets these days), feel free to merge.

@githubnemo githubnemo merged commit 3d2bf9a into huggingface:main Jan 7, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants