Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support CPU inference - use pytorch-implementation when xformers isn't installed #48

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion deepseek_vl2/models/siglip_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,30 @@
)
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from transformers.modeling_utils import is_flash_attn_2_available
from xformers.ops import memory_efficient_attention
from functools import partial
try:
from xformers.ops import memory_efficient_attention
except ImportError:
warnings.warn(
"xformers not installed, using slow PyTorch implementation of memory_efficient_attention",
stacklevel=2,
)

def memory_efficient_attention(query, key, value, p):
""" This code is taken from https://facebookresearch.github.io/xformers/components/ops.html """

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
A simplified PyTorch implementation of memory-efficient attention.
Assumes inputs are already reshaped to 4D [batch, heads, seq_len, head_dim].
This implementation is SLOWER and LESS MEMORY-EFFICIENT than xformers,
and does NOT SUPPORT attn_bias.

    Args:
        query (Tensor): shape (batch, heads, seq_len, head_dim)
        key (Tensor): shape (batch, heads, seq_len, head_dim)
        value (Tensor): shape (batch, heads, seq_len, head_dim)
        p (float): dropout probability. Default: 0.0 (no dropout)
        attn_bias:  NOT SUPPORTED IN THIS FALLBACK.
        scale (float, optional): scaling factor for the dot product. If None,
            defaults to 1 / sqrt(head_dim).
    Returns:
        Tensor: shape (batch, heads, seq_len, head_dim)
    """
    
    Its batter to replace the comment with this, it will improve the developer experience.

attn_bias = None
scale = 1.0 / query.shape[-1] ** 0.5
query = query * scale
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
attn = attn @ value
return attn.transpose(1, 2).contiguous()


if is_flash_attn_2_available():
Expand Down