Skip to content

Commit

Permalink
fix: fixed initializers
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 10, 2024
1 parent ae44f38 commit 9445d03
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 56 deletions.
93 changes: 50 additions & 43 deletions evaluate.ipynb

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions flaxdiff/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class EfficientAttention(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = lambda : kernel_init(1.0)
kernel_init: Callable = kernel_init(1.0)
force_fp32_for_softmax: bool = True

def setup(self):
Expand All @@ -33,15 +33,15 @@ def setup(self):
self.heads * self.dim_head,
precision=self.precision,
use_bias=self.use_bias,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
dtype=self.dtype
)
self.query = dense(name="to_q")
self.key = dense(name="to_k")
self.value = dense(name="to_v")

self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0")
kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
# self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)

def _reshape_tensor_to_head_dim(self, tensor):
Expand Down Expand Up @@ -114,7 +114,7 @@ class NormalAttention(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = lambda : kernel_init(1.0)
kernel_init: Callable = kernel_init(1.0)
force_fp32_for_softmax: bool = True

def setup(self):
Expand All @@ -125,7 +125,7 @@ def setup(self):
axis=-1,
precision=self.precision,
use_bias=self.use_bias,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
dtype=self.dtype
)
self.query = dense(name="to_q")
Expand All @@ -139,7 +139,7 @@ def setup(self):
use_bias=self.use_bias,
dtype=self.dtype,
name="to_out_0",
kernel_init=self.kernel_init()
kernel_init=self.kernel_init
# kernel_init=jax.nn.initializers.xavier_uniform()
)

Expand Down Expand Up @@ -235,7 +235,7 @@ class BasicTransformerBlock(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = lambda : kernel_init(1.0)
kernel_init: Callable = kernel_init(1.0)
use_flash_attention:bool = False
use_cross_only:bool = False
only_pure_attention:bool = False
Expand Down Expand Up @@ -302,7 +302,7 @@ class TransformerBlock(nn.Module):
use_self_and_cross:bool = True
only_pure_attention:bool = False
force_fp32_for_softmax: bool = True
kernel_init: Callable = lambda : kernel_init(1.0)
kernel_init: Callable = kernel_init(1.0)

@nn.compact
def __call__(self, x, context=None):
Expand All @@ -313,12 +313,12 @@ def __call__(self, x, context=None):
if self.use_linear_attention:
projected_x = nn.Dense(features=inner_dim,
use_bias=False, precision=self.precision,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
dtype=self.dtype, name=f'project_in')(normed_x)
else:
projected_x = nn.Conv(
features=inner_dim, kernel_size=(1, 1),
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
precision=self.precision, name=f'project_in_conv',
)(normed_x)
Expand Down Expand Up @@ -347,12 +347,12 @@ def __call__(self, x, context=None):
if self.use_linear_attention:
projected_x = nn.Dense(features=C, precision=self.precision,
dtype=self.dtype, use_bias=False,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
name=f'project_out')(projected_x)
else:
projected_x = nn.Conv(
features=C, kernel_size=(1, 1),
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
precision=self.precision, name=f'project_out_conv',
)(projected_x)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.11',
version='0.1.12',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 9445d03

Please sign in to comment.