Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
Signed-off-by: lawrence-cj <[email protected]>
  • Loading branch information
lawrence-cj committed Nov 24, 2024
1 parent cc5991b commit 90d3727
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 49 deletions.
20 changes: 10 additions & 10 deletions diffusion/model/nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,43 @@
from .sana import (
Sana,
SanaBlock,
get_2d_sincos_pos_embed,
get_2d_sincos_pos_embed_from_grid,
Sana,
SanaBlock,
get_1d_sincos_pos_embed_from_grid,
get_2d_sincos_pos_embed,
get_2d_sincos_pos_embed_from_grid,
)
from .sana_multi_scale import (
SanaMSBlock,
SanaMS,
SanaMS_600M_P1_D28,
SanaMS,
SanaMS_600M_P1_D28,
SanaMS_600M_P2_D28,
SanaMS_600M_P4_D28,
SanaMS_1600M_P1_D20,
SanaMS_1600M_P2_D20,
SanaMSBlock,
)
from .sana_multi_scale_adaln import (
SanaMSAdaLNBlock,
SanaMSAdaLN,
SanaMSAdaLN_600M_P1_D28,
SanaMSAdaLN_600M_P2_D28,
SanaMSAdaLN_600M_P4_D28,
SanaMSAdaLN_1600M_P1_D20,
SanaMSAdaLN_1600M_P2_D20,
SanaMSAdaLNBlock,
)
from .sana_U_shape import (
SanaUBlock,
SanaU,
SanaU_600M_P1_D28,
SanaU_600M_P2_D28,
SanaU_600M_P4_D28,
SanaU_1600M_P1_D20,
SanaU_1600M_P2_D20,
SanaUBlock,
)
from .sana_U_shape_multi_scale import (
SanaUMSBlock,
SanaUMS,
SanaUMS_600M_P1_D28,
SanaUMS_600M_P2_D28,
SanaUMS_600M_P4_D28,
SanaUMS_1600M_P1_D20,
SanaUMS_1600M_P2_D20,
SanaUMSBlock,
)
11 changes: 8 additions & 3 deletions diffusion/model/nets/sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU

_triton_modules_available = True


Expand Down Expand Up @@ -84,7 +85,9 @@ def __init__(
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
# TODO: Here the num_heads set to 36 for tmp used
self_num_heads = hidden_size // linear_head_dim
Expand Down Expand Up @@ -131,7 +134,9 @@ def __init__(
)
elif ffn_type == "triton_mbconvpreglu":
if not _triton_modules_available:
raise ValueError(f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
raise ValueError(
f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
self.mlp = TritonMBConvPreGLU(
in_dim=hidden_size,
out_dim=hidden_size,
Expand Down
16 changes: 7 additions & 9 deletions diffusion/model/nets/sana_U_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@

from diffusion.model.builder import MODELS
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
try:
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
except ImportError:
import warnings
warnings.warn("TritonLiteMLA with `triton` is not available on your platform.")
from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
from diffusion.model.nets.sana_blocks import (
Attention,
Expand All @@ -41,13 +36,14 @@
t2i_modulate,
)
from diffusion.model.norms import RMSNorm
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
from diffusion.utils.logger import get_root_logger
from diffusion.model.utils import auto_grad_checkpoint
from diffusion.utils.import_utils import is_triton_module_available
from diffusion.utils.logger import get_root_logger

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA

_triton_modules_available = True


Expand Down Expand Up @@ -88,7 +84,9 @@ def __init__(
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
# TODO: Here the num_heads set to 36 for tmp used
self_num_heads = hidden_size // 32
Expand Down
11 changes: 7 additions & 4 deletions diffusion/model/nets/sana_U_shape_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,19 @@
LiteLA,
MultiHeadCrossAttention,
PatchEmbedMS,
SizeEmbedder,
T2IFinalLayer,
t2i_modulate,
)
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
from diffusion.model.utils import auto_grad_checkpoint
from diffusion.utils.import_utils import is_triton_module_available

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA

_triton_modules_available = True


class SanaUMSBlock(nn.Module):
"""
A SanaU block with global shared adaptive layer norm (adaLN-single) conditioning and U-shaped model.
Expand Down Expand Up @@ -79,7 +80,9 @@ def __init__(
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
self_num_heads = hidden_size // 32
self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
Expand Down
4 changes: 2 additions & 2 deletions diffusion/model/nets/sana_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusion.utils.import_utils import is_xformers_available
from einops import rearrange
from timm.models.vision_transformer import Attention as Attention_
from timm.models.vision_transformer import Mlp
from transformers import AutoModelForCausalLM

from diffusion.model.norms import RMSNorm
from diffusion.model.utils import get_same_padding, to_2tuple

from diffusion.utils.import_utils import is_xformers_available

_xformers_available = False
if is_xformers_available():
import xformers.ops

_xformers_available = True


Expand Down
15 changes: 10 additions & 5 deletions diffusion/model/nets/sana_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU

_triton_modules_available = True

_xformers_available = False
if is_xformers_available():
import xformers.ops
_xformers_available = True


class SanaMSBlock(nn.Module):
"""
A Sana block with global shared adaptive layer norm zero (adaLN-Zero) conditioning.
Expand Down Expand Up @@ -84,7 +85,9 @@ def __init__(
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
self_num_heads = hidden_size // linear_head_dim
self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
Expand Down Expand Up @@ -120,7 +123,9 @@ def __init__(
)
elif ffn_type == "triton_mbconvpreglu":
if not _triton_modules_available:
raise ValueError(f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
raise ValueError(
f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
self.mlp = TritonMBConvPreGLU(
in_dim=hidden_size,
out_dim=hidden_size,
Expand Down Expand Up @@ -316,7 +321,7 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
else:
raise ValueError(f"{attn_type} type is not available due to _xformers_available={_xformers_available}.")
raise ValueError(f"Attention type is not available due to _xformers_available={_xformers_available}.")

for block in self.blocks:
x = auto_grad_checkpoint(
Expand Down
14 changes: 6 additions & 8 deletions diffusion/model/nets/sana_multi_scale_adaln.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@

from diffusion.model.builder import MODELS
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
try:
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
except ImportError:
import warnings
warnings.warn("TritonLiteMLA with `triton` is not available on your platform.")
from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
from diffusion.model.nets.sana_blocks import (
Attention,
Expand All @@ -38,12 +33,13 @@
T2IFinalLayer,
modulate,
)
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
from diffusion.model.utils import auto_grad_checkpoint
from diffusion.utils.import_utils import is_triton_module_available

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA

_triton_modules_available = True


Expand Down Expand Up @@ -84,7 +80,9 @@ def __init__(
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.")
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
self_num_heads = hidden_size // 32
self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
Expand Down
13 changes: 7 additions & 6 deletions diffusion/model/nets/sana_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
from timm.models.layers import DropPath

from diffusion.model.nets.basic_modules import DWMlp, MBConvPreGLU, Mlp
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
try:
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
except ImportError:
import warnings
warnings.warn("TritonLiteMLA with `triton` is not available on your platform.")
from diffusion.model.nets.sana_blocks import Attention, FlashAttention, MultiHeadCrossAttention, t2i_modulate
from diffusion.utils.import_utils import is_triton_module_available

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA

_triton_modules_available = True


class SanaMSPABlock(nn.Module):
Expand Down
7 changes: 5 additions & 2 deletions diffusion/utils/import_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import importlib.util
import importlib_metadata
from packaging import version
import logging
import warnings

import importlib_metadata
from packaging import version

logger = logging.getLogger(__name__)

_xformers_available = importlib.util.find_spec("xformers") is not None
Expand All @@ -28,8 +29,10 @@
_triton_modules_available = False
warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.")


def is_xformers_available():
return _xformers_available


def is_triton_module_available():
return _triton_modules_available

0 comments on commit 90d3727

Please sign in to comment.