From ac55992bce686fdc7b8e564c708954302f25fc3d Mon Sep 17 00:00:00 2001 From: junsong Date: Sun, 5 Jan 2025 05:47:08 -0800 Subject: [PATCH] fix the precision bug in inference --- configs/sana_config/512ms/Sana_600M_img512.yaml | 2 +- scripts/infer_run_inference_geneval.sh | 3 ++- scripts/inference_geneval.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/configs/sana_config/512ms/Sana_600M_img512.yaml b/configs/sana_config/512ms/Sana_600M_img512.yaml index b3f2e12..0857cc4 100644 --- a/configs/sana_config/512ms/Sana_600M_img512.yaml +++ b/configs/sana_config/512ms/Sana_600M_img512.yaml @@ -66,7 +66,7 @@ scheduler: predict_v: true noise_schedule: linear_flow pred_sigma: false - flow_shift: 1.0 + flow_shift: 3.0 # logit-normal timestep weighting_scheme: logit_normal logit_mean: 0.0 diff --git a/scripts/infer_run_inference_geneval.sh b/scripts/infer_run_inference_geneval.sh index b3dc3b4..f85011f 100644 --- a/scripts/infer_run_inference_geneval.sh +++ b/scripts/infer_run_inference_geneval.sh @@ -72,7 +72,7 @@ echo "Add label: $add_label" echo "Exist time prefix: $exist_time_prefix" cmd_template="DPM_TQDM=True python scripts/inference_geneval.py --config={config_file} --model_path={model_path} \ - --sampling_algo $sampling_algo --step $step --cfg_scale $cfg_scale \ + --sampling_algo $sampling_algo --step $step --cfg_scale $cfg_scale --sample_nums $sample_nums \ --gpu_id {gpu_id} --start_index {start_index} --end_index {end_index}" if [ -n "${add_label}" ]; then cmd_template="${cmd_template} --add_label ${add_label}" @@ -108,6 +108,7 @@ if [[ "$model_paths" == *.pth ]]; then cmd="${cmd//\{end_index\}/$end_index}" echo "Running on GPU $gpu_id: samples $start_index to $end_index" + echo $cmd eval CUDA_VISIBLE_DEVICES=$gpu_id $cmd & done wait diff --git a/scripts/inference_geneval.py b/scripts/inference_geneval.py index 1eee141..e5c6e68 100644 --- a/scripts/inference_geneval.py +++ b/scripts/inference_geneval.py @@ -235,6 +235,7 @@ def visualize(sample_steps, cfg_scale, pag_scale): latent_size, device=device, generator=generator, + dtype=weight_dtype, ) model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)