Skip to content

Commit

Permalink
fix(modeling_base): re-order model.forward_kwargs initialization (#566
Browse files Browse the repository at this point in the history
)

* fix(modeling_base): re-order `model.forward_kwargs` initialization

* fix(modeling_base): revert abstract `post_init` deletion
  • Loading branch information
maxreciprocate authored Oct 17, 2023
1 parent bcd237f commit d03fea7
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions trlx/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,6 @@ class PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin):
def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, peft_config=None, **kwargs):
super().__init__()
self.base_model = base_model
# cache `forward` args for general use (avoids incompatible args across architectures)
if peft_config:
# keep all kwargs for peft
self.forward_kwargs = None
else:
self.forward_kwargs = inspect.getfullargspec(base_model.forward).args

self.is_loaded_in_8bit = getattr(base_model, "is_loaded_in_8bit", False)
if self.is_loaded_in_8bit:
# TODO(glerzing): Fully test and support loading in 8-bit
Expand Down Expand Up @@ -318,6 +311,16 @@ def from_pretrained( # noqa: max-complexity
state_dict = pretrained_model_name_or_path.state_dict()

model.post_init(state_dict=state_dict)

# cache `forward` args for general use (avoids incompatible args across architectures)
if peft_config:
# Don't use the interface of the peft model,
# use the interface of the underlying transformer model instead.
# (peft adds 2 "base_model" layers)
model.forward_kwargs = inspect.getfullargspec(model.base_model.base_model.base_model.forward).args
else:
model.forward_kwargs = inspect.getfullargspec(model.base_model.forward).args

return model

def save_pretrained(self, *args, **kwargs):
Expand Down Expand Up @@ -349,27 +352,21 @@ def save_pretrained(self, *args, **kwargs):

return self.base_model.save_pretrained(*args, **kwargs)

def state_dict(self, *args, **kwargs):
"""Return the state_dict of the pretrained model."""
raise NotImplementedError

def post_init(self, *args, **kwargs):
"""Post initialization method. This method is called after the model is
instantiated and loaded from a checkpoint. It can be used to perform
additional operations such as loading the state_dict.
"""
if self.peft_type:
# Don't use the interface of the peft model,
# use the interface of the underlying transformer model instead.
# (peft adds 2 "base_model" layers)
self.forward_kwargs = inspect.getfullargspec(self.base_model.base_model.base_model.forward).args
pass

def state_dict(self, *args, **kwargs):
"""Return the state_dict of the pretrained model."""
raise NotImplementedError

def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]:
"""Filter out arguments not supported by the specific instance of
`base_model.transformer.forward`
"""
# FIXME: This is a hack to get around the fact that the `transformers`
# architectures we use don't have a consistent API for `forward` parameters.
if self.forward_kwargs is None:
return kwargs
return {k: v for k, v in kwargs.items() if k in self.forward_kwargs}

0 comments on commit d03fea7

Please sign in to comment.