Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add clipUtils #2473

Merged
merged 17 commits into from
Aug 13, 2024
12 changes: 6 additions & 6 deletions keras_cv/src/layers/vit_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import math

import tensorflow as tf
from keras import layers
from keras import ops
from keras_cv.src.backend import keras
from keras_cv.src.backend import ops

from keras_cv.api_export import keras_cv_export
from keras_cv.src.api_export import keras_cv_export


@keras_cv_export("keras_cv.layers.PatchingAndEmbedding")
class PatchingAndEmbedding(layers.Layer):
class PatchingAndEmbedding(keras.layers.Layer):
"""
Layer to patchify images, prepend a class token, positionally embed and
create a projection of patches for Vision Transformers
Expand Down Expand Up @@ -225,7 +225,7 @@ def get_config(self):


@keras_cv_export("keras_cv.layers.Unpatching")
class Unpatching(layers.Layer):
class Unpatching(keras.layers.Layer):
"""
Layer to unpatchify image data.

Expand Down Expand Up @@ -284,4 +284,4 @@ def call(self, patches):
else:
corrected_patches = patches[:, :required_patches]

return ops.split(corrected_patches, patches_per_column, axis=1)
return ops.split(corrected_patches, patches_per_column, axis=1)
2 changes: 1 addition & 1 deletion keras_cv/src/models/stable_diffusion_v3/MMDit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from keras_cv.backend import keras
from keras_cv.src.backend import keras
from keras_cv.layers.vit_layers import PatchingAndEmbedding
from keras_cv.models.stable_diffusion.v3 import embedding
from keras_cv.models.stable_diffusion.v3.MMDiT_block import MMDiTBlock
sachinprasadhs marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
5 changes: 3 additions & 2 deletions keras_cv/src/models/stable_diffusion_v3/MMDit_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import keras
from keras_cv.backend import ops
from keras_cv.src.backend import keras

from keras_cv.src.backend import ops


class MMDiTSelfAttention(keras.layers.Layer):
Expand Down
261 changes: 244 additions & 17 deletions keras_cv/src/models/stable_diffusion_v3/clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import regex as re

from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.src.backend import keras
from keras_cv.src.backend import ops


def quick_gelu(x):
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
hidden_dim,
num_heads,
intermediate_size,
intermediate_activation = 'quick_gelu',
intermediate_activation="quick_gelu",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -178,10 +178,12 @@ def __init__(
self.layer_norm_2 = keras.layers.LayerNormalization(
epsilon=1e-5, name="layer_norm_2"
)
if self.intermediate_activation == 'quick_gelu':
self.activation = quick_gelu
if self.intermediate_activation == "quick_gelu":
self.activation = quick_gelu
else:
self.activation = keras.layers.Activation(self.intermediate_activation, name="activation")
self.activation = keras.layers.Activation(
self.intermediate_activation, name="activation"
)

def compute_attention(
self, x, causal_attention_mask=None, attention_mask=None
Expand Down Expand Up @@ -240,15 +242,21 @@ def get_config(self):
"hidden_dim": self.hidden_dim,
"num_heads": self.num_heads,
"intermediate_size": self.intermediate_size,
"intermediate_activation":self.intermediate_activation,
"intermediate_activation": self.intermediate_activation,
}
)
return config


class CLIPEncoder(keras.layers.Layer):
def __init__(
self, width, num_layers, num_heads, intermediate_size, intermediate_activation, **kwargs
self,
width,
num_layers,
num_heads,
intermediate_size,
intermediate_activation,
**kwargs,
):
super().__init__(**kwargs)
self.width = width
Expand All @@ -261,7 +269,7 @@ def __init__(
self.width,
self.num_heads,
self.intermediate_size,
self.intermediate_activation
self.intermediate_activation,
)
for _ in range(self.num_layers)
]
Expand All @@ -276,16 +284,24 @@ def call(
x,
causal_attention_mask=None,
attention_mask=None,
intermediate_output=None,
):
for block in self.resblocks:
x = block(
x,
causal_attention_mask=causal_attention_mask,
attention_mask=attention_mask,
)
return x
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = self.num_layers + intermediate_output
intermediate = None
for i, block in enumerate(self.resblocks):
if i == intermediate_output:
x = block(
x,
causal_attention_mask=causal_attention_mask,
attention_mask=attention_mask,
)
intermediate = ops.copy(x)
return x, intermediate

def compute_output_shape(self, inputs_shape):

return inputs_shape

def get_config(self):
Expand Down Expand Up @@ -630,6 +646,105 @@ def tokenize_with_weights(self, text: str):
return out


class CLIPTextModel_(keras.Model):
def __init__(
self,
num_layers,
hidden_dim,
num_heads,
intermediate_size,
intermediate_activation,
**kwargs,
):
super().__init__()
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = intermediate_activation
self.embeddings = CLIPEmbeddings(hidden_dim)
self.encoder = CLIPEncoder(
hidden_dim,
num_layers,
num_heads,
intermediate_size,
intermediate_activation,
)
self.final_layer_norm = keras.layers.LayerNormalization(axis=-1)

def build(self, input_shape):
self.embeddings.build(input_shape)
self.encoder.build(input_shape)
self.final_layer_norm.build([None, None, self.hidden_dim])

def call(
self,
input_tokens,
intermediate_output=None,
final_layer_norm_intermediate=True,
):
x = self.embeddings(input_tokens)
# Compute causal mask
causal_mask = ops.ones((ops.shape(x)[1], ops.shape(x)[1]))
causal_mask = ops.triu(causal_mask)
causal_mask = ops.cast(causal_mask, "float32")
x, i = self.encoder(
x,
causal_attention_mask=causal_mask,
intermediate_output=intermediate_output,
)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)

indices = ops.expand_dims(
ops.cast(ops.argmax(input_tokens, axis=-1), "int32"), axis=-1
)
pooled_output = ops.take_along_axis(x, indices[:, :, None], axis=1)
pooled_output = ops.squeeze(pooled_output)

return x, i, pooled_output


class CLIPTextModel(keras.Model):
def __init__(
self,
num_layers,
hidden_dim,
num_heads,
intermediate_size,
intermediate_activation,
**kwargs,
):
super().__init__()
self.num_layers = num_layers
self.text_model = CLIPTextModel_(
num_layers,
hidden_dim,
num_heads,
intermediate_size,
intermediate_activation,
)
self.text_projection = keras.layers.Dense(
units=hidden_dim, use_bias=False
)

def build(self, input_shape):
self.text_model.build(input_shape)
self.text_projection.build([None, hidden_dim])

def get_input_embeddings(self):
return self.text_model.embeddings.token_embedding.weights[0]

def set_input_embeddings(self, embeddings):
self.text_model.embeddings.token_embedding.weights[0].assign(embeddings)

def call(self, *args, **kwargs):
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2])


class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
Expand All @@ -639,4 +754,116 @@ def encode_token_weights(self, token_weight_pairs):
else:
first_pooled = pooled
output = [out[0:1]]
return ops.concatenate(output, axis=-2), first_pooled
return ops.concatenate(output, axis=-2), first_pooled


class SDClipModel(keras.Model, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""

LAYERS = ["last", "pooled", "hidden"]

def __init__(
self,
num_layers,
hidden_dim,
num_heads,
intermediate_size,
intermediate_activation="quick_gelu",
max_length=77,
layer="last",
layer_idx=None,
model_class=CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},
layer_norm_hidden_state=True,
return_projected_pooled=True,
**kwargs,
):
super().__init__(**kwargs)
assert layer in self.LAYERS
self.model_class = model_class
self.transformer = model_class(
num_layers,
hidden_dim,
num_heads,
intermediate_size,
intermediate_activation,
)
self.num_layers = num_layers
self.max_length = max_length
self.transformer.build((None, None))
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.logit_scale = keras.Variable(4.6055)
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < self.num_layers
self.set_clip_options({"layer": layer_idx})

def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get(
"projected_pooled", self.return_projected_pooled
)
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
self.layer_idx = layer_idx

def call(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
tokens = ops.cast(tokens, "int64")
outputs = self.transformer(
tokens,
intermediate_output=self.layer_idx,
final_layer_norm_intermediate=self.layer_norm_hidden_state,
)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else:
z = outputs[1]
pooled_output = None
if len(outputs) >= 3:
if (
not self.return_projected_pooled
and len(outputs) >= 4
and outputs[3] is not None
):
pooled_output = ops.cast(outputs[3], "float32")
elif outputs[2] is not None:
pooled_output = ops.cast(outputs[2], "float32")
return ops.cast(z, "float32"), pooled_output


class SDXLClipG(SDClipModel):
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""

def __init__(
self,
num_layers,
hidden_dim,
num_heads,
intermediate_size,
intermediate_activation="gelu",
layer="penultimate",
layer_idx=None,
**kwargs,
):
if layer == "penultimate":
layer = "hidden"
layer_idx = -2
super().__init__(
num_layers=num_layers,
hidden_dim=hidden_dim,
num_heads=num_heads,
intermediate_size=intermediate_size,
intermediate_activation=intermediate_activation,
layer=layer,
layer_idx=layer_idx,
special_tokens={"start": 49406, "end": 49407, "pad": 0},
layer_norm_hidden_state=False,
)
2 changes: 1 addition & 1 deletion keras_cv/src/models/stable_diffusion_v3/clip_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ def test_sd3_tokenizer(self):
self.assertEqual(out_keras["g"][0][:4], expected_g_tokens)
self.assertEqual(out_keras["l"][0][:4], expected_l_tokens)
# TODO - uncomment to test T5XXLTokenizer after it is added
# self.assertEqual(out_keras["t5xxl"][0][:4], expected_t5xxl_tokens)
# self.assertEqual(out_keras["t5xxl"][0][:4], expected_t5xxl_tokens)
4 changes: 2 additions & 2 deletions keras_cv/src/models/stable_diffusion_v3/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from keras_nlp.layers import RotaryEmbedding

from keras_cv.backend import keras
from keras_cv.src.backend import keras


class TimestepEmbedding(keras.layers.Layer):
Expand Down Expand Up @@ -70,4 +70,4 @@ def __init__(self, hidden_dim, **kwargs):
def get_config(self):
config = super().get_config()
config.update({"hidden_dim": self.hidden_dim})
return config
return config
Loading
Loading