sv4d: fix readme;

rename video exampel folder;
add encode_t as input parameter.
This commit is contained in:
ymxie97
2024-08-02 17:19:03 +00:00
parent da40ebad4e
commit e90e953330
22 changed files with 43 additions and 29 deletions

View File

@@ -9,23 +9,23 @@
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes: - We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object. - **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
- To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency. - To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.
- You can run the community-build gradio demo locally by running `python -m scripts.demo.gradio_app_sv4d`. - To run the community-build gradio demo locally, run `python -m scripts.demo.gradio_app_sv4d`.
- Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details. - Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details.
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_example_video/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`) **QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`)
To run **SV4D** on a single input video of 21 frames: To run **SV4D** on a single input video of 21 frames:
- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/` - Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`
- Run `python scripts/sampling/simple_video_sample_4d.py --input_path <path/to/video>` - Run `python scripts/sampling/simple_video_sample_4d.py --input_path <path/to/video>`
- `input_path` : The input video `<path/to/video>` can be - `input_path` : The input video `<path/to/video>` can be
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_example_video/test_video1.mp4`, or - a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/test_video1.mp4`, or
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or - a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
- a file name pattern matching images of video frames. - a file name pattern matching images of video frames.
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time. - `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p. - `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0` - `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D. - **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`. - **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
![tile](assets/sv4d.gif) ![tile](assets/sv4d.gif)

View File

@@ -14,6 +14,7 @@ from huggingface_hub import hf_hub_download
from typing import List, Optional, Union from typing import List, Optional, Union
import torchvision import torchvision
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
from scripts.demo.sv4d_helpers import ( from scripts.demo.sv4d_helpers import (
decode_latents, decode_latents,
load_model, load_model,
@@ -138,6 +139,7 @@ sv3d_model = initial_model_load(sv3d_model)
def sample_anchor( def sample_anchor(
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
seed: Optional[int] = None, seed: Optional[int] = None,
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
num_steps: int = 20, num_steps: int = 20,
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
@@ -205,6 +207,10 @@ def sample_anchor(
sv3d_file = os.path.join(output_folder, "t000.mp4") sv3d_file = os.path.join(output_folder, "t000.mp4")
save_video(sv3d_file, images_t0.unsqueeze(1)) save_video(sv3d_file, images_t0.unsqueeze(1))
for emb in model.conditioner.embedders:
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
emb.en_and_decode_n_samples_a_time = encoding_t
model.en_and_decode_n_samples_a_time = decoding_t
# Initialize image matrix # Initialize image matrix
img_matrix = [[None] * n_views for _ in range(n_frames)] img_matrix = [[None] * n_views for _ in range(n_frames)]
for i, v in enumerate(subsampled_views): for i, v in enumerate(subsampled_views):
@@ -413,6 +419,13 @@ with gr.Blocks() as demo:
maximum=100, maximum=100,
step=1, step=1,
) )
encoding_t = gr.Slider(
label="Encode n frames at a time",
info="Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.",
value=8,
minimum=1,
maximum=40,
)
decoding_t = gr.Slider( decoding_t = gr.Slider(
label="Decode n frames at a time", label="Decode n frames at a time",
info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.", info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.",
@@ -440,7 +453,7 @@ with gr.Blocks() as demo:
generate_btn.click( generate_btn.click(
fn=sample_anchor, fn=sample_anchor,
inputs=[input_video, seed, decoding_t, denoising_steps], inputs=[input_video, seed, encoding_t, decoding_t, denoising_steps],
outputs=[sv3d_video, anchor_video, anchor_frames], outputs=[sv3d_video, anchor_video, anchor_frames],
api_name="SV4D output (5 frames)", api_name="SV4D output (5 frames)",
) )
@@ -455,22 +468,22 @@ with gr.Blocks() as demo:
examples = gr.Examples( examples = gr.Examples(
fn=preprocess_video, fn=preprocess_video,
examples=[ examples=[
"./assets/sv4d_example_video/test_video1.mp4", "./assets/sv4d_videos/test_video1.mp4",
"./assets/sv4d_example_video/test_video2.mp4", "./assets/sv4d_videos/test_video2.mp4",
"./assets/sv4d_example_video/green_robot.mp4", "./assets/sv4d_videos/green_robot.mp4",
"./assets/sv4d_example_video/dolphin.mp4", "./assets/sv4d_videos/dolphin.mp4",
"./assets/sv4d_example_video/lucia_v000.mp4", "./assets/sv4d_videos/lucia_v000.mp4",
"./assets/sv4d_example_video/snowboard_v000.mp4", "./assets/sv4d_videos/snowboard_v000.mp4",
"./assets/sv4d_example_video/stroller_v000.mp4", "./assets/sv4d_videos/stroller_v000.mp4",
"./assets/sv4d_example_video/human5.mp4", "./assets/sv4d_videos/human5.mp4",
"./assets/sv4d_example_video/bunnyman.mp4", "./assets/sv4d_videos/bunnyman.mp4",
"./assets/sv4d_example_video/hiphop_parrot.mp4", "./assets/sv4d_videos/hiphop_parrot.mp4",
"./assets/sv4d_example_video/guppie_v0.mp4", "./assets/sv4d_videos/guppie_v0.mp4",
"./assets/sv4d_example_video/wave_hello.mp4", "./assets/sv4d_videos/wave_hello.mp4",
"./assets/sv4d_example_video/pistol_v0.mp4", "./assets/sv4d_videos/pistol_v0.mp4",
"./assets/sv4d_example_video/human7.mp4", "./assets/sv4d_videos/human7.mp4",
"./assets/sv4d_example_video/monkey.mp4", "./assets/sv4d_videos/monkey.mp4",
"./assets/sv4d_example_video/train_v0.mp4", "./assets/sv4d_videos/train_v0.mp4",
], ],
inputs=[input_video], inputs=[input_video],
run_on_click=True, run_on_click=True,

View File

@@ -264,7 +264,7 @@ def preprocess_video(input_path, remove_bg=False, n_frames=21, W=576, H=576, out
images_v0.append(image) images_v0.append(image)
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10 base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 12
processed_file = os.path.join(output_folder, f"{base_count:06d}_process_input.mp4") processed_file = os.path.join(output_folder, f"{base_count:06d}_process_input.mp4")
imageio.mimwrite(processed_file, images_v0, fps=10) imageio.mimwrite(processed_file, images_v0, fps=10)
return processed_file return processed_file
@@ -892,7 +892,6 @@ def do_sample(
unload_module_gpu(model.model) unload_module_gpu(model.model)
unload_module_gpu(model.denoiser) unload_module_gpu(model.denoiser)
load_module_gpu(model.first_stage_model) load_module_gpu(model.first_stage_model)
model.en_and_decode_n_samples_a_time = decoding_t
if isinstance(model.first_stage_model.decoder, VideoDecoder): if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage( samples_x = model.decode_first_stage(
samples_z, timesteps=default(decoding_t, T) samples_z, timesteps=default(decoding_t, T)

View File

@@ -1,7 +1,6 @@
N_TIME: 5 N_TIME: 5
N_VIEW: 8 N_VIEW: 8
N_FRAMES: 40 N_FRAMES: 40
ENCODE_N_A_TIME: 8
model: model:
target: sgm.models.diffusion.DiffusionEngine target: sgm.models.diffusion.DiffusionEngine
@@ -68,7 +67,6 @@ model:
is_ae: True is_ae: True
n_cond_frames: ${N_FRAMES} n_cond_frames: ${N_FRAMES}
n_copies: 1 n_copies: 1
en_and_decode_n_samples_a_time: ${ENCODE_N_A_TIME}
encoder_config: encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly target: sgm.models.autoencoder.AutoencoderKLModeOnly
params: params:
@@ -133,7 +131,6 @@ model:
is_ae: True is_ae: True
n_cond_frames: ${N_VIEW} n_cond_frames: ${N_VIEW}
n_copies: 1 n_copies: 1
en_and_decode_n_samples_a_time: ${ENCODE_N_A_TIME}
sigma_sampler_config: sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
@@ -144,7 +141,6 @@ model:
is_ae: True is_ae: True
n_cond_frames: ${N_TIME} n_cond_frames: ${N_TIME}
n_copies: 1 n_copies: 1
en_and_decode_n_samples_a_time: ${ENCODE_N_A_TIME}
encoder_config: encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly target: sgm.models.autoencoder.AutoencoderKLModeOnly
params: params:

View File

@@ -10,6 +10,7 @@ import numpy as np
import torch import torch
from fire import Fire from fire import Fire
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
from scripts.demo.sv4d_helpers import ( from scripts.demo.sv4d_helpers import (
decode_latents, decode_latents,
load_model, load_model,
@@ -35,6 +36,7 @@ def sample(
motion_bucket_id: int = 127, motion_bucket_id: int = 127,
cond_aug: float = 1e-5, cond_aug: float = 1e-5,
seed: int = 23, seed: int = 23,
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda", device: str = "cuda",
elevations_deg: Optional[Union[float, List[float]]] = 10.0, elevations_deg: Optional[Union[float, List[float]]] = 10.0,
@@ -45,7 +47,7 @@ def sample(
): ):
""" """
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
""" """
# Set model config # Set model config
T = 5 # number of frames per sample T = 5 # number of frames per sample
@@ -162,6 +164,10 @@ def sample(
verbose, verbose,
) )
model = initial_model_load(model) model = initial_model_load(model)
for emb in model.conditioner.embedders:
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
emb.en_and_decode_n_samples_a_time = encoding_t
model.en_and_decode_n_samples_a_time = decoding_t
# Interleaved sampling for anchor frames # Interleaved sampling for anchor frames
t0, v0 = 0, 0 t0, v0 = 0, 0