From d427f1fabc1766db3534436ae3b06c4ffa4effc3 Mon Sep 17 00:00:00 2001 From: Yeonsil Yoon Date: Thu, 22 Aug 2024 09:39:42 -0700 Subject: [PATCH] Revert mark_step in mixtral model from PR #1260 (#1273) --- .../transformers/models/mixtral/modeling_mixtral.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index fc414e6d76..43dfc7e48a 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -471,7 +471,6 @@ def forward( reuse_cache: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, - lazy_mode: Optional[bool] = True, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -481,10 +480,7 @@ def forward( - add new args reuse_cache - add new args flash_attention_recompute - add new args cache_idx - - add new args lazy_mode """ - if lazy_mode: - htcore.mark_step() residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -504,16 +500,12 @@ def forward( cache_idx=cache_idx, ) hidden_states = residual + hidden_states - if lazy_mode: - htcore.mark_step() # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, router_logits = self.block_sparse_moe(hidden_states) hidden_states = residual + hidden_states - if lazy_mode: - htcore.mark_step() outputs = (hidden_states,) @@ -554,7 +546,6 @@ def forward( reuse_cache: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, - lazy_mode: Optional[bool] = True, ) -> Union[Tuple, MoeModelOutputWithPast]: """ Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069 @@ -684,7 +675,6 @@ def forward( reuse_cache=reuse_cache, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, - lazy_mode=lazy_mode, ) hidden_states = layer_outputs[0] @@ -759,7 +749,6 @@ def forward( reuse_cache: Optional[bool] = None, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, - lazy_mode: Optional[bool] = True, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -788,7 +777,6 @@ def forward( reuse_cache=reuse_cache, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, - lazy_mode=lazy_mode, ) hidden_states = outputs[0] @@ -893,7 +881,6 @@ def prepare_inputs_for_generation( "reuse_cache": reuse_cache, "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "cache_idx": kwargs.get("cache_idx"), - "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs