diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 732aeecf3d..7d65405b1e 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -31,6 +31,7 @@ PeftModelForSeq2SeqLM, PeftModelForSequenceClassification, PeftModelForTokenClassification, + get_layer_status, ) from .tuners import ( AdaLoraConfig, @@ -181,6 +182,21 @@ def get_peft_model( new_name = model.__dict__.get("name_or_path", None) peft_config.base_model_name_or_path = new_name + # Especially in notebook environments there could be a case that a user + # wants to experiment with different configuration values. However, it + # is likely that there won't be any changes for new configs on an already + # initialized PEFT model. The best we can do is warn the user about it. + try: + if len(get_layer_status(model)) > 0: + warnings.warn( + "You are trying to modify a model with PEFT for a " + "second time. If you want to reload the model with a " + "different config, make sure to call `.unload()` before." + ) + except ValueError: + # not a PEFT model or no adapters in use + pass + if (old_name is not None) and (old_name != new_name): warnings.warn( f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. " diff --git a/tests/test_mapping.py b/tests/test_mapping.py new file mode 100644 index 0000000000..a29f3f4b0e --- /dev/null +++ b/tests/test_mapping.py @@ -0,0 +1,39 @@ +import pytest +import torch + + +class TestGetPeftModel: + RELOAD_WARNING_EXPECTED_MATCH = r"You are trying to modify a model .*" + + @pytest.fixture + def get_peft_model(self): + from peft import get_peft_model + + return get_peft_model + + @pytest.fixture + def lora_config(self): + from peft import LoraConfig + + return LoraConfig(target_modules="0") + + @pytest.fixture + def base_model(self): + return torch.nn.Sequential(torch.nn.Linear(10, 2)) + + def test_get_peft_model_warns_when_reloading_model(self, get_peft_model, lora_config, base_model): + get_peft_model(base_model, lora_config) + + with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH): + get_peft_model(base_model, lora_config) + + def test_get_peft_model_proposed_fix_in_warning_help(self, get_peft_model, lora_config, base_model, recwarn): + peft_model = get_peft_model(base_model, lora_config) + peft_model.unload() + get_peft_model(base_model, lora_config) + + warning_checker = pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH) + + for warning in recwarn: + if warning_checker.matches(warning): + pytest.fail("Warning raised even though model was unloaded.")