diff --git a/helpers/arguments.py b/helpers/arguments.py index 4851917c..c98fa17c 100644 --- a/helpers/arguments.py +++ b/helpers/arguments.py @@ -690,6 +690,18 @@ def parse_args(input_args=None): " Default: ddim" ), ) + parser.add_argument( + "--enable_watermark", + type=bool, + default=False, + action="store_true", + help=( + "The SDXL 0.9 and 1.0 licenses both require a watermark be used to identify any images created to be shared." + " Since the images created during validation typically are not shared, and we want the most accurate results," + " this watermarker is disabled by default. If you are sharing the validation images, it is up to you" + " to ensure that you are complying with the license, whether that is through this watermarker, or another." + ) + ) parser.add_argument( "--mixed_precision", type=str, diff --git a/helpers/legacy/validation.py b/helpers/legacy/validation.py index a7fb0805..23e6597f 100644 --- a/helpers/legacy/validation.py +++ b/helpers/legacy/validation.py @@ -167,6 +167,7 @@ def log_validations( vae=vae, revision=args.revision, torch_dtype=weight_dtype, + add_watermarker=args.enable_watermark ) pipeline.scheduler = SCHEDULER_NAME_MAP[ args.validation_noise_scheduler diff --git a/train_sdxl.py b/train_sdxl.py index 826cd7ac..64d89885 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -1285,6 +1285,7 @@ def collate_fn(batch): vae=vae, unet=unet, revision=args.revision, + add_watermarker=args.enable_watermark ) pipeline.set_progress_bar_config(disable=True) pipeline.scheduler = SCHEDULER_NAME_MAP[