Skip to content

Commit

Permalink
fix CLIP text numerics
Browse files Browse the repository at this point in the history
  • Loading branch information
sachinprasadhs committed Jul 23, 2024
1 parent bf7dfcb commit a452c60
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions keras_cv/src/models/stable_diffusion_v3/clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(
hidden_dim,
num_heads,
intermediate_size,
intermediate_activation = 'quick_gelu',
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -162,13 +163,13 @@ def __init__(
self.num_heads,
name="multi_head_attention",
)
self.intermediate_activation = intermediate_activation
self.layer_norm_1 = keras.layers.LayerNormalization(
epsilon=1e-5, name="layer_norm_1"
)
self.mlp_dense_1 = keras.layers.Dense(
self.hidden_dim * 4,
self.intermediate_size,
name="c_fc",
activation=quick_gelu,
)
self.mlp_dense_2 = keras.layers.Dense(
self.hidden_dim,
Expand All @@ -177,6 +178,10 @@ def __init__(
self.layer_norm_2 = keras.layers.LayerNormalization(
epsilon=1e-5, name="layer_norm_2"
)
if self.intermediate_activation == 'quick_gelu':
self.activation = quick_gelu
else:
self.activation = keras.layers.Activation(self.intermediate_activation, name="activation")

def compute_attention(
self, x, causal_attention_mask=None, attention_mask=None
Expand Down Expand Up @@ -205,7 +210,7 @@ def build(self, input_shape):
self.attn.build(None)
self.layer_norm_1.build([None, None, self.hidden_dim])
self.mlp_dense_1.build([None, None, self.hidden_dim])
self.mlp_dense_2.build([None, None, self.hidden_dim * 4])
self.mlp_dense_2.build([None, None, self.intermediate_size])
self.layer_norm_2.build([None, None, self.hidden_dim])
self.built = True

Expand All @@ -220,6 +225,7 @@ def call(self, x, causal_attention_mask=None, attention_mask=None):
x = x + residual
residual = x
x = self.mlp_dense_1(self.layer_norm_2(residual))
x = self.activation(x)
x = self.mlp_dense_2(x)
x = residual + x
return x
Expand All @@ -234,25 +240,28 @@ def get_config(self):
"hidden_dim": self.hidden_dim,
"num_heads": self.num_heads,
"intermediate_size": self.intermediate_size,
"intermediate_activation":self.intermediate_activation,
}
)
return config


class CLIPEncoder(keras.layers.Layer):
def __init__(
self, width, num_layers, num_heads, intermediate_size, **kwargs
self, width, num_layers, num_heads, intermediate_size, intermediate_activation, **kwargs
):
super().__init__(**kwargs)
self.width = width
self.num_layers = num_layers
self.num_heads = num_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = intermediate_activation
self.resblocks = [
CLIPLayer(
self.width,
self.num_heads,
self.intermediate_size,
self.intermediate_activation
)
for _ in range(self.num_layers)
]
Expand Down

0 comments on commit a452c60

Please sign in to comment.