Skip to content

Commit

Permalink
feat: some generalizations
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 10, 2024
1 parent fe30466 commit ae44f38
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
25 changes: 17 additions & 8 deletions flaxdiff/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class EfficientAttention(nn.Module):
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = lambda : kernel_init(1.0)
force_fp32_for_softmax: bool = True

def setup(self):
inner_dim = self.dim_head * self.heads
Expand Down Expand Up @@ -114,6 +115,7 @@ class NormalAttention(nn.Module):
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = lambda : kernel_init(1.0)
force_fp32_for_softmax: bool = True

def setup(self):
inner_dim = self.dim_head * self.heads
Expand Down Expand Up @@ -157,7 +159,7 @@ def __call__(self, x, context=None):

hidden_states = nn.dot_product_attention(
query, key, value, dtype=self.dtype, broadcast_dropout=False,
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=True,
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
deterministic=True
)
proj = self.proj_attn(hidden_states)
Expand Down Expand Up @@ -237,6 +239,7 @@ class BasicTransformerBlock(nn.Module):
use_flash_attention:bool = False
use_cross_only:bool = False
only_pure_attention:bool = False
force_fp32_for_softmax: bool = True

def setup(self):
if self.use_flash_attention:
Expand All @@ -252,7 +255,8 @@ def setup(self):
precision=self.precision,
use_bias=self.use_bias,
dtype=self.dtype,
kernel_init=self.kernel_init
kernel_init=self.kernel_init,
force_fp32_for_softmax=self.force_fp32_for_softmax
)
self.attention2 = attenBlock(
query_dim=self.query_dim,
Expand All @@ -262,7 +266,8 @@ def setup(self):
precision=self.precision,
use_bias=self.use_bias,
dtype=self.dtype,
kernel_init=self.kernel_init
kernel_init=self.kernel_init,
force_fp32_for_softmax=self.force_fp32_for_softmax
)

self.ff = FlaxFeedForward(dim=self.query_dim)
Expand Down Expand Up @@ -296,6 +301,8 @@ class TransformerBlock(nn.Module):
use_flash_attention:bool = False
use_self_and_cross:bool = True
only_pure_attention:bool = False
force_fp32_for_softmax: bool = True
kernel_init: Callable = lambda : kernel_init(1.0)

@nn.compact
def __call__(self, x, context=None):
Expand All @@ -306,12 +313,12 @@ def __call__(self, x, context=None):
if self.use_linear_attention:
projected_x = nn.Dense(features=inner_dim,
use_bias=False, precision=self.precision,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(),
dtype=self.dtype, name=f'project_in')(normed_x)
else:
projected_x = nn.Conv(
features=inner_dim, kernel_size=(1, 1),
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(),
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
precision=self.precision, name=f'project_in_conv',
)(normed_x)
Expand All @@ -331,19 +338,21 @@ def __call__(self, x, context=None):
dtype=self.dtype,
use_flash_attention=self.use_flash_attention,
use_cross_only=(not self.use_self_and_cross),
only_pure_attention=self.only_pure_attention
only_pure_attention=self.only_pure_attention,
force_fp32_for_softmax=self.force_fp32_for_softmax,
kernel_init=self.kernel_init
)(projected_x, context)

if self.use_projection == True:
if self.use_linear_attention:
projected_x = nn.Dense(features=C, precision=self.precision,
dtype=self.dtype, use_bias=False,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(),
name=f'project_out')(projected_x)
else:
projected_x = nn.Conv(
features=C, kernel_size=(1, 1),
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(),
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
precision=self.precision, name=f'project_out_conv',
)(projected_x)
Expand Down
23 changes: 15 additions & 8 deletions flaxdiff/models/simple_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Unet(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)

def setup(self):
if self.norm_groups > 0:
Expand Down Expand Up @@ -49,7 +50,7 @@ def __call__(self, x, temb, textcontext):
features=self.feature_depths[0],
kernel_size=(3, 3),
strides=(1, 1),
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand All @@ -64,7 +65,7 @@ def __call__(self, x, temb, textcontext):
down_conv_type,
name=f"down_{i}_residual_{j}",
features=dim_in,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
kernel_size=(3, 3),
strides=(1, 1),
activation=self.activation,
Expand All @@ -81,6 +82,8 @@ def __call__(self, x, temb, textcontext):
use_self_and_cross=attention_config.get("use_self_and_cross", True),
precision=attention_config.get("precision", self.precision),
only_pure_attention=attention_config.get("only_pure_attention", True),
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
kernel_init=self.kernel_init(1.0),
name=f"down_{i}_attention_{j}")(x, textcontext)
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
downs.append(x)
Expand All @@ -103,7 +106,7 @@ def __call__(self, x, temb, textcontext):
middle_conv_type,
name=f"middle_res1_{j}",
features=middle_dim_out,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
kernel_size=(3, 3),
strides=(1, 1),
activation=self.activation,
Expand All @@ -121,12 +124,14 @@ def __call__(self, x, temb, textcontext):
use_self_and_cross=False,
precision=middle_attention.get("precision", self.precision),
only_pure_attention=middle_attention.get("only_pure_attention", True),
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
kernel_init=self.kernel_init(1.0),
name=f"middle_attention_{j}")(x, textcontext)
x = ResidualBlock(
middle_conv_type,
name=f"middle_res2_{j}",
features=middle_dim_out,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
kernel_size=(3, 3),
strides=(1, 1),
activation=self.activation,
Expand All @@ -148,7 +153,7 @@ def __call__(self, x, temb, textcontext):
up_conv_type,# if j == 0 else "separable",
name=f"up_{i}_residual_{j}",
features=dim_out,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
kernel_size=kernel_size,
strides=(1, 1),
activation=self.activation,
Expand All @@ -165,6 +170,8 @@ def __call__(self, x, temb, textcontext):
use_self_and_cross=attention_config.get("use_self_and_cross", True),
precision=attention_config.get("precision", self.precision),
only_pure_attention=attention_config.get("only_pure_attention", True),
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
kernel_init=self.kernel_init(1.0),
name=f"up_{i}_attention_{j}")(x, textcontext)
# print("Upscaling ", i, x.shape)
if i != len(feature_depths) - 1:
Expand All @@ -183,7 +190,7 @@ def __call__(self, x, temb, textcontext):
features=self.feature_depths[0],
kernel_size=(3, 3),
strides=(1, 1),
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand All @@ -194,7 +201,7 @@ def __call__(self, x, temb, textcontext):
conv_type,
name="final_residual",
features=self.feature_depths[0],
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
kernel_size=(3,3),
strides=(1, 1),
activation=self.activation,
Expand All @@ -213,7 +220,7 @@ def __call__(self, x, temb, textcontext):
kernel_size=(3, 3),
strides=(1, 1),
# activation=jax.nn.mish
kernel_init=kernel_init(0.0),
kernel_init=self.kernel_init(0.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.10',
version='0.1.11',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit ae44f38

Please sign in to comment.