Skip to content

Commit

Permalink
fix the precision bug in inference
Browse files Browse the repository at this point in the history
  • Loading branch information
lawrence-cj committed Jan 5, 2025
1 parent db768da commit ac55992
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion configs/sana_config/512ms/Sana_600M_img512.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ scheduler:
predict_v: true
noise_schedule: linear_flow
pred_sigma: false
flow_shift: 1.0
flow_shift: 3.0
# logit-normal timestep
weighting_scheme: logit_normal
logit_mean: 0.0
Expand Down
3 changes: 2 additions & 1 deletion scripts/infer_run_inference_geneval.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ echo "Add label: $add_label"
echo "Exist time prefix: $exist_time_prefix"

cmd_template="DPM_TQDM=True python scripts/inference_geneval.py --config={config_file} --model_path={model_path} \
--sampling_algo $sampling_algo --step $step --cfg_scale $cfg_scale \
--sampling_algo $sampling_algo --step $step --cfg_scale $cfg_scale --sample_nums $sample_nums \
--gpu_id {gpu_id} --start_index {start_index} --end_index {end_index}"
if [ -n "${add_label}" ]; then
cmd_template="${cmd_template} --add_label ${add_label}"
Expand Down Expand Up @@ -108,6 +108,7 @@ if [[ "$model_paths" == *.pth ]]; then
cmd="${cmd//\{end_index\}/$end_index}"

echo "Running on GPU $gpu_id: samples $start_index to $end_index"
echo $cmd
eval CUDA_VISIBLE_DEVICES=$gpu_id $cmd &
done
wait
Expand Down
1 change: 1 addition & 0 deletions scripts/inference_geneval.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def visualize(sample_steps, cfg_scale, pag_scale):
latent_size,
device=device,
generator=generator,
dtype=weight_dtype,
)
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)

Expand Down

0 comments on commit ac55992

Please sign in to comment.