Skip to content

Commit

Permalink
num_inference_steps for animtatediff
Browse files Browse the repository at this point in the history
  • Loading branch information
dkackman committed Mar 23, 2024
1 parent d8ed7f7 commit 08f34de
Showing 2 changed files with 4 additions and 3 deletions.
6 changes: 3 additions & 3 deletions swarm/test.py
Original file line number Diff line number Diff line change
@@ -184,7 +184,7 @@
"prompt": "A dancing marmot, 4k, high resolution",
"negative_prompt": "bad quality, worse quality, low resolution",
"workflow": "txt2vid",
"num_inference_steps": 8,
"num_inference_steps": 6,
"guidance_scale": 2.0,
"outputs": ["primary"],
"num_frames": 32,
@@ -196,9 +196,9 @@
"motion_adapter":
{
"model_name": "ByteDance/AnimateDiff-Lightning",
"checkpoint_file": "animatediff_lightning_8step_diffusers.safetensors",
"num_inference_steps": 4,
"checkpoint_file": "animatediff_lightning_4step_diffusers.safetensors",
},

"scheduler_args": {
"scheduler_type": "EulerDiscreteScheduler",
"beta_schedule": "linear",
1 change: 1 addition & 0 deletions swarm/video/tx2vid.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@ def txt2vid_diffusion_callback(device_identifier, model_name, **kwargs):
if "checkpoint_file" in motion_adapter_args:
motion_adapter = MotionAdapter()
motion_adapter.load_state_dict(load_file(hf_hub_download(motion_adapter_args["model_name"], motion_adapter_args["checkpoint_file"])))
kwargs["num_inference_steps"] = motion_adapter_args.pop("num_inference_steps", 4)
else:
motion_adapter = MotionAdapter.from_pretrained(motion_adapter_args["model_name"], torch_dtype=torch_dtype)

0 comments on commit 08f34de

Please sign in to comment.