Skip to content

Commit

Permalink
Merge pull request #1163 from bghira/feature/validation-lycoris-strength
Browse files Browse the repository at this point in the history
add more info to model card, refine contents
  • Loading branch information
bghira authored Nov 16, 2024
2 parents d223171 + d33ecad commit b815fb4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
28 changes: 21 additions & 7 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ def lora_info(args):
def model_card_note(args):
"""Return a string with the model card note."""
note_contents = args.model_card_note if args.model_card_note else ""
return f"\n{note_contents}\n"
if note_contents is None or note_contents == "":
return ""
return f"\n**Note:** {note_contents}\n"


def flux_schedule_info(args):
Expand All @@ -312,6 +314,7 @@ def flux_schedule_info(args):
output_args.append("flux_schedule_auto_shift")
if args.flux_schedule_shift is not None:
output_args.append(f"shift={args.flux_schedule_shift}")
output_args.append(f"flux_guidance_mode={args.flux_guidance_mode}")
if args.flux_guidance_value:
output_args.append(f"flux_guidance_value={args.flux_guidance_value}")
if args.flux_guidance_min:
Expand All @@ -324,6 +327,9 @@ def flux_schedule_info(args):
output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}")
if args.flux_attention_masked_training:
output_args.append("flux_attention_masked_training")
if args.t5_padding != "unmodified":
output_args.append(f"t5_padding={args.t5_padding}")
output_args.append(f"flow_matching_loss={args.flow_matching_loss}")
if (
args.model_type == "lora"
and args.lora_type == "standard"
Expand Down Expand Up @@ -362,6 +368,7 @@ def sd3_schedule_info(args):

return output_str


def ddpm_schedule_info(args):
"""Information about DDPM schedules, eg. rescaled betas or offset noise"""
output_args = []
Expand All @@ -370,15 +377,21 @@ def ddpm_schedule_info(args):
if args.use_soft_min_snr:
output_args.append(f"use_soft_min_snr")
if args.soft_min_snr_sigma_data:
output_args.append(f"soft_min_snr_sigma_data={args.soft_min_snr_sigma_data}")
output_args.append(
f"soft_min_snr_sigma_data={args.soft_min_snr_sigma_data}"
)
if args.rescale_betas_zero_snr:
output_args.append(f"rescale_betas_zero_snr")
if args.offset_noise:
output_args.append(f"offset_noise")
output_args.append(f"noise_offset={args.noise_offset}")
output_args.append(f"noise_offset_probability={args.noise_offset_probability}")
output_args.append(f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}")
output_args.append(f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}")
output_args.append(
f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}"
)
output_args.append(
f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}"
)
output_str = (
f" (extra parameters={output_args})"
if output_args
Expand All @@ -387,6 +400,7 @@ def ddpm_schedule_info(args):

return output_str


def model_schedule_info(args):
if args.model_family == "flux":
return flux_schedule_info(args)
Expand All @@ -396,7 +410,6 @@ def model_schedule_info(args):
return ddpm_schedule_info(args)



def save_model_card(
repo_id: str,
images=None,
Expand Down Expand Up @@ -488,18 +501,19 @@ def save_model_card(
{'This is a **diffusion** model trained using DDPM objective instead of Flow matching. **Be sure to set the appropriate scheduler configuration.**' if args.model_family == "sd3" and args.flow_matching_loss == "diffusion" else ''}
{'The main validation prompt used during training was:' if prompt else 'Validation used ground-truth images as an input for partial denoising (img2img).' if args.validation_using_datasets else 'No validation prompt was used during training.'}
{model_card_note(args)}
{'```' if prompt else ''}
{prompt}
{'```' if prompt else ''}
{model_card_note(args)}
## Validation settings
- CFG: `{StateTracker.get_args().validation_guidance}`
- CFG Rescale: `{StateTracker.get_args().validation_guidance_rescale}`
- Steps: `{StateTracker.get_args().validation_num_inference_steps}`
- Sampler: `{StateTracker.get_args().validation_noise_scheduler}`
- Sampler: `{'FlowMatchEulerDiscreteScheduler' if args.model_family in ['sd3', 'flux'] else StateTracker.get_args().validation_noise_scheduler}`
- Seed: `{StateTracker.get_args().validation_seed}`
- Resolution{'s' if ',' in StateTracker.get_args().validation_resolution else ''}: `{StateTracker.get_args().validation_resolution}`
{f"- Skip-layer guidance: {_skip_layers(args)}" if args.model_family in ['sd3', 'flux'] else ''}
Note: The validation settings are not necessarily the same as the [training settings](#training-settings).
Expand Down
10 changes: 8 additions & 2 deletions tests/test_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def setUp(self):
self.args.lora_init_type = "kaiming_uniform"
self.args.model_card_note = "Test note"
self.args.validation_using_datasets = False
self.args.flow_matching_loss = "flow-matching"
self.args.flow_matching_loss = "compatible"
self.args.flux_fast_schedule = False
self.args.flux_schedule_auto_shift = False
self.args.flux_schedule_shift = None
Expand All @@ -61,6 +61,9 @@ def setUp(self):
self.args.optimizer_config = ""
self.args.mixed_precision = "fp16"
self.args.base_model_precision = "no_change"
self.args.flux_guidance_mode = "constant"
self.args.flux_guidance_value = 1.0
self.args.t5_padding = "unmodified"
self.args.enable_xformers_memory_efficient_attention = False

def test_model_imports(self):
Expand Down Expand Up @@ -203,7 +206,10 @@ def test_model_card_note(self):
def test_flux_schedule_info(self):
self.args.model_family = "flux"
output = flux_schedule_info(self.args)
self.assertIn("(no special parameters set)", output)
self.assertEqual(
" (extra parameters=['flux_guidance_mode=constant', 'flux_guidance_value=1.0', 'flow_matching_loss=compatible'])",
output,
)

self.args.flux_fast_schedule = True
output = flux_schedule_info(self.args)
Expand Down

0 comments on commit b815fb4

Please sign in to comment.