Skip to content

Commit

Permalink
feat: parameterized few stuff in vit
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashish Kumar Singh authored and AshishKumar4 committed Aug 27, 2024
1 parent ae85f41 commit 00ca69b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions flaxdiff/models/simple_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class UViT(nn.Module):
dtype: Any = jnp.float32
precision: Any = jax.lax.Precision.HIGH
use_projection: bool = False
use_flash_attention: bool = False
use_self_and_cross: bool = False
force_fp32_for_softmax: bool = True
activation:Callable = jax.nn.swish
norm_groups:int=8
dtype: Optional[Dtype] = None
Expand Down Expand Up @@ -102,15 +105,15 @@ def __call__(self, x, temb, textcontext=None):
for i in range(self.num_layers // 2):
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
only_pure_attention=False,
kernel_init=self.kernel_init())(x)
skips.append(x)

# Middle block
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
use_flash_attention=False, use_self_and_cross=True, force_fp32_for_softmax=True,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
only_pure_attention=False,
kernel_init=self.kernel_init())(x)

Expand All @@ -121,7 +124,7 @@ def __call__(self, x, temb, textcontext=None):
dtype=self.dtype, precision=self.precision)(skip)
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
only_pure_attention=False,
kernel_init=self.kernel_init())(skip)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.23',
version='0.1.24',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 00ca69b

Please sign in to comment.