From 4e3490f79b40248c53ee54365a9662611e880892 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Tue, 9 Apr 2024 12:01:47 +0530 Subject: [PATCH] Fix failing DeepSpeed model zoo tests (#30112) * fix sequence length errors * fix label column name error for vit * fix the lm_head embedding!=linear layer mismatches for Seq2Seq models --- src/transformers/modeling_utils.py | 5 ++++- tests/deepspeed/test_model_zoo.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fd0afa521a14..9f223338391c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1932,7 +1932,10 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): # if word embeddings are not tied, make sure that lm head is resized as well if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: old_lm_head = self.get_output_embeddings() - new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens) + if isinstance(old_lm_head, torch.nn.Embedding): + new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens) + else: + new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens) if hasattr(old_lm_head, "_hf_hook"): hook = old_lm_head._hf_hook add_hook_to_module(new_lm_head, hook) diff --git a/tests/deepspeed/test_model_zoo.py b/tests/deepspeed/test_model_zoo.py index 08c8b86dc07e..ea002f5ddf29 100644 --- a/tests/deepspeed/test_model_zoo.py +++ b/tests/deepspeed/test_model_zoo.py @@ -236,6 +236,8 @@ def make_task_cmds(): --train_file {data_dir_wmt}/train.json --source_lang en --target_lang ro + --max_source_length 12 + --max_target_length 12 """, "sum": f""" {scripts_dir}/summarization/run_summarization.py @@ -269,6 +271,7 @@ def make_task_cmds(): --remove_unused_columns False --max_steps 10 --image_processor_name {DS_TESTS_DIRECTORY}/vit_feature_extractor.json + --label_column_name labels """, }