Skip to content

Commit

Permalink
In peft, only the trainable parameters need to be saved (#576)
Browse files Browse the repository at this point in the history
  • Loading branch information
sywangyi authored Dec 19, 2023
1 parent eea8062 commit 21238af
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import contextlib
import copy
import inspect
import math
import os
import random
Expand Down Expand Up @@ -119,6 +120,10 @@
import optuna


def _is_peft_model(model):
return is_peft_available() and isinstance(model, PeftModel)


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -1016,7 +1021,7 @@ def _load_best_model(self):
or os.path.exists(best_safe_adapter_model_path)
):
has_been_loaded = True
if is_peft_available() and isinstance(model, PeftModel):
if _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
Expand Down Expand Up @@ -1155,14 +1160,19 @@ def _save_checkpoint(self, model, trial, metrics=None):

if self.hp_search_backend is None and trial is None:
self.store_flos()

run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
if self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True
self.model_wrapped.save_checkpoint(output_dir)
accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
)
if accept_exclude_frozen_parameters and _is_peft_model(self.model):
self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
else:
self.model_wrapped.save_checkpoint(output_dir)

# Save optimizer and scheduler
if self.args.should_save and not self.is_deepspeed_enabled:
Expand Down Expand Up @@ -1396,7 +1406,13 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
self._save(output_dir, state_dict={})
# remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model_wrapped.save_checkpoint(output_dir)
accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
)
if accept_exclude_frozen_parameters and _is_peft_model(self.model):
self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
else:
self.model_wrapped.save_checkpoint(output_dir)
elif self.args.should_save:
self._save(output_dir)

Expand Down

0 comments on commit 21238af

Please sign in to comment.