Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(modeling_base): re-order model.forward_kwargs initialization #566

Merged
merged 2 commits into from
Oct 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions trlx/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,6 @@ class PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin):
def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, peft_config=None, **kwargs):
super().__init__()
self.base_model = base_model
# cache `forward` args for general use (avoids incompatible args across architectures)
if peft_config:
# keep all kwargs for peft
self.forward_kwargs = None
else:
self.forward_kwargs = inspect.getfullargspec(base_model.forward).args

self.is_loaded_in_8bit = getattr(base_model, "is_loaded_in_8bit", False)
if self.is_loaded_in_8bit:
# TODO(glerzing): Fully test and support loading in 8-bit
Expand Down Expand Up @@ -318,6 +311,16 @@ def from_pretrained( # noqa: max-complexity
state_dict = pretrained_model_name_or_path.state_dict()

model.post_init(state_dict=state_dict)

# cache `forward` args for general use (avoids incompatible args across architectures)
if peft_config:
# Don't use the interface of the peft model,
# use the interface of the underlying transformer model instead.
# (peft adds 2 "base_model" layers)
model.forward_kwargs = inspect.getfullargspec(model.base_model.base_model.base_model.forward).args
else:
model.forward_kwargs = inspect.getfullargspec(model.base_model.forward).args

return model

def save_pretrained(self, *args, **kwargs):
Expand Down Expand Up @@ -349,27 +352,21 @@ def save_pretrained(self, *args, **kwargs):

return self.base_model.save_pretrained(*args, **kwargs)

def state_dict(self, *args, **kwargs):
"""Return the state_dict of the pretrained model."""
raise NotImplementedError

def post_init(self, *args, **kwargs):
"""Post initialization method. This method is called after the model is
instantiated and loaded from a checkpoint. It can be used to perform
additional operations such as loading the state_dict.
"""
if self.peft_type:
# Don't use the interface of the peft model,
# use the interface of the underlying transformer model instead.
# (peft adds 2 "base_model" layers)
self.forward_kwargs = inspect.getfullargspec(self.base_model.base_model.base_model.forward).args
pass

def state_dict(self, *args, **kwargs):
"""Return the state_dict of the pretrained model."""
raise NotImplementedError

def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]:
"""Filter out arguments not supported by the specific instance of
`base_model.transformer.forward`
"""
# FIXME: This is a hack to get around the fact that the `transformers`
# architectures we use don't have a consistent API for `forward` parameters.
if self.forward_kwargs is None:
return kwargs
return {k: v for k, v in kwargs.items() if k in self.forward_kwargs}
Loading