Skip to content

Commit

Permalink
allow conditioning on an input PNG/JPG
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-q committed Nov 7, 2024
1 parent de8a97e commit 870f4d5
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 19 deletions.
29 changes: 13 additions & 16 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dit import DiT_models
from vae import VAE_models
from torchvision.io import read_video, write_video
from utils import one_hot_actions, sigmoid_beta_schedule
from utils import load_prompt, load_actions, sigmoid_beta_schedule
from tqdm import tqdm
from einops import rearrange
from torch import autocast
Expand All @@ -31,7 +31,7 @@ def main(args):

# load VAE checkpoint
vae = VAE_models["vit-l-20-shallow-encoder"]()
print(f"loading ViT-VAE-L-20 from oasis-ckpt={os.path.abspath(args.oasis_ckpt)}...")
print(f"loading ViT-VAE-L/20 from vae-ckpt={os.path.abspath(args.vae_ckpt)}...")
if args.vae_ckpt.endswith(".pt"):
vae_ckpt = torch.load(args.vae_ckpt, weights_only=True)
vae.load_state_dict(vae_ckpt)
Expand All @@ -41,33 +41,26 @@ def main(args):

# sampling params
B = 1
n_prompt_frames = args.n_prompt_frames
total_frames = args.num_frames
max_noise_level = 1000
ddim_noise_steps = args.ddim_steps
noise_range = torch.linspace(-1, max_noise_level - 1, ddim_noise_steps + 1)
noise_abs_max = 20
ctx_max_noise_idx = ddim_noise_steps // 10 * 3

# get input video
video_id = "snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001"
mp4_path = f"sample_data/{video_id}.mp4"
actions_path = f"sample_data/{video_id}.actions.pt"
video = read_video(mp4_path, pts_unit="sec")[0].float() / 255
actions = one_hot_actions(torch.load(actions_path))
offset = 12*20 # change to where you want to start in the video!
video = video[offset:offset+total_frames].unsqueeze(0)
actions = actions[offset:offset+total_frames].unsqueeze(0)
actions[:, :1] = torch.zeros_like(actions[:, :1]) # zero-init first frame's action
# get prompt image/video
x = load_prompt(args.prompt_path, video_offset=args.video_offset, n_prompt_frames=n_prompt_frames)
# get input action stream
actions = load_actions(args.actions_path, action_offset=args.video_offset)[:, :total_frames]

# sampling inputs
n_prompt_frames = 1
x = video[:, :n_prompt_frames]
x = x.to(device)
actions = actions.to(device)

# vae encoding
scaling_factor = 0.07843137255
x = rearrange(x, "b t h w c -> (b t) c h w")
x = rearrange(x, "b t c h w -> (b t) c h w")
H, W = x.shape[-2:]
with torch.no_grad():
x = vae.encode(x * 2 - 1).mean * scaling_factor
Expand Down Expand Up @@ -137,7 +130,11 @@ def main(args):

parse.add_argument('--oasis-ckpt', type=str, help='Path to Oasis DiT checkpoint.', default="oasis500m.safetensors")
parse.add_argument('--vae-ckpt', type=str, help='Path to Oasis ViT-VAE checkpoint.', default="vit-l-20.safetensors")
parse.add_argument('--num-frames', type=int, help='How many frames should be generated?', default=32)
parse.add_argument('--num-frames', type=int, help='How many frames should the output be?', default=32)
parse.add_argument('--prompt-path', type=str, help='Path to image or video to condition generation on.', default="sample_data/sample_image_0.png")
parse.add_argument('--actions-path', type=str, help='File to load actions from (.actions.pt or .one_hot_actions.pt)', default="sample_data/sample_actions_0.one_hot_actions.pt")
parse.add_argument('--video-offset', type=int, help='If loading prompt from video, index of frame to start reading from.', default=None)
parse.add_argument('--n-prompt-frames', type=int, help='If the prompt is a video, how many frames to condition on.', default=1)
parse.add_argument('--output-path', type=str, help='Path where generated video should be saved.', default="video.mp4")
parse.add_argument('--fps', type=int, help='What framerate should be used to save the output?', default=20)
parse.add_argument('--ddim-steps', type=int, help='How many DDIM steps?', default=50)
Expand Down
Binary file not shown.
Binary file added sample_data/sample_actions_0.one_hot_actions.pt
Binary file not shown.
Binary file added sample_data/sample_image_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
43 changes: 40 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import math
import torch
from torch import nn
from einops import rearrange, parse_shape
from typing import Mapping, Sequence
import torch
from torchvision.io import read_image, read_video
from torchvision.transforms.functional import resize
from einops import rearrange
from typing import Mapping, Sequence


def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
Expand Down Expand Up @@ -77,3 +77,40 @@ def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
actions_one_hot[i, j] = value

return actions_one_hot

IMAGE_EXTENSIONS = {"png", "jpg", "jpeg"}
VIDEO_EXTENSIONS = {"mp4"}

def load_prompt(path, video_offset=None, n_prompt_frames=1):
if path.lower().split(".")[-1] in IMAGE_EXTENSIONS:
print("prompt is image; ignoring video_offset and n_prompt_frames")
prompt = read_image(path)
# add frame dimension
prompt = rearrange(prompt, "c h w -> 1 c h w")
elif path.lower().split(".")[-1] in VIDEO_EXTENSIONS:
prompt = read_video(path, pts_unit="sec")[0]
if video_offset is not None:
prompt = prompt[video_offset:]
prompt = prompt[:n_prompt_frames]
else:
raise ValueError(f"unrecognized prompt file extension; expected one in {IMAGE_EXTENSIONS} or {VIDEO_EXTENSIONS}")
assert prompt.shape[0] == n_prompt_frames, f"input prompt {path} had less than n_prompt_frames={n_prompt_frames} frames"
prompt = resize(prompt, (360, 640))
# add batch dimension
prompt = rearrange(prompt, "t c h w -> 1 t c h w")
prompt = prompt.float() / 255.0
return prompt

def load_actions(path, action_offset=None):
if path.endswith(".actions.pt"):
actions = one_hot_actions(torch.load(path))
elif path.endswith(".one_hot_actions.pt"):
actions = torch.load(path, weights_only=True)
else:
raise ValueError("unrecognized action file extension; expected '*.actions.pt' or '*.one_hot_actions.pt'")
if action_offset is not None:
actions = actions[action_offset:]
# add batch dimension
actions = rearrange(actions, "t d -> 1 t d")
actions[:, :1] = torch.zeros_like(actions[:, :1]) # zero-init first frame's action
return actions

0 comments on commit 870f4d5

Please sign in to comment.