Skip to content

Commit

Permalink
Misc changes for transformers tests (#581)
Browse files Browse the repository at this point in the history
  • Loading branch information
ankurneog authored Dec 5, 2023
1 parent 5dc7577 commit b9f1485
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tests/transformers/tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def _sample_generate(
):
torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
output_generate = model.generate(
input_ids,
do_sample=True,
Expand Down Expand Up @@ -368,6 +369,7 @@ def _sample_generate(

with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
output_sample = model.sample(
input_ids.repeat_interleave(num_return_sequences, dim=0),
max_length=max_length,
Expand Down Expand Up @@ -399,6 +401,7 @@ def _beam_search_generate(
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
output_generate = model.generate(
input_ids,
do_sample=False,
Expand Down Expand Up @@ -462,6 +465,7 @@ def _beam_sample_generate(
):
torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
output_generate = model.generate(
input_ids,
do_sample=True,
Expand Down Expand Up @@ -497,6 +501,7 @@ def _beam_sample_generate(

with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
output_beam_sample = model.beam_sample(
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer,
Expand Down Expand Up @@ -529,6 +534,7 @@ def _group_beam_search_generate(
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
output_generate = model.generate(
input_ids,
do_sample=False,
Expand Down Expand Up @@ -561,7 +567,6 @@ def _group_beam_search_generate(
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)

output_group_beam_search = model.group_beam_search(
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer,
Expand Down Expand Up @@ -593,6 +598,7 @@ def _constrained_beam_search_generate(
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
output_generate = model.generate(
input_ids,
do_sample=False,
Expand Down Expand Up @@ -668,6 +674,7 @@ def _contrastive_generate(

kwargs = {}
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
output_generate = model.generate(
input_ids,
do_sample=False,
Expand Down Expand Up @@ -695,6 +702,7 @@ def _contrastive_generate(

with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
self._update_default_model_kwargs(model_kwargs)
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
output_contrastive = model.contrastive_search(
input_ids,
Expand Down
1 change: 1 addition & 0 deletions tests/transformers/tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,7 @@ def test_model_weights_reload_no_missing_tied_weights(self):
" `persistent=False`",
)

@mark.skip("skip - test is slow")
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down

0 comments on commit b9f1485

Please sign in to comment.