Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge #1165

Merged
merged 17 commits into from
Nov 16, 2024
Merged

merge #1165

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions documentation/LYCORIS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

## Using LyCORIS

To use LyCORIS, set `--lora_type=lycoris` and then set `--lycoris_config=config/lycoris_config.json`, where `config/lycoris_config.json` is the location of your LyCORIS configuration file:
To use LyCORIS, set `--lora_type=lycoris` and then set `--lycoris_config=config/lycoris_config.json`, where `config/lycoris_config.json` is the location of your LyCORIS configuration file.

```bash
MODEL_TYPE=lora
# We use trainer_extra_args for now, as Lycoris support is so new.
TRAINER_EXTRA_ARGS+=" --lora_type=lycoris --lycoris_config=config/lycoris_config.json"
The following will go into your `config.json`:
```json
{
"model_type": "lora",
"lora_type": "lycoris",
"lycoris_config": "config/lycoris_config.json",
"validation_lycoris_strength": 1.0,
...the rest of your settings...
}
```


Expand Down Expand Up @@ -48,7 +53,7 @@ Optional fields:
- any keyword arguments specific to the selected algorithm, at the end.

Mandatory fields:
- multiplier
- multiplier, which should be set to 1.0 only unless you know what to expect
- linear_dim
- linear_alpha

Expand Down Expand Up @@ -81,7 +86,8 @@ vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=dtype
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder="transformer")

lycoris_safetensors_path = 'pytorch_lora_weights.safetensors'
wrapper, _ = create_lycoris_from_weights(1.0, lycoris_safetensors_path, transformer)
lycoris_strength = 1.0
wrapper, _ = create_lycoris_from_weights(lycoris_strength, lycoris_safetensors_path, transformer)
wrapper.merge_to() # using apply_to() will be slower.

transformer.to(device, dtype=dtype)
Expand Down
9 changes: 9 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,15 @@ def get_argument_parser():
" and submit debug.log to a new Github issue report."
),
)
parser.add_argument(
"--validation_lycoris_strength",
type=float,
default=1.0,
help=(
"When inferencing for validations, the Lycoris model will by default be run at its training strength, 1.0."
" However, this value can be increased to a value of around 1.3 or 1.5 to get a stronger effect from the model."
),
)
parser.add_argument(
"--validation_torch_compile",
action="store_true",
Expand Down
7 changes: 7 additions & 0 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,13 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
raise ValueError(
f"VAE image embed cache directory {backend.get('cache_dir_vae')} is the same as the text embed cache directory. This is not allowed, the trainer will get confused."
)

if backend["type"] == "local" and (
vae_cache_dir is None or vae_cache_dir == ""
):
raise ValueError(
f"VAE image embed cache directory {backend.get('cache_dir_vae')} is not set. This is required for the VAE image embed cache."
)
init_backend["vaecache"] = VAECache(
id=init_backend["id"],
vae=StateTracker.get_vae(),
Expand Down
103 changes: 77 additions & 26 deletions helpers/data_backend/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import torch
from typing import Any
from regex import regex
import fcntl
import tempfile
import shutil

logger = logging.getLogger("LocalDataBackend")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
Expand All @@ -21,29 +24,50 @@ def __init__(self, accelerator, id: str, compress_cache: bool = False):

def read(self, filepath, as_byteIO: bool = False):
"""Read and return the content of the file."""
# Openfilepath as BytesIO:
with open(filepath, "rb") as file:
data = file.read()
if not as_byteIO:
return data
return BytesIO(data)
# Acquire a shared lock
fcntl.flock(file, fcntl.LOCK_SH)
try:
data = file.read()
if not as_byteIO:
return data
return BytesIO(data)
finally:
# Release the lock
fcntl.flock(file, fcntl.LOCK_UN)

def write(self, filepath: str, data: Any) -> None:
"""Write the provided data to the specified filepath."""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "wb") as file:
# Check if data is a Tensor, and if so, save it appropriately
if isinstance(data, torch.Tensor):
# logger.debug(f"Writing a torch file to disk.")
return self.torch_save(data, file)
elif isinstance(data, str):
# logger.debug(f"Writing a string to disk as {filepath}: {data}")
data = data.encode("utf-8")
else:
logger.debug(
f"Received an unknown data type to write to disk. Doing our best: {type(data)}"
)
file.write(data)
temp_dir = os.path.dirname(filepath)
temp_file_path = os.path.join(temp_dir, f".{os.path.basename(filepath)}.tmp")

# Open the temporary file for writing
with open(temp_file_path, "wb") as temp_file:
# Acquire an exclusive lock on the temporary file
fcntl.flock(temp_file, fcntl.LOCK_EX)
try:
# Write data to the temporary file
if isinstance(data, torch.Tensor):
# Use the torch_save method, passing the temp file
self.torch_save(data, temp_file)
return # torch_save handles closing the file
elif isinstance(data, str):
data = data.encode("utf-8")
else:
logger.debug(
f"Received an unknown data type to write to disk. Doing our best: {type(data)}"
)
temp_file.write(data)
temp_file.flush()
os.fsync(temp_file.fileno())
finally:
# Release the lock
fcntl.flock(temp_file, fcntl.LOCK_UN)

# Atomically replace the target file with the temporary file
os.rename(temp_file_path, filepath)


def delete(self, filepath):
"""Delete the specified file."""
Expand Down Expand Up @@ -212,16 +236,43 @@ def torch_save(self, data, original_location):
Save a torch tensor to a file.
"""
if isinstance(original_location, str):
location = self.open_file(original_location, "wb")
else:
location = original_location
filepath = original_location
os.makedirs(os.path.dirname(filepath), exist_ok=True)
temp_dir = os.path.dirname(filepath)
temp_file_path = os.path.join(temp_dir, f".{os.path.basename(filepath)}.tmp")

if self.compress_cache:
compressed_data = self._compress_torch(data)
location.write(compressed_data)
with open(temp_file_path, "wb") as temp_file:
# Acquire an exclusive lock on the temporary file
fcntl.flock(temp_file, fcntl.LOCK_EX)
try:
if self.compress_cache:
compressed_data = self._compress_torch(data)
temp_file.write(compressed_data)
else:
torch.save(data, temp_file)
temp_file.flush()
os.fsync(temp_file.fileno())
finally:
# Release the lock
fcntl.flock(temp_file, fcntl.LOCK_UN)
# Atomically replace the target file with the temporary file
os.rename(temp_file_path, filepath)
else:
torch.save(data, location)
location.close()
# Handle the case where original_location is a file object
temp_file = original_location
# Acquire an exclusive lock on the file object
fcntl.flock(temp_file, fcntl.LOCK_EX)
try:
if self.compress_cache:
compressed_data = self._compress_torch(data)
temp_file.write(compressed_data)
else:
torch.save(data, temp_file)
temp_file.flush()
os.fsync(temp_file.fileno())
finally:
# Release the lock
fcntl.flock(temp_file, fcntl.LOCK_UN)

def write_batch(self, filepaths: list, data_list: list) -> None:
"""Write a batch of data to the specified filepaths."""
Expand Down
61 changes: 52 additions & 9 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def code_example(args, repo_id: str = None):
image = pipeline(
prompt=prompt,{_negative_prompt(args, in_call=True) if args.model_family.lower() != 'flux' else ''}
num_inference_steps={args.validation_num_inference_steps},
generator=torch.Generator(device={_torch_device()}).manual_seed(1641421826),
generator=torch.Generator(device={_torch_device()}).manual_seed({args.validation_seed or args.seed or 42}),
{_validation_resolution(args)}
guidance_scale={args.validation_guidance},{_guidance_rescale(args)}{_skip_layers(args)}
).images[0]
Expand Down Expand Up @@ -293,13 +293,15 @@ def lora_info(args):
lycoris_config = json.load(file)
except:
lycoris_config = {"error": "could not locate or load LyCORIS config."}
return f"""- LyCORIS Config:\n```json\n{json.dumps(lycoris_config, indent=4)}\n```"""
return f"""### LyCORIS Config:\n```json\n{json.dumps(lycoris_config, indent=4)}\n```"""


def model_card_note(args):
"""Return a string with the model card note."""
note_contents = args.model_card_note if args.model_card_note else ""
return f"\n{note_contents}\n"
if note_contents is None or note_contents == "":
return ""
return f"\n**Note:** {note_contents}\n"


def flux_schedule_info(args):
Expand All @@ -312,6 +314,7 @@ def flux_schedule_info(args):
output_args.append("flux_schedule_auto_shift")
if args.flux_schedule_shift is not None:
output_args.append(f"shift={args.flux_schedule_shift}")
output_args.append(f"flux_guidance_mode={args.flux_guidance_mode}")
if args.flux_guidance_value:
output_args.append(f"flux_guidance_value={args.flux_guidance_value}")
if args.flux_guidance_min:
Expand All @@ -324,6 +327,9 @@ def flux_schedule_info(args):
output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}")
if args.flux_attention_masked_training:
output_args.append("flux_attention_masked_training")
if args.t5_padding != "unmodified":
output_args.append(f"t5_padding={args.t5_padding}")
output_args.append(f"flow_matching_loss={args.flow_matching_loss}")
if (
args.model_type == "lora"
and args.lora_type == "standard"
Expand Down Expand Up @@ -363,11 +369,45 @@ def sd3_schedule_info(args):
return output_str


def ddpm_schedule_info(args):
"""Information about DDPM schedules, eg. rescaled betas or offset noise"""
output_args = []
if args.snr_gamma:
output_args.append(f"snr_gamma={args.snr_gamma}")
if args.use_soft_min_snr:
output_args.append(f"use_soft_min_snr")
if args.soft_min_snr_sigma_data:
output_args.append(
f"soft_min_snr_sigma_data={args.soft_min_snr_sigma_data}"
)
if args.rescale_betas_zero_snr:
output_args.append(f"rescale_betas_zero_snr")
if args.offset_noise:
output_args.append(f"offset_noise")
output_args.append(f"noise_offset={args.noise_offset}")
output_args.append(f"noise_offset_probability={args.noise_offset_probability}")
output_args.append(
f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}"
)
output_args.append(
f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}"
)
output_str = (
f" (extra parameters={output_args})"
if output_args
else " (no special parameters set)"
)

return output_str


def model_schedule_info(args):
if args.model_family == "flux":
return flux_schedule_info(args)
if args.model_family == "sd3":
return sd3_schedule_info(args)
else:
return ddpm_schedule_info(args)


def save_model_card(
Expand Down Expand Up @@ -461,18 +501,19 @@ def save_model_card(

{'This is a **diffusion** model trained using DDPM objective instead of Flow matching. **Be sure to set the appropriate scheduler configuration.**' if args.model_family == "sd3" and args.flow_matching_loss == "diffusion" else ''}
{'The main validation prompt used during training was:' if prompt else 'Validation used ground-truth images as an input for partial denoising (img2img).' if args.validation_using_datasets else 'No validation prompt was used during training.'}
{model_card_note(args)}
{'```' if prompt else ''}
{prompt}
{'```' if prompt else ''}

{model_card_note(args)}
## Validation settings
- CFG: `{StateTracker.get_args().validation_guidance}`
- CFG Rescale: `{StateTracker.get_args().validation_guidance_rescale}`
- Steps: `{StateTracker.get_args().validation_num_inference_steps}`
- Sampler: `{StateTracker.get_args().validation_noise_scheduler}`
- Sampler: `{'FlowMatchEulerDiscreteScheduler' if args.model_family in ['sd3', 'flux'] else StateTracker.get_args().validation_noise_scheduler}`
- Seed: `{StateTracker.get_args().validation_seed}`
- Resolution{'s' if ',' in StateTracker.get_args().validation_resolution else ''}: `{StateTracker.get_args().validation_resolution}`
{f"- Skip-layer guidance: {_skip_layers(args)}" if args.model_family in ['sd3', 'flux'] else ''}

Note: The validation settings are not necessarily the same as the [training settings](#training-settings).

Expand All @@ -489,17 +530,19 @@ def save_model_card(
- Training epochs: {StateTracker.get_epoch() - 1}
- Training steps: {StateTracker.get_global_step()}
- Learning rate: {StateTracker.get_args().learning_rate}
- Learning rate schedule: {StateTracker.get_args().lr_scheduler}
- Warmup steps: {StateTracker.get_args().lr_warmup_steps}
- Max grad norm: {StateTracker.get_args().max_grad_norm}
- Effective batch size: {StateTracker.get_args().train_batch_size * StateTracker.get_args().gradient_accumulation_steps * StateTracker.get_accelerator().num_processes}
- Micro-batch size: {StateTracker.get_args().train_batch_size}
- Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps}
- Number of GPUs: {StateTracker.get_accelerator().num_processes}
- Gradient checkpointing: {StateTracker.get_args().gradient_checkpointing}
- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())}
- Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr}
- Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''}
- Precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
- Quantised: {f'Yes: {StateTracker.get_args().base_model_precision}' if StateTracker.get_args().base_model_precision != "no_change" else 'No'}
- Xformers: {'Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else 'Not used'}
- Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
- Caption dropout probability: {StateTracker.get_args().caption_dropout_probability * 100}%
{'- Xformers: Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else ''}
{lora_info(args=StateTracker.get_args())}

## Datasets
Expand Down
2 changes: 1 addition & 1 deletion helpers/training/save_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _save_full_model(self, models, weights, output_dir):
shutil.copy2(s, d)

# Remove the temporary directory
shutil.rmtree(temporary_dir)
shutil.rmtree(temporary_dir, ignore_errors=True)

def save_model_hook(self, models, weights, output_dir):
# Write "training_state.json" to the output directory containing the training state
Expand Down
2 changes: 1 addition & 1 deletion helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2717,7 +2717,7 @@ def train(self):
self.config.output_dir, removing_checkpoint
)
try:
shutil.rmtree(removing_checkpoint)
shutil.rmtree(removing_checkpoint, ignore_errors=True)
except Exception as e:
logger.error(
f"Failed to remove directory: {removing_checkpoint}"
Expand Down
6 changes: 6 additions & 0 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,10 @@ def setup_scheduler(self):
return scheduler

def setup_pipeline(self, validation_type, enable_ema_model: bool = True):
if hasattr(self.accelerator, "_lycoris_wrapped_network"):
self.accelerator._lycoris_wrapped_network.set_multiplier(float(getattr(
self.args, "validation_lycoris_strength", 1.0
)))
if validation_type == "intermediary" and self.args.use_ema:
if enable_ema_model:
if self.unet is not None:
Expand Down Expand Up @@ -1120,6 +1124,8 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True):

def clean_pipeline(self):
"""Remove the pipeline."""
if hasattr(self.accelerator, "_lycoris_wrapped_network"):
self.accelerator._lycoris_wrapped_network.set_multiplier(1.0)
if self.pipeline is not None:
del self.pipeline
self.pipeline = None
Expand Down
Loading
Loading