Skip to content

Commit

Permalink
Add llama model change
Browse files Browse the repository at this point in the history
  • Loading branch information
yeonsily committed Jul 14, 2023
1 parent a42056d commit 50be41d
Show file tree
Hide file tree
Showing 6 changed files with 400 additions and 1 deletion.
2 changes: 1 addition & 1 deletion examples/text-generation/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def model_is_bloom(config):


def get_optimized_model_name(config):
model_names = ["bloom", "gpt2", "opt", "gptj", "gpt_neox"]
model_names = ["bloom", "gpt2", "opt", "gptj", "gpt_neox", "llama"]
for model_name in model_names:
if model_name == config.model_type:
return model_name
Expand Down
29 changes: 29 additions & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,35 @@ def _update_model_kwargs_for_generation(

return model_kwargs

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""
Copied from Transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
Remove `token_type_ids` from model_kwargs, which is not used for llama model
"""
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)
if self.config.model_type == 'llama':
for key in ["token_type_ids"]:
model_kwargs.pop(key, None)

unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)

if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)

@torch.no_grad()
def generate(
self,
Expand Down
12 changes: 12 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import transformers.models.gpt2.modeling_gpt2 as modeling_gpt2
import transformers.models.gpt_neox.modeling_gpt_neox as modeling_gpt_neox
import transformers.models.gptj.modeling_gptj as modeling_gptj
import transformers.models.llama.modeling_llama as modeling_llama
import transformers.models.opt.modeling_opt as modeling_opt
import transformers.models.t5.modeling_t5 as modeling_t5
from transformers import pytorch_utils
Expand All @@ -42,6 +43,7 @@
GaudiGPT2LMHeadModel,
GaudiGPTJForCausalLM,
GaudiGPTNeoXForCausalLM,
GaudiLlamaForCausalLM,
GaudiOPTForCausalLM,
GaudiOPTLearnedPositionalEmbedding,
GaudiT5DenseActDense,
Expand Down Expand Up @@ -73,6 +75,9 @@
gaudi_gptj_block_forward,
gaudi_gptj_model_forward,
gaudi_invert_attention_mask,
gaudi_llama_attention_forward,
gaudi_llama_decoder_layer_forward,
gaudi_llama_model_forward,
gaudi_opt_attention_forward,
gaudi_opt_decoder_forward,
gaudi_opt_decoder_layer_forward,
Expand Down Expand Up @@ -107,6 +112,7 @@ def adapt_transformers_to_gaudi():
GenerationMixin.generate = GaudiGenerationMixin.generate
GenerationMixin._update_model_kwargs_for_generation = GaudiGenerationMixin._update_model_kwargs_for_generation
GenerationMixin._expand_inputs_for_generation = staticmethod(GaudiGenerationMixin._expand_inputs_for_generation)
GenerationMixin._validate_model_kwargs = GaudiGenerationMixin._validate_model_kwargs
GenerationMixin.greedy_search = GaudiGenerationMixin.greedy_search
GenerationMixin.sample = GaudiGenerationMixin.sample
GenerationMixin.beam_search = GaudiGenerationMixin.beam_search
Expand Down Expand Up @@ -167,6 +173,12 @@ def adapt_transformers_to_gaudi():
modeling_gpt_neox.GPTNeoXLayer.forward = gaudi_gpt_neox_layer_forward
modeling_gpt_neox.GPTNeoXAttention.forward = gaudi_gpt_neox_attention_forward

# Optimization for llama generation on Gaudi
modeling_llama.LlamaForCausalLM = GaudiLlamaForCausalLM
modeling_llama.LlamaModel.forward = gaudi_llama_model_forward
modeling_llama.LlamaDecoderLayer.forward = gaudi_llama_decoder_layer_forward
modeling_llama.LlamaAttention.forward = gaudi_llama_attention_forward

# Dropout kernel improvement for Flan-T5
modeling_t5.T5Stack = GaudiT5Stack
modeling_t5.T5DenseGatedActDense = GaudiT5DenseGatedActDense
Expand Down
6 changes: 6 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
gaudi_gptj_model_forward,
)
from .modeling_all_models import gaudi_conv1d_forward, gaudi_get_extended_attention_mask, gaudi_invert_attention_mask
from .llama import (
GaudiLlamaForCausalLM,
gaudi_llama_attention_forward,
gaudi_llama_decoder_layer_forward,
gaudi_llama_model_forward,
)
from .opt import (
GaudiOPTForCausalLM,
GaudiOPTLearnedPositionalEmbedding,
Expand Down
6 changes: 6 additions & 0 deletions optimum/habana/transformers/models/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .modeling_llama import (
GaudiLlamaForCausalLM,
gaudi_llama_attention_forward,
gaudi_llama_decoder_layer_forward,
gaudi_llama_model_forward,
)
Loading

0 comments on commit 50be41d

Please sign in to comment.