Skip to content

Commit

Permalink
support F.scaled_dot_product_attention without xformers
Browse files Browse the repository at this point in the history
  • Loading branch information
lawrence-cj committed Nov 24, 2024
1 parent 6265a23 commit cc5991b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
9 changes: 5 additions & 4 deletions diffusion/model/nets/sana_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,14 @@ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, qk_norm=Fal
def forward(self, x, cond, mask=None):
# query: img tokens; key/value: condition; mask: if padding tokens
B, N, C = x.shape
first_dim = 1 if _xformers_available else B

q = self.q_linear(x)
kv = self.kv_linear(cond).view(1, -1, 2, C)
kv = self.kv_linear(cond).view(first_dim, -1, 2, C)
k, v = kv.unbind(2)
q = self.q_norm(q).view(1, -1, self.num_heads, self.head_dim)
k = self.k_norm(k).view(1, -1, self.num_heads, self.head_dim)
v = v.view(1, -1, self.num_heads, self.head_dim)
q = self.q_norm(q).view(first_dim, -1, self.num_heads, self.head_dim)
k = self.k_norm(k).view(first_dim, -1, self.num_heads, self.head_dim)
v = v.view(first_dim, -1, self.num_heads, self.head_dim)

if _xformers_available:
attn_bias = None
Expand Down
21 changes: 15 additions & 6 deletions diffusion/model/nets/sana_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@
t2i_modulate,
)
from diffusion.model.utils import auto_grad_checkpoint
from diffusion.utils.import_utils import is_triton_module_available
from diffusion.utils.import_utils import is_triton_module_available, is_xformers_available

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
_triton_modules_available = True

_xformers_available = False
if is_xformers_available():
import xformers.ops
_xformers_available = True

class SanaMSBlock(nn.Module):
"""
Expand Down Expand Up @@ -301,14 +305,19 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
y = self.attention_y_norm(y)

if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.repeat(y.shape[0] // mask.shape[0], 1) if mask.shape[0] != y.shape[0] else mask
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
if _xformers_available:
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = mask
elif _xformers_available:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
else:
raise ValueError(f"{attn_type} type is not available due to _xformers_available={_xformers_available}.")

for block in self.blocks:
x = auto_grad_checkpoint(
block, x, y, t0, y_lens, (self.h, self.w), **kwargs
Expand Down

0 comments on commit cc5991b

Please sign in to comment.