Skip to content

Commit

Permalink
Merge branch 'main' into lora_modules
Browse files Browse the repository at this point in the history
  • Loading branch information
linoytsaban authored Oct 25, 2024
2 parents 29152db + 435f6b7 commit 7276da7
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 25 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/push_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:
- name: Environment
run: |
python utils/print_env.py
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
- name: PyTorch CUDA checkpoint tests on Ubuntu
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
Expand Down Expand Up @@ -184,7 +184,7 @@ jobs:
run: |
python utils/print_env.py
- name: Run slow Flax TPU tests
- name: Run Flax TPU tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
Expand Down Expand Up @@ -232,7 +232,7 @@ jobs:
run: |
python utils/print_env.py
- name: Run slow ONNXRuntime CUDA tests
- name: Run ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
Expand Down
19 changes: 14 additions & 5 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def save_model_card(
validation_prompt=None,
repo_folder=None,
):
if "large" in base_model:
model_variant = "SD3.5-Large"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
else:
model_variant = "SD3"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
variant_tags = ["sd3", "sd3-diffusers"]

widget_dict = []
if images is not None:
for i, image in enumerate(images):
Expand All @@ -95,7 +104,7 @@ def save_model_card(
)

model_description = f"""
# SD3 DreamBooth LoRA - {repo_id}
# {model_variant} DreamBooth LoRA - {repo_id}
<Gallery />
Expand All @@ -120,7 +129,7 @@ def save_model_card(
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda')
pipeline = AutoPipelineForText2Image.from_pretrained({base_model}, torch_dtype=torch.float16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
```
Expand All @@ -135,7 +144,7 @@ def save_model_card(
## License
Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
Please adhere to the licensing terms as described [here]({license_url}).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
Expand All @@ -151,11 +160,11 @@ def save_model_card(
"diffusers-training",
"diffusers",
"lora",
"sd3",
"sd3-diffusers",
"template:sd-lora",
]

tags += variant_tags

model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))

Expand Down
16 changes: 12 additions & 4 deletions examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def save_model_card(
validation_prompt=None,
repo_folder=None,
):
if "large" in base_model:
model_variant = "SD3.5-Large"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
else:
model_variant = "SD3"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
variant_tags = ["sd3", "sd3-diffusers"]

widget_dict = []
if images is not None:
for i, image in enumerate(images):
Expand All @@ -86,7 +95,7 @@ def save_model_card(
)

model_description = f"""
# SD3 DreamBooth - {repo_id}
# {model_variant} DreamBooth - {repo_id}
<Gallery />
Expand All @@ -113,7 +122,7 @@ def save_model_card(
## License
Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.
Please adhere to the licensing terms as described `[here]({license_url})`.
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
Expand All @@ -128,10 +137,9 @@ def save_model_card(
"text-to-image",
"diffusers-training",
"diffusers",
"sd3",
"sd3-diffusers",
"template:sd-lora",
]
tags += variant_tags

model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
Expand Down
59 changes: 56 additions & 3 deletions src/diffusers/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,17 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s

class SDXLCFGCutoffCallback(PipelineCallback):
"""
Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.
Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""

tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
tensor_inputs = [
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
]

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
Expand All @@ -129,6 +133,55 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids

return callback_kwargs


class SDXLControlnetCFGCutoffCallback(PipelineCallback):
"""
Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""

tensor_inputs = [
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
"image",
]

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)

if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.

add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens

add_time_ids = callback_kwargs[self.tensor_inputs[2]]
add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector

# For Controlnet
image = callback_kwargs[self.tensor_inputs[3]]
image = image[-1:]

pipeline._guidance_scale = 0.0

callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
callback_kwargs[self.tensor_inputs[3]] = image

return callback_kwargs


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ class StableDiffusionXLControlNetPipeline(
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
"image",
]

def __init__(
Expand Down Expand Up @@ -1540,6 +1541,7 @@ def __call__(
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
image = callback_outputs.pop("image", image)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -903,9 +903,12 @@ def __call__(

timestep = t.expand(latents.shape[0]).to(latents.dtype)

guidance = (
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
)
if isinstance(self.controlnet, FluxMultiControlNetModel):
use_guidance = self.controlnet.nets[0].config.guidance_embeds
else:
use_guidance = self.controlnet.config.guidance_embeds

guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

if isinstance(controlnet_keep[i], list):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -762,8 +762,8 @@ def __call__(
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,8 +800,8 @@ def __call__(
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -921,8 +921,8 @@ def __call__(
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
Expand Down
4 changes: 3 additions & 1 deletion tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from diffusers.utils.testing_utils import (
floats_tensor,
is_peft_available,
nightly,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
Expand Down Expand Up @@ -165,9 +166,10 @@ def test_modify_padding_mode(self):


@slow
@nightly
@require_torch_gpu
@require_peft_backend
# @unittest.skip("We cannot run inference on this model with the current CI hardware")
@unittest.skip("We cannot run inference on this model with the current CI hardware")
# TODO (DN6, sayakpaul): move these tests to a beefier GPU
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace.
Expand Down
2 changes: 2 additions & 0 deletions tests/lora/test_lora_layers_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
load_image,
nightly,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
Expand Down Expand Up @@ -207,6 +208,7 @@ def test_integration_move_lora_dora_cpu(self):


@slow
@nightly
@require_torch_gpu
@require_peft_backend
class LoraIntegrationTests(unittest.TestCase):
Expand Down
1 change: 1 addition & 0 deletions tests/lora/test_lora_layers_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def tearDown(self):


@slow
@nightly
@require_torch_gpu
@require_peft_backend
class LoraSDXLIntegrationTests(unittest.TestCase):
Expand Down

0 comments on commit 7276da7

Please sign in to comment.