Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update sv4d sampling script and readme #392

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions .vscode/launch.json
chunhanyao-stable marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Remote Attach",
"type": "debugpy",
"request": "attach",
"connect": {
"host": "localhost",
"port": 5678
},
"pathMappings": [
{
"localRoot": "${workspaceFolder}",
"remoteRoot": "."
}
]
}
]
}
7 changes: 4 additions & 3 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
- To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.
- Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details.
- Please check our [project page](), [tech report]() and [video summary]() for more details.
chunhanyao-stable marked this conversation as resolved.
Show resolved Hide resolved

**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`)
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [SV4D](https://huggingface.co/stabilityai/sv4d) and [SV3D_u]((https://huggingface.co/stabilityai/sv3d)) from HuggingFace)

To run **SV4D** on a single input video of 21 frames:
- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`
Expand All @@ -23,7 +23,8 @@ To run **SV4D** on a single input video of 21 frames:
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos (with noisy background), try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.

![tile](assets/sv4d.gif)

Expand Down
125 changes: 47 additions & 78 deletions scripts/demo/sv4d_helpers.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@
from sgm.util import default, instantiate_from_config


def load_module_gpu(model):
model.cuda()


def unload_module_gpu(model):
model.cpu()
torch.cuda.empty_cache()


def initial_model_load(model):
model.model.half()
return model


def get_resizing_factor(
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
) -> float:
Expand All @@ -60,75 +74,11 @@ def get_resizing_factor(
return factor


def load_img_for_prediction_no_st(
image_path: str,
mask_path: str,
W: int,
H: int,
crop_h: int,
crop_w: int,
device="cuda",
) -> torch.Tensor:
image = Image.open(image_path)
if image is None:
return None
image = np.array(image).astype(np.float32) / 255
h, w = image.shape[:2]
rotated = 0

mask = None
if mask_path is not None:
mask = Image.open(mask_path)
mask = np.array(mask).astype(np.float32) / 255
mask = np.any(mask.reshape(h, w, -1) > 0, axis=2, keepdims=True).astype(
np.float32
)
elif image.shape[-1] == 4:
mask = image[:, :, 3:]

if mask is not None:
image = image[:, :, :3] * mask + (1 - mask)
# if "DAVIS" in image_path:
# y, x, _ = np.where(mask > 0)
# x_mean, y_mean = np.mean(x), np.mean(y)
# else:
# x_mean, y_mean = w//2, h//2
# h_new = int(max(crop_h, crop_w) * 1.33)
# x_min = max(int(x_mean - h_new//2), 0)
# y_min = max(int(y_mean - h_new//2), 0)
# image_cropped = image[y_min : y_min + h_new, x_min : x_min + h_new]
# h_crop, w_crop = image_cropped.shape[:2]
# h_new = max(h_crop, w_crop)
# top = max((h_new - h_crop) // 2, 0)
# left = max((h_new - w_crop) // 2, 0)
# image_padded = np.ones((h_new, h_new, 3)).astype(np.float32)
# image_padded[top : top + h_crop, left : left + w_crop, :] = image_cropped
# image = image_padded
# h, w = image.shape[:2]

image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).to(dtype=torch.float32)
image = image.unsqueeze(0)

rfs = get_resizing_factor((H, W), (h, w))
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
top = (resize_size[0] - H) // 2
left = (resize_size[1] - W) // 2

image = torch.nn.functional.interpolate(
image, resize_size, mode="area", antialias=False
)
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
return image.to(device) * 2.0 - 1.0, rotated


def read_gif(input_path, n_frames):
frames = []
video = Image.open(input_path)
if video.n_frames < n_frames:
return frames
for img in ImageSequence.Iterator(video):
frames.append(img.convert("RGB"))
frames.append(img.convert("RGBA"))
if len(frames) == n_frames:
break
return frames
Expand Down Expand Up @@ -206,16 +156,17 @@ def read_video(
print(f"Loading {len(all_img_paths)} video frames...")
images = [Image.open(img_path) for img_path in all_img_paths]

if len(images) < n_frames:
images = (images + images[::-1])[:n_frames]

if len(images) != n_frames:
raise ValueError("Input video contains fewer than {n_frames} frames.")
raise ValueError(f"Input video contains fewer than {n_frames} frames.")

# Remove background and crop video frames
images_v0 = []
for image in images:
for t, image in enumerate(images):
if remove_bg:
if image.mode == "RGBA":
pass
else:
if image.mode != "RGBA":
image.thumbnail([W, H], Image.Resampling.LANCZOS)
image = remove(image.convert("RGBA"), alpha_matting=True)
image_arr = np.array(image)
Expand All @@ -225,11 +176,12 @@ def read_video(
)
x, y, w, h = cv2.boundingRect(mask)
max_size = max(w, h)
side_len = (
int(max_size / image_frame_ratio)
if image_frame_ratio is not None
else in_w
)
if t == 0:
side_len = (
int(max_size / image_frame_ratio)
if image_frame_ratio is not None
else in_w
)
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len // 2
padded_image[
Expand All @@ -239,7 +191,9 @@ def read_video(
rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)
rgba_arr = np.array(rgba) / 255.0
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
images = Image.fromarray((rgb * 255).astype(np.uint8))
image = Image.fromarray((rgb * 255).astype(np.uint8))
else:
image = image.convert("RGB").resize((W, H), Image.LANCZOS)
image = ToTensor()(image).unsqueeze(0).to(device)
images_v0.append(image * 2.0 - 1.0)
return images_v0
Expand Down Expand Up @@ -341,11 +295,13 @@ def denoiser(input, sigma, c):


def decode_latents(model, samples_z, timesteps):
load_module_gpu(model.first_stage_model)
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(samples_z, timesteps=timesteps)
else:
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_module_gpu(model.first_stage_model)
return samples


Expand Down Expand Up @@ -751,20 +707,21 @@ def do_sample(
else:
num_samples = [num_samples]

load_module_gpu(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
T=T,
additional_batch_uc_fields=additional_batch_uc_fields,
)

c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
)
unload_module_gpu(model.conditioner)

for k in c:
if not k == "crossattn":
Expand Down Expand Up @@ -805,15 +762,21 @@ def denoiser(input, sigma, c):
model.model, input, sigma, c, **additional_model_inputs
)

load_module_gpu(model.model)
load_module_gpu(model.denoiser)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)

load_module_gpu(model.first_stage_model)
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(
samples_z, timesteps=default(decoding_t, T)
)
else:
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_module_gpu(model.first_stage_model)

if filter is not None:
samples = filter(samples)
Expand Down Expand Up @@ -850,20 +813,21 @@ def do_sample_per_step(
else:
num_samples = [num_samples]

load_module_gpu(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
T=T,
additional_batch_uc_fields=additional_batch_uc_fields,
)

c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
)
unload_module_gpu(model.conditioner)

for k in c:
if not k == "crossattn":
Expand Down Expand Up @@ -917,6 +881,9 @@ def denoiser(input, sigma, c):
if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax
else 0.0
)

load_module_gpu(model.model)
load_module_gpu(model.denoiser)
samples_z = sampler.sampler_step(
s_in * sigmas[step],
s_in * sigmas[step + 1],
Expand All @@ -926,6 +893,8 @@ def denoiser(input, sigma, c):
uc,
gamma,
)
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)

return samples_z

Expand Down
6 changes: 0 additions & 6 deletions scripts/sampling/configs/sv4d.yaml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,6 @@ model:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler

# - input_key: cond_aug
# is_trainable: False
# target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
# params:
# outdim: 256

- input_key: polar_rad
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
Expand Down
9 changes: 6 additions & 3 deletions scripts/sampling/simple_video_sample_4d.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from scripts.demo.sv4d_helpers import (
decode_latents,
load_model,
initial_model_load,
read_video,
run_img2vid,
run_img2vid_per_step,
Expand All @@ -26,6 +27,7 @@ def sample(
output_folder: Optional[str] = "outputs/sv4d",
num_steps: Optional[int] = 20,
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
img_size: int = 576, # image resolution
fps_id: int = 6,
motion_bucket_id: int = 127,
cond_aug: float = 1e-5,
Expand All @@ -47,7 +49,7 @@ def sample(
V = 8 # number of views per sample
F = 8 # vae factor to downsize image->latent
C = 4
H, W = 576, 576
H, W = img_size, img_size
n_frames = 21 # number of input and output video frames
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
n_views_sv3d = 21
Expand All @@ -64,7 +66,7 @@ def sample(
"f": F,
"options": {
"discretization": 1,
"cfg": 2.5,
"cfg": 3.0,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
Expand Down Expand Up @@ -137,7 +139,7 @@ def sample(
for t in range(n_frames):
img_matrix[t][0] = images_v0[t]

base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
save_video(
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
img_matrix[0],
Expand All @@ -155,6 +157,7 @@ def sample(
num_steps,
verbose,
)
model = initial_model_load(model)

# Interleaved sampling for anchor frames
t0, v0 = 0, 0
Expand Down
2 changes: 1 addition & 1 deletion sgm/modules/spacetime_attention.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -593,4 +593,4 @@ def forward(
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out
return out
chunhanyao-stable marked this conversation as resolved.
Show resolved Hide resolved
Loading