Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add clipUtils #2473

Merged
merged 17 commits into from
Aug 13, 2024
42 changes: 25 additions & 17 deletions keras_cv/src/layers/vit_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import math

import tensorflow as tf
from keras import layers
from keras import ops

from keras_cv.api_export import keras_cv_export
from keras_cv.src.api_export import keras_cv_export
from keras_cv.src.backend import keras
from keras_cv.src.backend import ops


@keras_cv_export("keras_cv.layers.PatchingAndEmbedding")
class PatchingAndEmbedding(layers.Layer):
class PatchingAndEmbedding(keras.layers.Layer):
"""
Layer to patchify images, prepend a class token, positionally embed and
create a projection of patches for Vision Transformers
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(self, project_dim, patch_size, padding="VALID", **kwargs):
f"Padding must be either 'SAME' or 'VALID', but {padding} was "
"passed."
)
self.projection = layers.Conv2D(
self.projection = keras.layers.Conv2D(
filters=self.project_dim,
kernel_size=self.patch_size,
strides=self.patch_size,
Expand All @@ -88,7 +88,7 @@ def build(self, input_shape):
* input_shape[2]
// self.patch_size
)
self.position_embedding = layers.Embedding(
self.position_embedding = keras.layers.Embedding(
input_dim=self.num_patches + 1, output_dim=self.project_dim
)

Expand Down Expand Up @@ -225,32 +225,40 @@ def get_config(self):


@keras_cv_export("keras_cv.layers.Unpatching")
class Unpatching(layers.Layer):
class Unpatching(keras.layers.Layer):
"""
Layer to unpatchify image data.

This layer expects patches sorted by column and reorganizes the patches such that they will each be positioned as a
2D shape with some number of channels.
This layer expects patches sorted by column and reorganizes the patches
such that they will each be positioned as a 2D shape with some
number of channels.

Any necessary padding or truncation will be applied to reach the target shape.
Any necessary padding or truncation will be applied to reach the target
shape.

Args:
target_shape: The target image shape after unpatching, of form [height, width]
target_shape: The target image shape after unpatching,
of form [height, width]
"""

def __init__(self, target_shape):
self.target_shape = target_shape

def call(self, patches):
"""
Reconstructs an unpatched image from the sequence of column sequence patches.
Reconstructs an unpatched image from the sequence of column sequence
patches.

If there are insufficient patches to construct the image of requested dimensions, additional zero-patches will
be appended. If excessive patches are provided, unnecessary patches will be truncated from the end.
If there are insufficient patches to construct the image of requested
dimensions, additional zero-patches will
be appended. If excessive patches are provided, unnecessary patches
will be truncated from the end.

Args:
patches: Patches of images in column sequence (i.e. each patch is vertically oriented relative to the
previous patch). Expected shape of [batch_size, patch_num, patch_height, patch_width, channels].
patches: Patches of images in column sequence (i.e. each patch
is vertically oriented relative to the
previous patch). Expected shape of [batch_size, patch_num,
patch_height, patch_width, channels].

Returns:
Unpatched image: Image reconstructed from the patches,
Expand Down Expand Up @@ -284,4 +292,4 @@ def call(self, patches):
else:
corrected_patches = patches[:, :required_patches]

return ops.split(corrected_patches, patches_per_column, axis=1)
return ops.split(corrected_patches, patches_per_column, axis=1)
22 changes: 18 additions & 4 deletions keras_cv/src/models/stable_diffusion_v3/MMDit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
from keras_cv.backend import keras
from keras_cv.layers.vit_layers import PatchingAndEmbedding
from keras_cv.models.stable_diffusion.v3 import embedding
from keras_cv.models.stable_diffusion.v3.MMDiT_block import MMDiTBlock
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writingf, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.src.backend import keras
from keras_cv.src.layers.vit_layers import PatchingAndEmbedding
from keras_cv.src.models.stable_diffusion_v3 import embedding
from keras_cv.src.models.stable_diffusion_v3.MMDit_block import MMDiTBlock


class MMDiT(keras.layers.Layer):
Expand Down
21 changes: 18 additions & 3 deletions keras_cv/src/models/stable_diffusion_v3/MMDit_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
import keras
from keras_cv.backend import ops
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writingf, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.src.backend import keras
from keras_cv.src.backend import ops


class MMDiTSelfAttention(keras.layers.Layer):
Expand All @@ -17,7 +31,8 @@ def __init__(
self.cdense = keras.layers.Dense(key_dim)

if normalization_mode == "rms_normalization":
# TODO(varuns1997): Re-Implement RMSNormalization for Keras 2 Compatibility
# TODO(varuns1997): Re-Implement RMSNormalization
# for Keras 2 Compatibility
self.query_normalization = keras.layers.LayerNormalization(
rms_scaling=True
)
Expand Down
13 changes: 13 additions & 0 deletions keras_cv/src/models/stable_diffusion_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading