Skip to content

Commit

Permalink
(experimental) Allow EMA on LoRA/Lycoris networks
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Nov 17, 2024
1 parent 9cdf13c commit fd4de68
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 140 deletions.
14 changes: 5 additions & 9 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,7 @@ def get_argument_parser():
help=(
"Validations must be enabled for model evaluation to function. The default is to use no evaluator,"
" and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations."
)
),
)
parser.add_argument(
"--pretrained_evaluation_model_name_or_path",
Expand All @@ -1348,7 +1348,7 @@ def get_argument_parser():
"Optionally provide a custom model to use for ViT evaluations."
" The default is currently clip-vit-large-patch14-336, allowing for lower patch sizes (greater accuracy)"
" and an input resolution of 336x336."
)
),
)
parser.add_argument(
"--validation_on_startup",
Expand Down Expand Up @@ -2351,13 +2351,9 @@ def parse_cmdline_args(input_args=None):
)
args.gradient_precision = "fp32"

if args.use_ema:
if args.model_family == "sd3":
raise ValueError(
"Using EMA is not currently supported for Stable Diffusion 3 training."
)
if "lora" in args.model_type:
raise ValueError("Using EMA is not currently supported for LoRA training.")
# if args.use_ema:
# if "lora" in args.model_type:
# raise ValueError("Using EMA is not currently supported for LoRA training.")
args.logging_dir = os.path.join(args.output_dir, args.logging_dir)
args.accelerator_project_config = ProjectConfiguration(
project_dir=args.output_dir, logging_dir=args.logging_dir
Expand Down
215 changes: 138 additions & 77 deletions helpers/training/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from typing import Any, Dict, Iterable, Optional, Union
from diffusers.utils.deprecation_utils import deprecate
from diffusers.utils import is_transformers_available
from helpers.training.state_tracker import StateTracker

logger = logging.getLogger("EMAModel")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "WARNING"))


def should_update_ema(args, step):
Expand Down Expand Up @@ -119,6 +120,62 @@ def __init__(
self.model_config = model_config
self.args = args
self.accelerator = accelerator
self.training = True # To emulate nn.Module's training mode

def save_state_dict(self, path: str) -> None:
"""
Save the EMA model's state directly to a file.
Args:
path (str): The file path where the EMA state will be saved.
"""
# if the folder containing the path does not exist, create it
os.makedirs(os.path.dirname(path), exist_ok=True)
# grab state dict
state_dict = self.state_dict()
# save it using torch.save
torch.save(state_dict, path)
logger.info(f"EMA model state saved to {path}")

def load_state_dict(self, path: str) -> None:
"""
Load the EMA model's state from a file and apply it to this instance.
Args:
path (str): The file path from where the EMA state will be loaded.
"""
state_dict = torch.load(path, map_location="cpu", weights_only=True)

# Load metadata
self.decay = state_dict.get("decay", self.decay)
self.min_decay = state_dict.get("min_decay", self.min_decay)
self.optimization_step = state_dict.get(
"optimization_step", self.optimization_step
)
self.update_after_step = state_dict.get(
"update_after_step", self.update_after_step
)
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
self.power = state_dict.get("power", self.power)

# Load shadow parameters
shadow_params = []
idx = 0
while f"shadow_params.{idx}" in state_dict:
shadow_params.append(state_dict[f"shadow_params.{idx}"])
idx += 1

if len(shadow_params) != len(self.shadow_params):
raise ValueError(
f"Mismatch in number of shadow parameters: expected {len(self.shadow_params)}, "
f"but found {len(shadow_params)} in the state dict."
)

for current_param, loaded_param in zip(self.shadow_params, shadow_params):
current_param.data.copy_(loaded_param.data)

logger.info(f"EMA model state loaded from {path}")

@classmethod
def from_pretrained(cls, path, model_cls) -> "EMAModel":
Expand Down Expand Up @@ -176,7 +233,6 @@ def get_decay(self, optimization_step: int = None) -> float:
@torch.no_grad()
def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None):
if not should_update_ema(self.args, global_step):

return

if self.args.ema_device == "cpu" and not self.args.ema_cpu_only:
Expand Down Expand Up @@ -290,6 +346,7 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
)
else:
for s_param, param in zip(self.shadow_params, parameters):
print(f"From shape: {s_param.shape}, to shape: {param.shape}")
param.data.copy_(s_param.to(param.device).data)

def pin_memory(self) -> None:
Expand All @@ -307,59 +364,71 @@ def pin_memory(self) -> None:
# This probably won't work, but we'll do it anyway.
self.shadow_params = [p.pin_memory() for p in self.shadow_params]

def to(self, device=None, dtype=None, non_blocking=False) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
def to(self, *args, **kwargs):
for param in self.shadow_params:
param.data = param.data.to(*args, **kwargs)
return self

Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
(
p.to(device=device, dtype=dtype, non_blocking=non_blocking)
if p.is_floating_point()
else p.to(device=device, non_blocking=non_blocking)
)
for p in self.shadow_params
]
def cuda(self, device=None):
return self.to(device="cuda" if device is None else f"cuda:{device}")

def cpu(self):
return self.to(device="cpu")

def state_dict(self) -> dict:
def state_dict(self, destination=None, prefix="", keep_vars=False):
r"""
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
checkpointing to save the ema state dict.
Returns a dictionary containing a whole state of the EMA model.
"""
# Following PyTorch conventions, references to tensors are returned:
# "returns a reference to the state and not its copy!" -
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
return {
state_dict = {
"decay": self.decay,
"min_decay": self.min_decay,
"optimization_step": self.optimization_step,
"update_after_step": self.update_after_step,
"use_ema_warmup": self.use_ema_warmup,
"inv_gamma": self.inv_gamma,
"power": self.power,
"shadow_params": self.shadow_params,
}
for idx, param in enumerate(self.shadow_params):
state_dict[f"{prefix}shadow_params.{idx}"] = (
param if keep_vars else param.detach()
)
return state_dict

# def load_state_dict(self, state_dict: dict, strict=True) -> None:
# r"""
# Loads the EMA model's state.
# """
# self.decay = state_dict.get("decay", self.decay)
# self.min_decay = state_dict.get("min_decay", self.min_decay)
# self.optimization_step = state_dict.get(
# "optimization_step", self.optimization_step
# )
# self.update_after_step = state_dict.get(
# "update_after_step", self.update_after_step
# )
# self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
# self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
# self.power = state_dict.get("power", self.power)

# # Load shadow parameters
# shadow_params = []
# idx = 0
# while f"shadow_params.{idx}" in state_dict:
# shadow_params.append(state_dict[f"shadow_params.{idx}"])
# idx += 1
# if len(shadow_params) != len(self.shadow_params):
# raise ValueError("Mismatch in number of shadow parameters")
# self.shadow_params = shadow_params

def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Args:
Save the current parameters for restoring later.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]

def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Args:
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
validation (or model saving), use this to restore the former parameters.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the parameters with which this
`ExponentialMovingAverage` was initialized will be used.
Restore the parameters stored with the `store` method.
"""
if self.temp_stored_params is None:
raise RuntimeError(
Expand All @@ -378,53 +447,45 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
# Better memory-wise.
self.temp_stored_params = None

def load_state_dict(self, state_dict: dict) -> None:
r"""
Args:
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
ema state dict.
state_dict (dict): EMA state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = copy.deepcopy(state_dict)
def parameter_count(self) -> int:
return sum(p.numel() for p in self.shadow_params)

self.decay = state_dict.get("decay", self.decay)
if self.decay < 0.0 or self.decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
# Implementing nn.Module methods to emulate its behavior

self.min_decay = state_dict.get("min_decay", self.min_decay)
if not isinstance(self.min_decay, float):
raise ValueError("Invalid min_decay")
def named_children(self):
# No child modules
return iter([])

self.optimization_step = state_dict.get(
"optimization_step", self.optimization_step
)
if not isinstance(self.optimization_step, int):
raise ValueError("Invalid optimization_step")
def children(self):
return iter([])

self.update_after_step = state_dict.get(
"update_after_step", self.update_after_step
)
if not isinstance(self.update_after_step, int):
raise ValueError("Invalid update_after_step")
def modules(self):
yield self

self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
if not isinstance(self.use_ema_warmup, bool):
raise ValueError("Invalid use_ema_warmup")
def named_modules(self, memo=None, prefix=""):
yield prefix, self

self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
if not isinstance(self.inv_gamma, (float, int)):
raise ValueError("Invalid inv_gamma")
def parameters(self, recurse=True):
return iter(self.shadow_params)

self.power = state_dict.get("power", self.power)
if not isinstance(self.power, (float, int)):
raise ValueError("Invalid power")

shadow_params = state_dict.get("shadow_params", None)
if shadow_params is not None:
self.shadow_params = shadow_params
if not isinstance(self.shadow_params, list):
raise ValueError("shadow_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
raise ValueError("shadow_params must all be Tensors")
def named_parameters(self, prefix="", recurse=True):
for i, param in enumerate(self.shadow_params):
name = f"{prefix}shadow_params.{i}"
yield name, param

def buffers(self, recurse=True):
return iter([])

def named_buffers(self, prefix="", recurse=True):
return iter([])

def train(self, mode=True):
self.training = mode
return self

def eval(self):
return self.train(False)

def zero_grad(self):
# No gradients to zero in EMA model
pass
54 changes: 48 additions & 6 deletions helpers/training/quantisation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,15 @@ def get_quant_fn(base_model_precision):


def quantise_model(
unet, transformer, text_encoder_1, text_encoder_2, text_encoder_3, controlnet, args
unet=None,
transformer=None,
text_encoder_1=None,
text_encoder_2=None,
text_encoder_3=None,
controlnet=None,
ema=None,
args=None,
return_dict: bool = False,
):
"""
Quantizes the provided models using the specified precision settings.
Expand All @@ -218,6 +226,7 @@ def quantise_model(
text_encoder_2: The second text encoder to quantize.
text_encoder_3: The third text encoder to quantize.
controlnet: The ControlNet model to quantize.
ema: An EMAModel to quantize.
args: An object containing precision settings and other arguments.
Returns:
Expand Down Expand Up @@ -273,6 +282,14 @@ def quantise_model(
"base_model_precision": args.base_model_precision,
},
),
(
ema,
{
"quant_fn": get_quant_fn(args.base_model_precision),
"model_precision": args.base_model_precision,
"quantize_activations": args.quantize_activations,
},
),
]

# Iterate over the models and apply quantization if the model is not None
Expand All @@ -293,8 +310,33 @@ def quantise_model(
models[i] = (quant_fn(model, **quant_args_combined), quant_args)

# Unpack the quantized models
transformer, unet, controlnet, text_encoder_1, text_encoder_2, text_encoder_3 = [
model for model, _ in models
]

return unet, transformer, text_encoder_1, text_encoder_2, text_encoder_3, controlnet
(
transformer,
unet,
controlnet,
text_encoder_1,
text_encoder_2,
text_encoder_3,
ema,
) = [model for model, _ in models]

if return_dict:
return {
"unet": unet,
"transformer": transformer,
"text_encoder_1": text_encoder_1,
"text_encoder_2": text_encoder_2,
"text_encoder_3": text_encoder_3,
"controlnet": controlnet,
"ema": ema,
}

return (
unet,
transformer,
text_encoder_1,
text_encoder_2,
text_encoder_3,
controlnet,
ema,
)
Loading

0 comments on commit fd4de68

Please sign in to comment.