Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring: turn the attribute _return_attention_scores into an argument #20803

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def __init__(
self.seed = seed

self._inverse_sqrt_key_dim = 1.0 / math.sqrt(float(self._key_dim))
self._return_attention_scores = False

# Check for flash attention constraints
if self._flash_attention and self._dropout > 0.0:
Expand Down Expand Up @@ -419,6 +418,7 @@ def _compute_attention(
value,
attention_mask=None,
training=None,
return_attention_scores=False,
):
"""Applies Dot-product attention with query, key, value tensors.

Expand All @@ -442,7 +442,7 @@ def _compute_attention(
attention_scores: Multi-headed attention weights.
"""
# Check for flash attention constraints
if self._flash_attention and self._return_attention_scores:
if self._flash_attention and return_attention_scores:
raise ValueError(
"Returning attention scores is not supported when flash "
"attention is enabled. Please disable flash attention to access"
Expand All @@ -452,7 +452,7 @@ def _compute_attention(
# Determine whether to use dot-product attention
use_dot_product_attention = not (
self._dropout > 0.0
or self._return_attention_scores
or return_attention_scores
or (len(query.shape) != 4)
)

Expand Down Expand Up @@ -525,7 +525,6 @@ def call(
training=None,
use_causal_mask=False,
):
self._return_attention_scores = return_attention_scores
if key is None:
key = value

Expand Down Expand Up @@ -562,6 +561,7 @@ def call(
value,
attention_mask,
training,
return_attention_scores,
)
attention_output = self._output_dense(attention_output)

Expand Down
Loading