Merge pull request #394 from Stability-AI/yiming/sv4d

merge sv4d changes: 1. reduce memory consumption (40G -> 20G) and speed up (500s -> 200s) 2. add gradio demo
This commit is contained in:
chunhanyao-stable
2024-08-02 22:57:15 -07:00
committed by GitHub
20 changed files with 750 additions and 120 deletions

View File

@@ -9,22 +9,23 @@
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
- To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.
- To run the community-build gradio demo locally, run `python -m scripts.demo.gradio_app_sv4d`.
- Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details.
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`)
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`)
To run **SV4D** on a single input video of 21 frames:
- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`
- Run `python scripts/sampling/simple_video_sample_4d.py --input_path <path/to/video>`
- `input_path` : The input video `<path/to/video>` can be
- a single video file in `gif` or `mp4` format, such as `assets/test_video1.mp4`, or
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/test_video1.mp4`, or
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
- a file name pattern matching images of video frames.
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
![tile](assets/sv4d.gif)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,496 @@
# Adding this at the very top of app.py to make 'generative-models' directory discoverable
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models"))
from glob import glob
from typing import Optional
import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from typing import List, Optional, Union
import torchvision
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
from scripts.demo.sv4d_helpers import (
decode_latents,
load_model,
initial_model_load,
read_video,
run_img2vid,
prepare_inputs,
do_sample_per_step,
sample_sv3d,
save_video,
preprocess_video,
)
# the tmp path, if /tmp/gradio is not writable, change it to a writable path
# os.environ["GRADIO_TEMP_DIR"] = "gradio_tmp"
version = "sv4d" # replace with 'sv3d_p' or 'sv3d_u' for other models
# Define the repo, local directory and filename
repo_id = "stabilityai/sv4d"
filename = f"{version}.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors"
local_dir = "checkpoints"
local_ckpt_path = os.path.join(local_dir, filename)
# Check if the file already exists
if not os.path.exists(local_ckpt_path):
# If the file doesn't exist, download it
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
print("File downloaded. (sv4d)")
else:
print("File already exists. No need to download. (sv4d)")
device = "cuda"
max_64_bit_int = 2**63 - 1
num_frames = 21
num_steps = 20
model_config = f"scripts/sampling/configs/{version}.yaml"
# Set model config
T = 5 # number of frames per sample
V = 8 # number of views per sample
F = 8 # vae factor to downsize image->latent
C = 4
H, W = 576, 576
n_frames = 21 # number of input and output video frames
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
n_views_sv3d = 21
subsampled_views = np.array(
[0, 2, 5, 7, 9, 12, 14, 16, 19]
) # subsample (V+1=)9 (uniform) views from 21 SV3D views
version_dict = {
"T": T * V,
"H": H,
"W": W,
"C": C,
"f": F,
"options": {
"discretization": 1,
"cfg": 3,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 5,
"num_steps": num_steps,
"force_uc_zero_embeddings": [
"cond_frames",
"cond_frames_without_noise",
"cond_view",
"cond_motion",
],
"additional_guider_kwargs": {
"additional_cond_keys": ["cond_view", "cond_motion"]
},
},
}
# Load SV4D model
model, filter = load_model(
model_config,
device,
version_dict["T"],
num_steps,
)
model = initial_model_load(model)
# -----------sv3d config and model loading----------------
# if version == "sv3d_u":
sv3d_model_config = "scripts/sampling/configs/sv3d_u.yaml"
# elif version == "sv3d_p":
# sv3d_model_config = "scripts/sampling/configs/sv3d_p.yaml"
# else:
# raise ValueError(f"Version {version} does not exist.")
# Define the repo, local directory and filename
repo_id = "stabilityai/sv3d"
filename = f"sv3d_u.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors"
local_dir = "checkpoints"
local_ckpt_path = os.path.join(local_dir, filename)
# Check if the file already exists
if not os.path.exists(local_ckpt_path):
# If the file doesn't exist, download it
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
print("File downloaded. (sv3d)")
else:
print("File already exists. No need to download. (sv3d)")
# load sv3d model
sv3d_model, filter = load_model(
sv3d_model_config,
device,
21,
num_steps,
verbose=False,
)
sv3d_model = initial_model_load(sv3d_model)
# ------------------
def sample_anchor(
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
seed: Optional[int] = None,
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
num_steps: int = 20,
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
fps_id: int = 6,
motion_bucket_id: int = 127,
cond_aug: float = 1e-5,
device: str = "cuda",
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
azimuths_deg: Optional[List[float]] = None,
verbose: Optional[bool] = False,
):
"""
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
"""
output_folder = os.path.dirname(input_path)
torch.manual_seed(seed)
os.makedirs(output_folder, exist_ok=True)
# Read input video frames i.e. images at view 0
print(f"Reading {input_path}")
images_v0 = read_video(
input_path,
n_frames=n_frames,
device=device,
)
# Get camera viewpoints
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
elevations_deg = [elevations_deg] * n_views_sv3d
assert (
len(elevations_deg) == n_views_sv3d
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
if azimuths_deg is None:
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
assert (
len(azimuths_deg) == n_views_sv3d
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
azimuths_rad = np.array(
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
)
# Sample multi-view images of the first frame using SV3D i.e. images at time 0
sv3d_model.sampler.num_steps = num_steps
print("sv3d_model.sampler.num_steps", sv3d_model.sampler.num_steps)
images_t0 = sample_sv3d(
images_v0[0],
n_views_sv3d,
num_steps,
sv3d_version,
fps_id,
motion_bucket_id,
cond_aug,
decoding_t,
device,
polars_rad,
azimuths_rad,
verbose,
sv3d_model,
)
images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame
sv3d_file = os.path.join(output_folder, "t000.mp4")
save_video(sv3d_file, images_t0.unsqueeze(1))
for emb in model.conditioner.embedders:
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
emb.en_and_decode_n_samples_a_time = encoding_t
model.en_and_decode_n_samples_a_time = decoding_t
# Initialize image matrix
img_matrix = [[None] * n_views for _ in range(n_frames)]
for i, v in enumerate(subsampled_views):
img_matrix[0][i] = images_t0[v].unsqueeze(0)
for t in range(n_frames):
img_matrix[t][0] = images_v0[t]
# Interleaved sampling for anchor frames
t0, v0 = 0, 0
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
view_indices = np.arange(V) + 1
print(f"Sampling anchor frames {frame_indices}")
image = img_matrix[t0][v0]
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
model.sampler.num_steps = num_steps
version_dict["options"]["num_steps"] = num_steps
samples = run_img2vid(
version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t
)
samples = samples.view(T, V, 3, H, W)
for i, t in enumerate(frame_indices):
for j, v in enumerate(view_indices):
if img_matrix[t][v] is None:
img_matrix[t][v] = samples[i, j][None] * 2 - 1
# concat video
grid_list = []
for t in frame_indices:
imgs_view = torch.cat(img_matrix[t])
grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0))
# save output videos
anchor_vis_file = os.path.join(output_folder, "anchor_vis.mp4")
save_video(anchor_vis_file, grid_list, fps=3)
anchor_file = os.path.join(output_folder, "anchor.mp4")
image_list = samples.view(T*V, 3, H, W).unsqueeze(1) * 2 - 1
save_video(anchor_file, image_list)
return sv3d_file, anchor_vis_file, anchor_file
def sample_all(
input_path: str = "inputs/test_video1.mp4", # Can either be video file or folder with image files
sv3d_path: str = "outputs/sv4d/000000_t000.mp4",
anchor_path: str = "outputs/sv4d/000000_anchor.mp4",
seed: Optional[int] = None,
num_steps: int = 20,
device: str = "cuda",
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
azimuths_deg: Optional[List[float]] = None,
):
"""
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
"""
output_folder = os.path.dirname(input_path)
torch.manual_seed(seed)
os.makedirs(output_folder, exist_ok=True)
# Read input video frames i.e. images at view 0
print(f"Reading {input_path}")
images_v0 = read_video(
input_path,
n_frames=n_frames,
device=device,
)
images_t0 = read_video(
sv3d_path,
n_frames=n_views_sv3d,
device=device,
)
# Get camera viewpoints
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
elevations_deg = [elevations_deg] * n_views_sv3d
assert (
len(elevations_deg) == n_views_sv3d
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
if azimuths_deg is None:
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
assert (
len(azimuths_deg) == n_views_sv3d
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
azimuths_rad = np.array(
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
)
# Initialize image matrix
img_matrix = [[None] * n_views for _ in range(n_frames)]
for i, v in enumerate(subsampled_views):
img_matrix[0][i] = images_t0[v]
for t in range(n_frames):
img_matrix[t][0] = images_v0[t]
# load interleaved sampling for anchor frames
t0, v0 = 0, 0
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
view_indices = np.arange(V) + 1
anchor_frames = read_video(
anchor_path,
n_frames=T * V,
device=device,
)
anchor_frames = torch.cat(anchor_frames).view(T, V, 3, H, W)
for i, t in enumerate(frame_indices):
for j, v in enumerate(view_indices):
if img_matrix[t][v] is None:
img_matrix[t][v] = anchor_frames[i, j][None]
# Dense sampling for the rest
print(f"Sampling dense frames:")
for t0 in np.arange(0, n_frames - 1, T - 1): # [0, 4, 8, 12, 16]
frame_indices = t0 + np.arange(T)
print(f"Sampling dense frames {frame_indices}")
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
# alternate between forward and backward conditioning
forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(
frame_indices,
img_matrix,
v0,
view_indices,
model,
version_dict,
seed,
polars,
azims
)
for step in range(num_steps):
if step % 2 == 1:
c, uc, additional_model_inputs, sampler = forward_inputs
frame_indices = forward_frame_indices
else:
c, uc, additional_model_inputs, sampler = backward_inputs
frame_indices = backward_frame_indices
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
samples = do_sample_per_step(
model,
sampler,
noisy_latents,
c,
uc,
step,
additional_model_inputs,
)
samples = samples.view(T, V, C, H // F, W // F)
for i, t in enumerate(frame_indices):
for j, v in enumerate(view_indices):
latent_matrix[t, v] = samples[i, j]
img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)
# concat video
grid_list = []
for t in range(n_frames):
imgs_view = torch.cat(img_matrix[t])
grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0))
# save output videos
vid_file = os.path.join(output_folder, "sv4d_final.mp4")
save_video(vid_file, grid_list)
return vid_file, seed
with gr.Blocks() as demo:
gr.Markdown(
"""# Demo for SV4D from Stability AI ([model](https://huggingface.co/stabilityai/sv4d), [news](https://stability.ai/news/stable-video-4d))
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv4d/blob/main/LICENSE.md)): generate 8 novel view videos from a single-view video (with white background).
#### It takes ~45s to generate anchor frames and another ~160s to generate full results (21 frames).
#### Hints for improving performance:
- Use a white background;
- Make the object in the center of the image;
- The SV4D process the first 21 frames of the uploaded video. Gradio provides a nice option of trimming the uploaded video if needed.
"""
)
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Upload your video")
generate_btn = gr.Button("Step 1: generate 8 novel view videos (5 anchor frames each)")
interpolate_btn = gr.Button("Step 2: Extend novel view videos to 21 frames")
with gr.Column():
anchor_video = gr.Video(label="SV4D outputs (anchor frames)")
sv3d_video = gr.Video(label="SV3D outputs", interactive=False)
with gr.Column():
sv4d_interpolated_video = gr.Video(label="SV4D outputs (21 frames)")
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
value=23,
# randomize=True,
minimum=0,
maximum=100,
step=1,
)
encoding_t = gr.Slider(
label="Encode n frames at a time",
info="Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.",
value=8,
minimum=1,
maximum=40,
)
decoding_t = gr.Slider(
label="Decode n frames at a time",
info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.",
value=4,
minimum=1,
maximum=14,
)
denoising_steps = gr.Slider(
label="Number of denoising steps",
info="Increase will improve the performance but needs more time.",
value=20,
minimum=10,
maximum=50,
step=1,
)
remove_bg = gr.Checkbox(
label="Remove background",
info="We use rembg. Users can check the alternative way: SAM2 (https://github.com/facebookresearch/segment-anything-2)",
)
input_video.upload(fn=preprocess_video, inputs=[input_video, remove_bg], outputs=input_video, queue=False)
with gr.Row(visible=False):
anchor_frames = gr.Video()
generate_btn.click(
fn=sample_anchor,
inputs=[input_video, seed, encoding_t, decoding_t, denoising_steps],
outputs=[sv3d_video, anchor_video, anchor_frames],
api_name="SV4D output (5 frames)",
)
interpolate_btn.click(
fn=sample_all,
inputs=[input_video, sv3d_video, anchor_frames, seed, denoising_steps],
outputs=[sv4d_interpolated_video, seed],
api_name="SV4D interpolation (21 frames)",
)
examples = gr.Examples(
fn=preprocess_video,
examples=[
"./assets/sv4d_videos/test_video1.mp4",
"./assets/sv4d_videos/test_video2.mp4",
"./assets/sv4d_videos/green_robot.mp4",
"./assets/sv4d_videos/dolphin.mp4",
"./assets/sv4d_videos/lucia_v000.mp4",
"./assets/sv4d_videos/snowboard_v000.mp4",
"./assets/sv4d_videos/stroller_v000.mp4",
"./assets/sv4d_videos/human5.mp4",
"./assets/sv4d_videos/bunnyman.mp4",
"./assets/sv4d_videos/hiphop_parrot.mp4",
"./assets/sv4d_videos/guppie_v0.mp4",
"./assets/sv4d_videos/wave_hello.mp4",
"./assets/sv4d_videos/pistol_v0.mp4",
"./assets/sv4d_videos/human7.mp4",
"./assets/sv4d_videos/monkey.mp4",
"./assets/sv4d_videos/train_v0.mp4",
],
inputs=[input_video],
run_on_click=True,
outputs=[input_video],
)
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(share=True)

View File

@@ -121,10 +121,6 @@ def save_video(file_name, imgs, fps=10):
def read_video(
input_path: str,
n_frames: int,
W: int,
H: int,
remove_bg: bool = False,
image_frame_ratio: Optional[float] = None,
device: str = "cuda",
):
path = Path(input_path)
@@ -158,47 +154,121 @@ def read_video(
if len(images) < n_frames:
images = (images + images[::-1])[:n_frames]
if len(images) != n_frames:
raise ValueError(f"Input video contains fewer than {n_frames} frames.")
# Remove background and crop video frames
images_v0 = []
for t, image in enumerate(images):
if remove_bg:
if image.mode != "RGBA":
image.thumbnail([W, H], Image.Resampling.LANCZOS)
image = remove(image.convert("RGBA"), alpha_matting=True)
image_arr = np.array(image)
in_w, in_h = image_arr.shape[:2]
ret, mask = cv2.threshold(
np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY
)
x, y, w, h = cv2.boundingRect(mask)
max_size = max(w, h)
if t == 0:
side_len = (
int(max_size / image_frame_ratio)
if image_frame_ratio is not None
else in_w
)
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len // 2
padded_image[
center - h // 2 : center - h // 2 + h,
center - w // 2 : center - w // 2 + w,
] = image_arr[y : y + h, x : x + w]
rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)
rgba_arr = np.array(rgba) / 255.0
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
image = Image.fromarray((rgb * 255).astype(np.uint8))
else:
image = image.convert("RGB").resize((W, H), Image.LANCZOS)
for image in images:
image = ToTensor()(image).unsqueeze(0).to(device)
images_v0.append(image * 2.0 - 1.0)
return images_v0
def preprocess_video(input_path, remove_bg=False, n_frames=21, W=576, H=576, output_folder=None, image_frame_ratio = 0.917):
print(f"preprocess {input_path}")
if output_folder is None:
output_folder = os.path.dirname(input_path)
path = Path(input_path)
is_video_file = False
all_img_paths = []
if path.is_file():
if any([input_path.endswith(x) for x in [".gif", ".mp4"]]):
is_video_file = True
else:
raise ValueError("Path is not a valid video file.")
elif path.is_dir():
all_img_paths = sorted(
[
f
for f in path.iterdir()
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
]
)[:n_frames]
elif "*" in input_path:
all_img_paths = sorted(glob(input_path))[:n_frames]
else:
raise ValueError
if is_video_file and input_path.endswith(".gif"):
images = read_gif(input_path, n_frames)[:n_frames]
elif is_video_file and input_path.endswith(".mp4"):
images = read_mp4(input_path, n_frames)[:n_frames]
else:
print(f"Loading {len(all_img_paths)} video frames...")
images = [Image.open(img_path) for img_path in all_img_paths]
if len(images) != n_frames:
raise ValueError(f"Input video contains {len(images)} frames, fewer than {n_frames} frames.")
# Remove background
for i, image in enumerate(images):
if remove_bg:
if image.mode == "RGBA":
pass
else:
# image.thumbnail([W, H], Image.Resampling.LANCZOS)
image = remove(image.convert("RGBA"), alpha_matting=True)
images[i] = image
# Crop video frames, assume the object is already in the center of the image
white_thresh = 250
images_v0 = []
box_coord = [np.inf, np.inf, 0, 0]
for image in images:
image_arr = np.array(image)
in_w, in_h = image_arr.shape[:2]
original_center = (in_w // 2, in_h // 2)
if image.mode == "RGBA":
ret, mask = cv2.threshold(
np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY
)
else:
# assume the input image has white background
ret, mask = cv2.threshold(
(np.array(image).mean(-1) <= white_thresh).astype(np.uint8) * 255, 0, 255, cv2.THRESH_BINARY
)
x, y, w, h = cv2.boundingRect(mask)
box_coord[0] = min(box_coord[0], x)
box_coord[1] = min(box_coord[1], y)
box_coord[2] = max(box_coord[2], x + w)
box_coord[3] = max(box_coord[3], y + h)
box_square = max(original_center[0] - box_coord[0], original_center[1] - box_coord[1])
box_square = max(box_square, box_coord[2] - original_center[0])
box_square = max(box_square, box_coord[3] - original_center[1])
x, y, w, h = original_center[0] - box_square, original_center[1] - box_square, 2 * box_square, 2 * box_square
box_size = box_square * 2
for image in images:
if image.mode == "RGB":
image = image.convert("RGBA")
image_arr = np.array(image)
side_len = (
int(box_size / image_frame_ratio)
if image_frame_ratio is not None
else in_w
)
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len // 2
padded_image[
center - box_size // 2 : center - box_size // 2 + box_size,
center - box_size // 2 : center - box_size // 2 + box_size,
] = image_arr[x : x + w, y : y + h]
rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)
# rgba = image.resize((W, H), Image.LANCZOS)
rgba_arr = np.array(rgba) / 255.0
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
image = (rgb * 255).astype(np.uint8)
images_v0.append(image)
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 12
processed_file = os.path.join(output_folder, f"{base_count:06d}_process_input.mp4")
imageio.mimwrite(processed_file, images_v0, fps=10)
return processed_file
def sample_sv3d(
image,
num_frames: Optional[int] = None, # 21 for SV3D
@@ -212,26 +282,32 @@ def sample_sv3d(
polar_rad: Optional[Union[float, List[float]]] = None,
azim_rad: Optional[List[float]] = None,
verbose: Optional[bool] = False,
sv3d_model=None,
):
"""
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
"""
if version == "sv3d_u":
model_config = "scripts/sampling/configs/sv3d_u.yaml"
elif version == "sv3d_p":
model_config = "scripts/sampling/configs/sv3d_p.yaml"
else:
raise ValueError(f"Version {version} does not exist.")
if sv3d_model is None:
if version == "sv3d_u":
model_config = "scripts/sampling/configs/sv3d_u.yaml"
elif version == "sv3d_p":
model_config = "scripts/sampling/configs/sv3d_p.yaml"
else:
raise ValueError(f"Version {version} does not exist.")
model, filter = load_model(
model_config,
device,
num_frames,
num_steps,
verbose,
)
model, filter = load_model(
model_config,
device,
num_frames,
num_steps,
verbose,
)
else:
model = sv3d_model
load_module_gpu(model)
H, W = image.shape[2:]
F = 8
@@ -286,25 +362,32 @@ def sample_sv3d(
)
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)
model.en_and_decode_n_samples_a_time = decoding_t
samples_x = model.decode_first_stage(samples_z)
samples_x[-1:] = value_dict["cond_frames_without_noise"]
samples = torch.clamp(samples_x, min=-1.0, max=1.0)
return samples
def decode_latents(model, samples_z, timesteps):
load_module_gpu(model.first_stage_model)
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(samples_z, timesteps=timesteps)
else:
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_module_gpu(model.first_stage_model)
unload_module_gpu(model)
return samples
def decode_latents(model, samples_z, img_matrix, frame_indices, view_indices, timesteps):
load_module_gpu(model.first_stage_model)
for t in frame_indices:
for v in view_indices:
if t != 0 and v != 0:
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(samples_z[t, v][None], timesteps=timesteps)
else:
samples_x = model.decode_first_stage(samples_z[t, v][None])
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
img_matrix[t][v] = samples * 2 - 1
unload_module_gpu(model.first_stage_model)
return img_matrix
def init_embedder_options_no_st(keys, init_dict, prompt=None, negative_prompt=None):
# Hardcoded demo settings; might undergo some changes in the future
@@ -604,6 +687,7 @@ def run_img2vid(
azim_rad=np.linspace(0, 360, 21 + 1)[1:],
cond_motion=None,
cond_view=None,
decoding_t=None,
):
options = version_dict["options"]
H = version_dict["H"]
@@ -670,12 +754,53 @@ def run_img2vid(
force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None),
return_latents=False,
decoding_t=options.get("decoding_T", T),
decoding_t=decoding_t,
)
return samples
def prepare_inputs(frame_indices, img_matrix, v0, view_indices, model, version_dict, seed, polars, azims):
load_module_gpu(model.conditioner)
forward_frame_indices = frame_indices.copy()
t0 = forward_frame_indices[0]
image = img_matrix[t0][v0]
cond_motion = torch.cat([img_matrix[t][v0] for t in forward_frame_indices], 0)
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
forward_inputs = prepare_sampling(
version_dict,
model,
image,
seed,
polars,
azims,
cond_motion,
cond_view,
)
# backward sampling
backward_frame_indices = frame_indices[
::-1
].copy()
t0 = backward_frame_indices[0]
image = img_matrix[t0][v0]
cond_motion = torch.cat([img_matrix[t][v0] for t in backward_frame_indices], 0)
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
backward_inputs = prepare_sampling(
version_dict,
model,
image,
seed,
polars,
azims,
cond_motion,
cond_view,
)
unload_module_gpu(model.conditioner)
return forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices
def do_sample(
model,
sampler,
@@ -761,13 +886,11 @@ def do_sample(
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
load_module_gpu(model.model)
load_module_gpu(model.denoiser)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)
load_module_gpu(model.first_stage_model)
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(
@@ -777,17 +900,15 @@ def do_sample(
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_module_gpu(model.first_stage_model)
if filter is not None:
samples = filter(samples)
if return_latents:
return samples, samples_z
return samples
def do_sample_per_step(
def prepare_sampling_(
model,
sampler,
value_dict,
@@ -797,8 +918,6 @@ def do_sample_per_step(
batch2model_input: List = None,
T=None,
additional_batch_uc_fields=None,
step=None,
noisy_latents=None,
):
force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
batch2model_input = default(batch2model_input, [])
@@ -812,8 +931,6 @@ def do_sample_per_step(
num_samples = [num_samples, T]
else:
num_samples = [num_samples]
load_module_gpu(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
@@ -827,8 +944,6 @@ def do_sample_per_step(
force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
)
unload_module_gpu(model.conditioner)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
@@ -859,7 +974,14 @@ def do_sample_per_step(
)
else:
additional_model_inputs[k] = batch[k]
return c, uc, additional_model_inputs
def do_sample_per_step(model, sampler, noisy_latents, c, uc, step, additional_model_inputs):
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
noisy_latents_scaled, s_in, sigmas, num_sigmas, _, _ = (
sampler.prepare_sampling_loop(
noisy_latents.clone(), c, uc, sampler.num_steps
@@ -893,13 +1015,10 @@ def do_sample_per_step(
uc,
gamma,
)
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)
return samples_z
def run_img2vid_per_step(
def prepare_sampling(
version_dict,
model,
image,
@@ -908,8 +1027,6 @@ def run_img2vid_per_step(
azim_rad=np.linspace(0, 360, 21 + 1)[1:],
cond_motion=None,
cond_view=None,
step=None,
noisy_latents=None,
):
options = version_dict["options"]
H = version_dict["H"]
@@ -962,7 +1079,7 @@ def run_img2vid_per_step(
sampler, num_rows, num_cols = init_sampling_no_st(options=options)
num_samples = num_rows * num_cols
samples = do_sample_per_step(
c, uc, additional_model_inputs = prepare_sampling_(
model,
sampler,
value_dict,
@@ -971,11 +1088,9 @@ def run_img2vid_per_step(
force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None),
batch2model_input=["num_video_frames", "image_only_indicator"],
T=T,
step=step,
noisy_latents=noisy_latents,
)
return samples
return c, uc, additional_model_inputs, sampler
def get_unique_embedder_keys_from_conditioner(conditioner):

View File

@@ -10,15 +10,19 @@ import numpy as np
import torch
from fire import Fire
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
from scripts.demo.sv4d_helpers import (
decode_latents,
load_model,
initial_model_load,
read_video,
run_img2vid,
run_img2vid_per_step,
prepare_sampling,
prepare_inputs,
do_sample_per_step,
sample_sv3d,
save_video,
preprocess_video,
)
@@ -32,17 +36,18 @@ def sample(
motion_bucket_id: int = 127,
cond_aug: float = 1e-5,
seed: int = 23,
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda",
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
azimuths_deg: Optional[List[float]] = None,
image_frame_ratio: Optional[float] = None,
image_frame_ratio: Optional[float] = 0.917,
verbose: Optional[bool] = False,
remove_bg: bool = False,
):
"""
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
"""
# Set model config
T = 5 # number of frames per sample
@@ -89,15 +94,16 @@ def sample(
# Read input video frames i.e. images at view 0
print(f"Reading {input_path}")
images_v0 = read_video(
processed_input_path = preprocess_video(
input_path,
remove_bg=remove_bg,
n_frames=n_frames,
W=W,
H=H,
remove_bg=remove_bg,
output_folder=output_folder,
image_frame_ratio=image_frame_ratio,
device=device,
)
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
# Get camera viewpoints
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
@@ -139,7 +145,7 @@ def sample(
for t in range(n_frames):
img_matrix[t][0] = images_v0[t]
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 12
save_video(
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
img_matrix[0],
@@ -158,6 +164,10 @@ def sample(
verbose,
)
model = initial_model_load(model)
for emb in model.conditioner.embedders:
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
emb.en_and_decode_n_samples_a_time = encoding_t
model.en_and_decode_n_samples_a_time = decoding_t
# Interleaved sampling for anchor frames
t0, v0 = 0, 0
@@ -171,7 +181,7 @@ def sample(
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
samples = run_img2vid(
version_dict, model, image, seed, polars, azims, cond_motion, cond_view
version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t
)
samples = samples.view(T, V, 3, H, W)
for i, t in enumerate(frame_indices):
@@ -185,40 +195,48 @@ def sample(
frame_indices = t0 + np.arange(T)
print(f"Sampling dense frames {frame_indices}")
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
# alternate between forward and backward conditioning
forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(
frame_indices,
img_matrix,
v0,
view_indices,
model,
version_dict,
seed,
polars,
azims
)
for step in tqdm(range(num_steps)):
frame_indices = frame_indices[
::-1
].copy() # alternate between forward and backward conditioning
t0 = frame_indices[0]
image = img_matrix[t0][v0]
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
if step % 2 == 1:
c, uc, additional_model_inputs, sampler = forward_inputs
frame_indices = forward_frame_indices
else:
c, uc, additional_model_inputs, sampler = backward_inputs
frame_indices = backward_frame_indices
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
samples = run_img2vid_per_step(
version_dict,
samples = do_sample_per_step(
model,
image,
seed,
polars,
azims,
cond_motion,
cond_view,
step,
sampler,
noisy_latents,
c,
uc,
step,
additional_model_inputs,
)
samples = samples.view(T, V, C, H // F, W // F)
for i, t in enumerate(frame_indices):
for j, v in enumerate(view_indices):
latent_matrix[t, v] = samples[i, j]
for t in frame_indices:
for v in view_indices:
if t != 0 and v != 0:
img = decode_latents(model, latent_matrix[t, v][None], T)
img_matrix[t][v] = img * 2 - 1
img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)
# Save output videos
for v in view_indices: