diff --git a/keras_hub/src/models/image_to_image.py b/keras_hub/src/models/image_to_image.py index 239f49b55..d3194a581 100644 --- a/keras_hub/src/models/image_to_image.py +++ b/keras_hub/src/models/image_to_image.py @@ -234,7 +234,7 @@ def normalize_images(x): input_is_scalar = True x = ops.image.resize( x, - (self.backbone.height, self.backbone.width), + (self.backbone.image_shape[0], self.backbone.image_shape[1]), interpolation="nearest", data_format=data_format, ) diff --git a/keras_hub/src/models/inpaint.py b/keras_hub/src/models/inpaint.py index 1c475f5b8..40bcc7ad1 100644 --- a/keras_hub/src/models/inpaint.py +++ b/keras_hub/src/models/inpaint.py @@ -202,7 +202,7 @@ def normalize(x): input_is_scalar = True x = ops.image.resize( x, - (self.backbone.height, self.backbone.width), + (self.backbone.image_shape[0], self.backbone.image_shape[1]), interpolation="nearest", data_format=data_format, ) @@ -240,7 +240,7 @@ def normalize(x): x = ops.cast(x, "float32") x = ops.image.resize( x, - (self.backbone.height, self.backbone.width), + (self.backbone.image_shape[0], self.backbone.image_shape[1]), interpolation="nearest", data_format=data_format, ) @@ -303,7 +303,7 @@ def normalize_images(x): input_is_scalar = True x = ops.image.resize( x, - (self.backbone.height, self.backbone.width), + (self.backbone.image_shape[0], self.backbone.image_shape[1]), interpolation="nearest", data_format=data_format, ) @@ -323,7 +323,7 @@ def normalize_masks(x): x = ops.cast(x, "float32") x = ops.image.resize( x, - (self.backbone.height, self.backbone.width), + (self.backbone.image_shape[0], self.backbone.image_shape[1]), interpolation="nearest", data_format=data_format, ) @@ -384,8 +384,8 @@ def generate( Typically, `inputs` is a dict with `"images"` `"masks"` and `"prompts"` keys. `"images"` are reference images within a value range of - `[-1.0, 1.0]`, which will be resized to `self.backbone.height` and - `self.backbone.width`, then encoded into latent space by the VAE + `[-1.0, 1.0]`, which will be resized to height and width from + `self.backbone.image_shape`, then encoded into latent space by the VAE encoder. `"masks"` are mask images with a boolean dtype, where white pixels are repainted while black pixels are preserved. `"prompts"` are strings that will be tokenized and encoded by the text encoder. diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py index 485340fbd..4dd3e4403 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py @@ -215,8 +215,8 @@ class StableDiffusion3Backbone(Backbone): model. Defaults to `1000`. shift: float. The shift value for the timestep schedule. Defaults to `3.0`. - height: optional int. The output height of the image. - width: optional int. The output width of the image. + image_shape: tuple. The input shape without the batch size. Defaults to + `(1024, 1024, 3)`. data_format: `None` or str. If specified, either `"channels_last"` or `"channels_first"`. The ordering of the dimensions in the inputs. `"channels_last"` corresponds to inputs with shape @@ -270,23 +270,21 @@ def __init__( output_channels=3, num_train_timesteps=1000, shift=3.0, - height=None, - width=None, + image_shape=(1024, 1024, 3), data_format=None, dtype=None, **kwargs, ): - height = int(height or 1024) - width = int(width or 1024) - if height % 8 != 0 or width % 8 != 0: - raise ValueError( - "`height` and `width` must be divisible by 8. " - f"Received: height={height}, width={width}" - ) data_format = standardize_data_format(data_format) if data_format != "channels_last": raise NotImplementedError - image_shape = (height, width, int(vae.input_channels)) + height = image_shape[0] + width = image_shape[1] + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + "height and width in `image_shape` must be divisible by 8. " + f"Received: image_shape={image_shape}" + ) latent_shape = (height // 8, width // 8, int(latent_channels)) context_shape = (None, 4096 if t5 is None else t5.hidden_dim) pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,) @@ -452,8 +450,7 @@ def __init__( self.output_channels = output_channels self.num_train_timesteps = num_train_timesteps self.shift = shift - self.height = height - self.width = width + self.image_shape = image_shape @property def latent_shape(self): @@ -585,8 +582,7 @@ def get_config(self): "output_channels": self.output_channels, "num_train_timesteps": self.num_train_timesteps, "shift": self.shift, - "height": self.height, - "width": self.width, + "image_shape": self.image_shape, } ) return config diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py index 37723b0b5..77415a6ee 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py @@ -11,7 +11,8 @@ class StableDiffusion3BackboneTest(TestCase): def setUp(self): - height, width = 64, 64 + image_shape = (64, 64, 3) + height, width = image_shape[0], image_shape[1] vae = VAEBackbone( [32, 32, 32, 32], [1, 1, 1, 1], @@ -36,8 +37,7 @@ def setUp(self): "vae": vae, "clip_l": clip_l, "clip_g": clip_g, - "height": height, - "width": width, + "image_shape": image_shape, } self.input_data = { "images": ops.ones((2, height, width, 3)), @@ -82,7 +82,6 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, init_kwargs={ - "height": self.init_kwargs["height"], - "width": self.init_kwargs["width"], + "image_shape": self.init_kwargs["image_shape"], }, ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py index 29c939e75..285ba834b 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py @@ -27,7 +27,7 @@ class StableDiffusion3ImageToImage(ImageToImage): Use `generate()` to do image generation. ```python image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset( - "stable_diffusion_3_medium", height=512, width=512 + "stable_diffusion_3_medium", image_shape=(512, 512, 3) ) image_to_image.generate( { diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py index 7debb6963..8fa5b167a 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py @@ -55,8 +55,7 @@ def setUp(self): clip_g=CLIPTextEncoder( 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" ), - height=64, - width=64, + image_shape=(64, 64, 3), ) self.init_kwargs = { "preprocessor": self.preprocessor, diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py index 90c11a723..8d5ed7c6a 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py @@ -29,7 +29,7 @@ class StableDiffusion3Inpaint(Inpaint): reference_image = np.ones((1024, 1024, 3), dtype="float32") reference_mask = np.ones((1024, 1024), dtype="float32") inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset( - "stable_diffusion_3_medium", height=512, width=512 + "stable_diffusion_3_medium", image_shape=(512, 512, 3) ) inpaint.generate( reference_image, diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py index faade4b1e..5e8ddd32c 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py @@ -55,8 +55,7 @@ def setUp(self): clip_g=CLIPTextEncoder( 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" ), - height=64, - width=64, + image_shape=(64, 64, 3), ) self.init_kwargs = { "preprocessor": self.preprocessor, diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py index 88d9b3d4b..a7756fc64 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py @@ -13,6 +13,6 @@ "path": "stable_diffusion_3", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/2", + "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/3", } } diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py index 623088f23..739c6f465 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py @@ -27,7 +27,7 @@ class StableDiffusion3TextToImage(TextToImage): Use `generate()` to do image generation. ```python text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset( - "stable_diffusion_3_medium", height=512, width=512 + "stable_diffusion_3_medium", image_shape=(512, 512, 3) ) text_to_image.generate( "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py index bbbb55b27..69d30de83 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py @@ -55,8 +55,7 @@ def setUp(self): clip_g=CLIPTextEncoder( 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" ), - height=64, - width=64, + image_shape=(64, 64, 3), ) self.init_kwargs = { "preprocessor": self.preprocessor, diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 261d1eda5..52aad373a 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -563,10 +563,8 @@ def get_backbone_kwargs(self, **kwargs): backbone_kwargs["dtype"] = kwargs.pop("dtype", None) # Forward `height` and `width` to backbone when using `TextToImage`. - if "height" in kwargs: - backbone_kwargs["height"] = kwargs.pop("height", None) - if "width" in kwargs: - backbone_kwargs["width"] = kwargs.pop("width", None) + if "image_shape" in kwargs: + backbone_kwargs["image_shape"] = kwargs.pop("image_shape", None) return backbone_kwargs, kwargs diff --git a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py index 51b9082cc..38e19cf10 100644 --- a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py @@ -113,8 +113,7 @@ def convert_model(preset, height, width): vae, clip_l, clip_g, - height=height, - width=width, + image_shape=(height, width, 3), name="stable_diffusion_3_backbone", ) return backbone @@ -532,8 +531,7 @@ def main(_): keras_preprocessor.save_to_preset(preset) # Set the image size to 1024, the same as in huggingface/diffusers. - keras_model.height = 1024 - keras_model.width = 1024 + keras_model.image_shape = (1024, 1024, 3) keras_model.save_to_preset(preset) print(f"🏁 Preset saved to ./{preset}.")