Skip to content

Commit

Permalink
Merge pull request #119 from invoke-ai/precision-settings
Browse files Browse the repository at this point in the history
Add support for fp16 and bf16 training without mixed_precision
  • Loading branch information
RyanJDick authored May 1, 2024
2 parents 0790df0 + d376cff commit f00f912
Show file tree
Hide file tree
Showing 27 changed files with 226 additions and 148 deletions.
12 changes: 12 additions & 0 deletions src/invoke_training/_shared/accelerator/accelerator_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from typing import Literal

import datasets
import diffusers
Expand Down Expand Up @@ -89,3 +90,14 @@ def get_mixed_precision_dtype(accelerator: Accelerator):
else:
raise NotImplementedError(f"mixed_precision mode '{accelerator.mixed_precision}' is not yet supported.")
return weight_dtype


def get_dtype_from_str(dtype_str: Literal["float16", "bfloat16", "float32"]) -> torch.dtype:
if dtype_str == "float16":
return torch.float16
elif dtype_str == "bfloat16":
return torch.bfloat16
elif dtype_str == "float32":
return torch.float32
else:
raise ValueError(f"Unsupported dtype: {dtype_str}")
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import typing
from enum import Enum

import torch
from diffusers import (
AutoencoderKL,
DDPMScheduler,
Expand Down Expand Up @@ -56,6 +57,7 @@ def load_models_sd(
model_name_or_path: str,
hf_variant: str | None = None,
base_embeddings: dict[str, str] = None,
dtype: torch.dtype | None = None,
) -> tuple[CLIPTokenizer, DDPMScheduler, CLIPTextModel, AutoencoderKL, UNet2DConditionModel]:
"""Load all models required for training from disk, transfer them to the
target training device and cast their weight dtypes.
Expand Down Expand Up @@ -88,6 +90,11 @@ def load_models_sd(
vae.requires_grad_(False)
unet.requires_grad_(False)

if dtype is not None:
text_encoder = text_encoder.to(dtype=dtype)
vae = vae.to(dtype=dtype)
unet = unet.to(dtype=dtype)

# Put models in 'eval' mode.
text_encoder.eval()
vae.eval()
Expand All @@ -101,6 +108,7 @@ def load_models_sdxl(
hf_variant: str | None = None,
vae_model: str | None = None,
base_embeddings: dict[str, str] = None,
dtype: torch.dtype | None = None,
) -> tuple[
CLIPTokenizer,
CLIPTokenizer,
Expand Down Expand Up @@ -159,6 +167,12 @@ def load_models_sdxl(
vae.requires_grad_(False)
unet.requires_grad_(False)

if dtype is not None:
text_encoder_1 = text_encoder_1.to(dtype=dtype)
text_encoder_2 = text_encoder_2.to(dtype=dtype)
vae = vae.to(dtype=dtype)
unet = unet.to(dtype=dtype)

# Put models in 'eval' mode.
text_encoder_1.eval()
text_encoder_2.eval()
Expand Down
28 changes: 18 additions & 10 deletions src/invoke_training/pipelines/_experimental/sd_dpo_lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,28 @@ class SdDirectPreferenceOptimizationLoraConfig(BasePipelineConfig):
Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
"""

mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified precision. The
trainable parameters are always kept in float32 precision to avoid issues with numerical stability.
weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"no"`: Use this mode if you have plenty of VRAM available.
- `"bf16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"fp16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
- `"fp8"`: You are likely to run into numerical stability issues with this mode. Only use this mode if you know what you are doing and are willing to work through some issues.
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines._experimental.sd_dpo_lora.config.SdDirectPreferenceOptimizationLoraConfig.mixed_precision].
""" # noqa: E501

mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See `accelerate.Accelerator` for more details.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
for more details.
""" # noqa: E501

xformers: bool = False
Expand Down
20 changes: 12 additions & 8 deletions src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers import CLIPTextModel, CLIPTokenizer

from invoke_training._shared.accelerator.accelerator_utils import (
get_mixed_precision_dtype,
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
Expand Down Expand Up @@ -225,11 +225,14 @@ def train(config: SdDirectPreferenceOptimizationLoraConfig, callbacks: list[Pipe
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)

weight_dtype = get_mixed_precision_dtype(accelerator)
weight_dtype = get_dtype_from_str(config.weight_dtype)

logger.info("Loading models.")
tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(
model_name_or_path=config.model, hf_variant=config.hf_variant, base_embeddings=config.base_embeddings
model_name_or_path=config.model,
hf_variant=config.hf_variant,
base_embeddings=config.base_embeddings,
dtype=weight_dtype,
)
ref_text_encoder = copy.deepcopy(text_encoder)
ref_unet = copy.deepcopy(unet)
Expand Down Expand Up @@ -352,11 +355,12 @@ def prep_peft_model(model, lr: float | None = None):
training_unet = prep_peft_model(unet, config.unet_learning_rate)
training_text_encoder = prep_peft_model(text_encoder, config.text_encoder_learning_rate)

# Make sure all trainable params are in float32.
for trainable_model in all_trainable_models:
for param in trainable_model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)
# If mixed_precision is enabled, cast all trainable params to float32.
if config.mixed_precision != "no":
for trainable_model in all_trainable_models:
for param in trainable_model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)

if config.gradient_checkpointing:
# We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.
Expand Down
24 changes: 15 additions & 9 deletions src/invoke_training/pipelines/stable_diffusion/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,24 @@ class SdLoraConfig(BasePipelineConfig):
Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
"""

mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified precision. The
trainable parameters are always kept in float32 precision to avoid issues with numerical stability.
weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"no"`: Use this mode if you have plenty of VRAM available.
- `"bf16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"fp16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
- `"fp8"`: You are likely to run into numerical stability issues with this mode. Only use this mode if you know what you are doing and are willing to work through some issues.
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig.mixed_precision].
""" # noqa: E501

mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
Expand Down
20 changes: 12 additions & 8 deletions src/invoke_training/pipelines/stable_diffusion/lora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers import CLIPTextModel, CLIPTokenizer

from invoke_training._shared.accelerator.accelerator_utils import (
get_mixed_precision_dtype,
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
Expand Down Expand Up @@ -288,11 +288,14 @@ def train(config: SdLoraConfig, callbacks: list[PipelineCallbacks] | None = None
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)

weight_dtype = get_mixed_precision_dtype(accelerator)
weight_dtype = get_dtype_from_str(config.weight_dtype)

logger.info("Loading models.")
tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(
model_name_or_path=config.model, hf_variant=config.hf_variant, base_embeddings=config.base_embeddings
model_name_or_path=config.model,
hf_variant=config.hf_variant,
base_embeddings=config.base_embeddings,
dtype=weight_dtype,
)

if config.xformers:
Expand Down Expand Up @@ -398,11 +401,12 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
)
text_encoder = inject_lora_layers(text_encoder, text_encoder_lora_config, lr=config.text_encoder_learning_rate)

# Make sure all trainable params are in float32.
for trainable_model in all_trainable_models:
for param in trainable_model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)
# If mixed_precision is enabled, cast all trainable params to float32.
if config.mixed_precision != "no":
for trainable_model in all_trainable_models:
for param in trainable_model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)

if config.gradient_checkpointing:
# We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,24 @@ class SdTextualInversionConfig(BasePipelineConfig):
`train_batch_size` when training with limited VRAM.
"""

mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified precision. The
trainable parameters are always kept in float32 precision to avoid issues with numerical stability.
weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"no"`: Use this mode if you have plenty of VRAM available.
- `"bf16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"fp16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
- `"fp8"`: You are likely to run into numerical stability issues with this mode. Only use this mode if you know what you are doing and are willing to work through some issues.
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig.mixed_precision].
""" # noqa: E501

mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers import CLIPTextModel, CLIPTokenizer, PreTrainedTokenizer

from invoke_training._shared.accelerator.accelerator_utils import (
get_mixed_precision_dtype,
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
Expand Down Expand Up @@ -62,7 +62,7 @@ def _save_ti_embeddings(
.get_input_embeddings()
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
)
learned_embeds_dict = {"emb_params": learned_embeds.detach().cpu()}
learned_embeds_dict = {"emb_params": learned_embeds.detach().cpu().to(torch.float32)}

save_state_dict(learned_embeds_dict, save_path)

Expand Down Expand Up @@ -161,11 +161,11 @@ def train(config: SdTextualInversionConfig, callbacks: list[PipelineCallbacks] |
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)

weight_dtype = get_mixed_precision_dtype(accelerator)
weight_dtype = get_dtype_from_str(config.weight_dtype)

logger.info("Loading models.")
tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(
model_name_or_path=config.model, hf_variant=config.hf_variant
model_name_or_path=config.model, hf_variant=config.hf_variant, dtype=weight_dtype
)

placeholder_tokens, placeholder_token_ids = _initialize_placeholder_tokens(
Expand Down Expand Up @@ -226,9 +226,8 @@ def train(config: SdTextualInversionConfig, callbacks: list[PipelineCallbacks] |
else:
vae.to(accelerator.device, dtype=weight_dtype)

# For mixed precision training, we cast all non-trainable weights (unet, vae) to half-precision as these weights are
# only used for inference, keeping weights in full precision is not required.
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

# Initialize the optimizer to only optimize the token embeddings.
optimizer = initialize_optimizer(config.optimizer, text_encoder.get_input_embeddings().parameters())
Expand Down
24 changes: 15 additions & 9 deletions src/invoke_training/pipelines/stable_diffusion_xl/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,24 @@ class SdxlLoraConfig(BasePipelineConfig):
Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
"""

mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified precision. The
trainable parameters are always kept in float32 precision to avoid issues with numerical stability.
weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"no"`: Use this mode if you have plenty of VRAM available.
- `"bf16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"fp16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
- `"fp8"`: You are likely to run into numerical stability issues with this mode. Only use this mode if you know what you are doing and are willing to work through some issues.
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig.mixed_precision].
""" # noqa: E501

mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
Expand Down
16 changes: 9 additions & 7 deletions src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers import CLIPPreTrainedModel, CLIPTextModel, PreTrainedTokenizer

from invoke_training._shared.accelerator.accelerator_utils import (
get_mixed_precision_dtype,
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
Expand Down Expand Up @@ -356,14 +356,15 @@ def train(config: SdxlLoraConfig, callbacks: list[PipelineCallbacks] | None = No
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)

weight_dtype = get_mixed_precision_dtype(accelerator)
weight_dtype = get_dtype_from_str(config.weight_dtype)

logger.info("Loading models.")
tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl(
model_name_or_path=config.model,
hf_variant=config.hf_variant,
vae_model=config.vae_model,
base_embeddings=config.base_embeddings,
dtype=weight_dtype,
)

if config.xformers:
Expand Down Expand Up @@ -479,11 +480,12 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
text_encoder_2, text_encoder_lora_config, lr=config.text_encoder_learning_rate
)

# Make sure all trainable params are in float32.
for trainable_model in all_trainable_models:
for param in trainable_model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)
# If mixed_precision is enabled, cast all trainable params to float32.
if config.mixed_precision != "no":
for trainable_model in all_trainable_models:
for param in trainable_model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)

if config.gradient_checkpointing:
# We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.
Expand Down
Loading

0 comments on commit f00f912

Please sign in to comment.