diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index 8fbe0c66..10011c96 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -299,7 +299,9 @@ def lora_info(args): def model_card_note(args): """Return a string with the model card note.""" note_contents = args.model_card_note if args.model_card_note else "" - return f"\n{note_contents}\n" + if note_contents is None or note_contents == "": + return "" + return f"\n**Note:** {note_contents}\n" def flux_schedule_info(args): @@ -312,6 +314,7 @@ def flux_schedule_info(args): output_args.append("flux_schedule_auto_shift") if args.flux_schedule_shift is not None: output_args.append(f"shift={args.flux_schedule_shift}") + output_args.append(f"flux_guidance_mode={args.flux_guidance_mode}") if args.flux_guidance_value: output_args.append(f"flux_guidance_value={args.flux_guidance_value}") if args.flux_guidance_min: @@ -324,6 +327,9 @@ def flux_schedule_info(args): output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}") if args.flux_attention_masked_training: output_args.append("flux_attention_masked_training") + if args.t5_padding != "unmodified": + output_args.append(f"t5_padding={args.t5_padding}") + output_args.append(f"flow_matching_loss={args.flow_matching_loss}") if ( args.model_type == "lora" and args.lora_type == "standard" @@ -362,6 +368,7 @@ def sd3_schedule_info(args): return output_str + def ddpm_schedule_info(args): """Information about DDPM schedules, eg. rescaled betas or offset noise""" output_args = [] @@ -370,15 +377,21 @@ def ddpm_schedule_info(args): if args.use_soft_min_snr: output_args.append(f"use_soft_min_snr") if args.soft_min_snr_sigma_data: - output_args.append(f"soft_min_snr_sigma_data={args.soft_min_snr_sigma_data}") + output_args.append( + f"soft_min_snr_sigma_data={args.soft_min_snr_sigma_data}" + ) if args.rescale_betas_zero_snr: output_args.append(f"rescale_betas_zero_snr") if args.offset_noise: output_args.append(f"offset_noise") output_args.append(f"noise_offset={args.noise_offset}") output_args.append(f"noise_offset_probability={args.noise_offset_probability}") - output_args.append(f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}") - output_args.append(f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}") + output_args.append( + f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}" + ) + output_args.append( + f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}" + ) output_str = ( f" (extra parameters={output_args})" if output_args @@ -387,6 +400,7 @@ def ddpm_schedule_info(args): return output_str + def model_schedule_info(args): if args.model_family == "flux": return flux_schedule_info(args) @@ -396,7 +410,6 @@ def model_schedule_info(args): return ddpm_schedule_info(args) - def save_model_card( repo_id: str, images=None, @@ -488,18 +501,19 @@ def save_model_card( {'This is a **diffusion** model trained using DDPM objective instead of Flow matching. **Be sure to set the appropriate scheduler configuration.**' if args.model_family == "sd3" and args.flow_matching_loss == "diffusion" else ''} {'The main validation prompt used during training was:' if prompt else 'Validation used ground-truth images as an input for partial denoising (img2img).' if args.validation_using_datasets else 'No validation prompt was used during training.'} -{model_card_note(args)} {'```' if prompt else ''} {prompt} {'```' if prompt else ''} +{model_card_note(args)} ## Validation settings - CFG: `{StateTracker.get_args().validation_guidance}` - CFG Rescale: `{StateTracker.get_args().validation_guidance_rescale}` - Steps: `{StateTracker.get_args().validation_num_inference_steps}` -- Sampler: `{StateTracker.get_args().validation_noise_scheduler}` +- Sampler: `{'FlowMatchEulerDiscreteScheduler' if args.model_family in ['sd3', 'flux'] else StateTracker.get_args().validation_noise_scheduler}` - Seed: `{StateTracker.get_args().validation_seed}` - Resolution{'s' if ',' in StateTracker.get_args().validation_resolution else ''}: `{StateTracker.get_args().validation_resolution}` +{f"- Skip-layer guidance: {_skip_layers(args)}" if args.model_family in ['sd3', 'flux'] else ''} Note: The validation settings are not necessarily the same as the [training settings](#training-settings). diff --git a/tests/test_model_card.py b/tests/test_model_card.py index 51f9385e..1c596be5 100644 --- a/tests/test_model_card.py +++ b/tests/test_model_card.py @@ -36,7 +36,7 @@ def setUp(self): self.args.lora_init_type = "kaiming_uniform" self.args.model_card_note = "Test note" self.args.validation_using_datasets = False - self.args.flow_matching_loss = "flow-matching" + self.args.flow_matching_loss = "compatible" self.args.flux_fast_schedule = False self.args.flux_schedule_auto_shift = False self.args.flux_schedule_shift = None @@ -61,6 +61,9 @@ def setUp(self): self.args.optimizer_config = "" self.args.mixed_precision = "fp16" self.args.base_model_precision = "no_change" + self.args.flux_guidance_mode = "constant" + self.args.flux_guidance_value = 1.0 + self.args.t5_padding = "unmodified" self.args.enable_xformers_memory_efficient_attention = False def test_model_imports(self): @@ -203,7 +206,10 @@ def test_model_card_note(self): def test_flux_schedule_info(self): self.args.model_family = "flux" output = flux_schedule_info(self.args) - self.assertIn("(no special parameters set)", output) + self.assertEqual( + " (extra parameters=['flux_guidance_mode=constant', 'flux_guidance_value=1.0', 'flow_matching_loss=compatible'])", + output, + ) self.args.flux_fast_schedule = True output = flux_schedule_info(self.args)