From 4c5557e787d023cbdddeb75896e52a972e780596 Mon Sep 17 00:00:00 2001 From: apehex Date: Thu, 23 Jan 2025 16:36:46 +0100 Subject: [PATCH] Turn the attribute `_return_attention_scores` into an argument --- keras/src/layers/attention/multi_head_attention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keras/src/layers/attention/multi_head_attention.py b/keras/src/layers/attention/multi_head_attention.py index ad4d55d3a14b..8dda823e3cb1 100644 --- a/keras/src/layers/attention/multi_head_attention.py +++ b/keras/src/layers/attention/multi_head_attention.py @@ -158,7 +158,6 @@ def __init__( self.seed = seed self._inverse_sqrt_key_dim = 1.0 / math.sqrt(float(self._key_dim)) - self._return_attention_scores = False # Check for flash attention constraints if self._flash_attention and self._dropout > 0.0: @@ -419,6 +418,7 @@ def _compute_attention( value, attention_mask=None, training=None, + return_attention_scores=False, ): """Applies Dot-product attention with query, key, value tensors. @@ -442,7 +442,7 @@ def _compute_attention( attention_scores: Multi-headed attention weights. """ # Check for flash attention constraints - if self._flash_attention and self._return_attention_scores: + if self._flash_attention and return_attention_scores: raise ValueError( "Returning attention scores is not supported when flash " "attention is enabled. Please disable flash attention to access" @@ -452,7 +452,7 @@ def _compute_attention( # Determine whether to use dot-product attention use_dot_product_attention = not ( self._dropout > 0.0 - or self._return_attention_scores + or return_attention_scores or (len(query.shape) != 4) ) @@ -525,7 +525,6 @@ def call( training=None, use_causal_mask=False, ): - self._return_attention_scores = return_attention_scores if key is None: key = value @@ -562,6 +561,7 @@ def call( value, attention_mask, training, + return_attention_scores, ) attention_output = self._output_dense(attention_output)