Compare commits
24 Commits
sv3d_gradi
...
sp4d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
787abc0be9 | ||
|
|
0aee97d395 | ||
|
|
fd9d14e02f | ||
|
|
8f41cbc50b | ||
|
|
f87e52e72c | ||
|
|
0ad7de9a5c | ||
|
|
c3147b86db | ||
|
|
1659a1c09b | ||
|
|
37ab71e234 | ||
|
|
e90e953330 | ||
|
|
da40ebad4e | ||
|
|
50364a7d2f | ||
|
|
2cea114cc1 | ||
|
|
734195d1c9 | ||
|
|
854bd4f0df | ||
|
|
e0596f1aca | ||
|
|
ce1576bfca | ||
|
|
1cd0cbaff4 | ||
|
|
863665548f | ||
|
|
e3e4b9d263 | ||
|
|
1aa06e5995 | ||
|
|
998cb122d3 | ||
|
|
31fe459a85 | ||
|
|
abe9ed3d40 |
4
.gitignore
vendored
@@ -11,4 +11,6 @@
|
||||
/dist
|
||||
/outputs
|
||||
/build
|
||||
/src
|
||||
/src
|
||||
/.vscode
|
||||
**/__pycache__/
|
||||
|
||||
79
README.md
Normal file → Executable file
@@ -4,6 +4,84 @@
|
||||
|
||||
## News
|
||||
|
||||
|
||||
**Nov 4, 2025**
|
||||
- We are releasing **[Stable Part Diffusion 4D (SP4D)](https://huggingface.co/stabilityai/sp4d)**, a video-to-4D diffusion model for multi-view part video synthesis and animatable 3D asset generation. For research purposes:
|
||||
- **SP4D** was trained to generate 48 RGB frames and part segmentation maps (4 video frames x 12 camera views) at 576x576 resolution, given a 4-frame input video of the same size, ideally consisting of white-background images of a moving object.
|
||||
- Based on our previous 4D model [SV4D 2.0](https://huggingface.co/stabilityai/sv4d2.0), **SP4D** can simultaneously generate multi-view RGB videos as well as the corresponding kinematic part segmentations that are consistent across time and camera views.
|
||||
- The generated part videos can then be used to create animation-ready 3D assets with part-aware rigging capabilities.
|
||||
- Please check our [project page](https://stablepartdiffusion4d.github.io/), [arxiv paper](https://arxiv.org/pdf/2509.10687) and [video summary](https://www.youtube.com/watch?v=FXEFeh8tf0k) for more details.
|
||||
|
||||
**QUICKSTART** :
|
||||
- Setup environment following the SV4D instructions and download [sp4d.safetensors](https://huggingface.co/stabilityai/sp4d) from HuggingFace into `checkpoints/`
|
||||
- Run `python scripts/sampling/simple_video_sample_sp4d.py --input_path assets/sv4d_videos/cows.gif --output_folder outputs` to generate multi-view part videos given the sample input.
|
||||
|
||||
|
||||
**May 20, 2025**
|
||||
- We are releasing **[Stable Video 4D 2.0 (SV4D 2.0)](https://huggingface.co/stabilityai/sv4d2.0)**, an enhanced video-to-4D diffusion model for high-fidelity novel-view video synthesis and 4D asset generation. For research purposes:
|
||||
- **SV4D 2.0** was trained to generate 48 frames (12 video frames x 4 camera views) at 576x576 resolution, given a 12-frame input video of the same size, ideally consisting of white-background images of a moving object.
|
||||
- Compared to our previous 4D model [SV4D](https://huggingface.co/stabilityai/sv4d), **SV4D 2.0** can generate videos with higher fidelity, sharper details during motion, and better spatio-temporal consistency. It also generalizes much better to real-world videos. Moreover, it does not rely on refernce multi-view of the first frame generated by SV3D, making it more robust to self-occlusions.
|
||||
- To generate longer novel-view videos, we autoregressively generate 12 frames at a time and use the previous generation as conditioning views for the remaining frames.
|
||||
- Please check our [project page](https://sv4d20.github.io), [arxiv paper](https://arxiv.org/pdf/2503.16396) and [video summary](https://www.youtube.com/watch?v=dtqj-s50ynU) for more details.
|
||||
|
||||
**QUICKSTART** :
|
||||
- `python scripts/sampling/simple_video_sample_4d2.py --input_path assets/sv4d_videos/camel.gif --output_folder outputs` (after downloading [sv4d2.safetensors](https://huggingface.co/stabilityai/sv4d2.0) from HuggingFace into `checkpoints/`)
|
||||
|
||||
To run **SV4D 2.0** on a single input video of 21 frames:
|
||||
- Download SV4D 2.0 model (`sv4d2.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d2.0) to `checkpoints/`: `huggingface-cli download stabilityai/sv4d2.0 sv4d2.safetensors --local-dir checkpoints`
|
||||
- Run inference: `python scripts/sampling/simple_video_sample_4d2.py --input_path <path/to/video>`
|
||||
- `input_path` : The input video `<path/to/video>` can be
|
||||
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/camel.gif`, or
|
||||
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
|
||||
- a file name pattern matching images of video frames.
|
||||
- `num_steps` : default is 50, can decrease to it to shorten sampling time.
|
||||
- `elevations_deg` : specified elevations (reletive to input view), default is 0.0 (same as input view).
|
||||
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.
|
||||
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
|
||||
|
||||
Notes:
|
||||
- We also train a 8-view model that generates 5 frames x 8 views at a time (same as SV4D).
|
||||
- Download the model from huggingface: `huggingface-cli download stabilityai/sv4d2.0 sv4d2_8views.safetensors --local-dir checkpoints`
|
||||
- Run inference: `python scripts/sampling/simple_video_sample_4d2.py --model_path checkpoints/sv4d2_8views.safetensors --input_path assets/sv4d_videos/chest.gif --output_folder outputs`
|
||||
- The 5x8 model takes 5 frames of input at a time. But the inference scripts for both model take 21-frame video as input by default (same as SV3D and SV4D), we run the model autoregressively until we generate 21 frames.
|
||||
- Install dependencies before running:
|
||||
```
|
||||
python3.10 -m venv .generativemodels
|
||||
source .generativemodels/bin/activate
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # check CUDA version
|
||||
pip3 install -r requirements/pt2.txt
|
||||
pip3 install .
|
||||
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
**July 24, 2024**
|
||||
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
|
||||
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
|
||||
- To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.
|
||||
- To run the community-build gradio demo locally, run `python -m scripts.demo.gradio_app_sv4d`.
|
||||
- Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details.
|
||||
|
||||
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`)
|
||||
|
||||
To run **SV4D** on a single input video of 21 frames:
|
||||
- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`
|
||||
- Run `python scripts/sampling/simple_video_sample_4d.py --input_path <path/to/video>`
|
||||
- `input_path` : The input video `<path/to/video>` can be
|
||||
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/test_video1.mp4`, or
|
||||
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
|
||||
- a file name pattern matching images of video frames.
|
||||
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
|
||||
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
|
||||
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
|
||||
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.
|
||||
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
|
||||
|
||||

|
||||
|
||||
|
||||
**March 18, 2024**
|
||||
- We are releasing **[SV3D](https://huggingface.co/stabilityai/sv3d)**, an image-to-video model for novel multi-view synthesis, for research purposes:
|
||||
- **SV3D** was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object.
|
||||
@@ -138,6 +216,7 @@ This is assuming you have navigated to the `generative-models` root after clonin
|
||||
# install required packages from pypi
|
||||
python3 -m venv .pt2
|
||||
source .pt2/bin/activate
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
pip3 install -r requirements/pt2.txt
|
||||
```
|
||||
|
||||
|
||||
BIN
assets/sv4d.gif
Normal file
|
After Width: | Height: | Size: 8.0 MiB |
BIN
assets/sv4d2.gif
Normal file
|
After Width: | Height: | Size: 9.7 MiB |
BIN
assets/sv4d_videos/bear.gif
Normal file
|
After Width: | Height: | Size: 2.2 MiB |
BIN
assets/sv4d_videos/bee.gif
Normal file
|
After Width: | Height: | Size: 638 KiB |
BIN
assets/sv4d_videos/bmx-bumps.gif
Normal file
|
After Width: | Height: | Size: 2.2 MiB |
BIN
assets/sv4d_videos/camel.gif
Normal file
|
After Width: | Height: | Size: 1.9 MiB |
BIN
assets/sv4d_videos/chameleon.gif
Normal file
|
After Width: | Height: | Size: 1.4 MiB |
BIN
assets/sv4d_videos/chest.gif
Normal file
|
After Width: | Height: | Size: 2.2 MiB |
BIN
assets/sv4d_videos/cows.gif
Normal file
|
After Width: | Height: | Size: 1.7 MiB |
BIN
assets/sv4d_videos/dance-twirl.gif
Normal file
|
After Width: | Height: | Size: 1.2 MiB |
BIN
assets/sv4d_videos/flag.gif
Normal file
|
After Width: | Height: | Size: 2.1 MiB |
BIN
assets/sv4d_videos/gear.gif
Normal file
|
After Width: | Height: | Size: 446 KiB |
BIN
assets/sv4d_videos/hike.gif
Normal file
|
After Width: | Height: | Size: 1.6 MiB |
BIN
assets/sv4d_videos/horsejump-low.gif
Normal file
|
After Width: | Height: | Size: 1.4 MiB |
BIN
assets/sv4d_videos/robot.gif
Normal file
|
After Width: | Height: | Size: 946 KiB |
BIN
assets/sv4d_videos/snowboard.gif
Normal file
|
After Width: | Height: | Size: 1.5 MiB |
BIN
assets/sv4d_videos/test_video1.mp4
Normal file
BIN
assets/sv4d_videos/windmill.gif
Normal file
|
After Width: | Height: | Size: 2.4 MiB |
@@ -5,13 +5,16 @@ einops>=0.6.1
|
||||
fairscale>=0.4.13
|
||||
fire>=0.5.0
|
||||
fsspec>=2023.6.0
|
||||
imageio[ffmpeg]
|
||||
imageio[pyav]
|
||||
invisible-watermark>=0.2.0
|
||||
kornia==0.6.9
|
||||
matplotlib>=3.7.2
|
||||
natsort>=8.4.0
|
||||
ninja>=1.11.1
|
||||
numpy>=1.24.4
|
||||
numpy==2.1
|
||||
omegaconf>=2.3.0
|
||||
onnxruntime
|
||||
open-clip-torch>=2.20.0
|
||||
opencv-python==4.6.0.66
|
||||
pandas>=2.0.3
|
||||
|
||||
496
scripts/demo/gradio_app_sv4d.py
Normal file
@@ -0,0 +1,496 @@
|
||||
# Adding this at the very top of app.py to make 'generative-models' directory discoverable
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models"))
|
||||
|
||||
from glob import glob
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from typing import List, Optional, Union
|
||||
import torchvision
|
||||
|
||||
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
||||
from scripts.demo.sv4d_helpers import (
|
||||
decode_latents,
|
||||
load_model,
|
||||
initial_model_load,
|
||||
read_video,
|
||||
run_img2vid,
|
||||
prepare_inputs,
|
||||
do_sample_per_step,
|
||||
sample_sv3d,
|
||||
save_video,
|
||||
preprocess_video,
|
||||
)
|
||||
|
||||
|
||||
# the tmp path, if /tmp/gradio is not writable, change it to a writable path
|
||||
# os.environ["GRADIO_TEMP_DIR"] = "gradio_tmp"
|
||||
|
||||
version = "sv4d" # replace with 'sv3d_p' or 'sv3d_u' for other models
|
||||
|
||||
# Define the repo, local directory and filename
|
||||
repo_id = "stabilityai/sv4d"
|
||||
filename = f"{version}.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors"
|
||||
local_dir = "checkpoints"
|
||||
local_ckpt_path = os.path.join(local_dir, filename)
|
||||
|
||||
# Check if the file already exists
|
||||
if not os.path.exists(local_ckpt_path):
|
||||
# If the file doesn't exist, download it
|
||||
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
|
||||
print("File downloaded. (sv4d)")
|
||||
else:
|
||||
print("File already exists. No need to download. (sv4d)")
|
||||
|
||||
device = "cuda"
|
||||
max_64_bit_int = 2**63 - 1
|
||||
|
||||
num_frames = 21
|
||||
num_steps = 20
|
||||
model_config = f"scripts/sampling/configs/{version}.yaml"
|
||||
|
||||
# Set model config
|
||||
T = 5 # number of frames per sample
|
||||
V = 8 # number of views per sample
|
||||
F = 8 # vae factor to downsize image->latent
|
||||
C = 4
|
||||
H, W = 576, 576
|
||||
n_frames = 21 # number of input and output video frames
|
||||
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
||||
n_views_sv3d = 21
|
||||
subsampled_views = np.array(
|
||||
[0, 2, 5, 7, 9, 12, 14, 16, 19]
|
||||
) # subsample (V+1=)9 (uniform) views from 21 SV3D views
|
||||
|
||||
version_dict = {
|
||||
"T": T * V,
|
||||
"H": H,
|
||||
"W": W,
|
||||
"C": C,
|
||||
"f": F,
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 3,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 5,
|
||||
"num_steps": num_steps,
|
||||
"force_uc_zero_embeddings": [
|
||||
"cond_frames",
|
||||
"cond_frames_without_noise",
|
||||
"cond_view",
|
||||
"cond_motion",
|
||||
],
|
||||
"additional_guider_kwargs": {
|
||||
"additional_cond_keys": ["cond_view", "cond_motion"]
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Load SV4D model
|
||||
model, filter = load_model(
|
||||
model_config,
|
||||
device,
|
||||
version_dict["T"],
|
||||
num_steps,
|
||||
)
|
||||
model = initial_model_load(model)
|
||||
|
||||
# -----------sv3d config and model loading----------------
|
||||
# if version == "sv3d_u":
|
||||
sv3d_model_config = "scripts/sampling/configs/sv3d_u.yaml"
|
||||
# elif version == "sv3d_p":
|
||||
# sv3d_model_config = "scripts/sampling/configs/sv3d_p.yaml"
|
||||
# else:
|
||||
# raise ValueError(f"Version {version} does not exist.")
|
||||
|
||||
# Define the repo, local directory and filename
|
||||
repo_id = "stabilityai/sv3d"
|
||||
filename = f"sv3d_u.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors"
|
||||
local_dir = "checkpoints"
|
||||
local_ckpt_path = os.path.join(local_dir, filename)
|
||||
|
||||
# Check if the file already exists
|
||||
if not os.path.exists(local_ckpt_path):
|
||||
# If the file doesn't exist, download it
|
||||
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
|
||||
print("File downloaded. (sv3d)")
|
||||
else:
|
||||
print("File already exists. No need to download. (sv3d)")
|
||||
|
||||
# load sv3d model
|
||||
sv3d_model, filter = load_model(
|
||||
sv3d_model_config,
|
||||
device,
|
||||
21,
|
||||
num_steps,
|
||||
verbose=False,
|
||||
)
|
||||
sv3d_model = initial_model_load(sv3d_model)
|
||||
# ------------------
|
||||
|
||||
def sample_anchor(
|
||||
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
||||
seed: Optional[int] = None,
|
||||
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
||||
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||
num_steps: int = 20,
|
||||
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
|
||||
fps_id: int = 6,
|
||||
motion_bucket_id: int = 127,
|
||||
cond_aug: float = 1e-5,
|
||||
device: str = "cuda",
|
||||
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
|
||||
azimuths_deg: Optional[List[float]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
||||
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
||||
"""
|
||||
output_folder = os.path.dirname(input_path)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# Read input video frames i.e. images at view 0
|
||||
print(f"Reading {input_path}")
|
||||
images_v0 = read_video(
|
||||
input_path,
|
||||
n_frames=n_frames,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Get camera viewpoints
|
||||
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||
elevations_deg = [elevations_deg] * n_views_sv3d
|
||||
assert (
|
||||
len(elevations_deg) == n_views_sv3d
|
||||
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
|
||||
if azimuths_deg is None:
|
||||
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
|
||||
assert (
|
||||
len(azimuths_deg) == n_views_sv3d
|
||||
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
||||
azimuths_rad = np.array(
|
||||
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||
)
|
||||
|
||||
# Sample multi-view images of the first frame using SV3D i.e. images at time 0
|
||||
sv3d_model.sampler.num_steps = num_steps
|
||||
print("sv3d_model.sampler.num_steps", sv3d_model.sampler.num_steps)
|
||||
images_t0 = sample_sv3d(
|
||||
images_v0[0],
|
||||
n_views_sv3d,
|
||||
num_steps,
|
||||
sv3d_version,
|
||||
fps_id,
|
||||
motion_bucket_id,
|
||||
cond_aug,
|
||||
decoding_t,
|
||||
device,
|
||||
polars_rad,
|
||||
azimuths_rad,
|
||||
verbose,
|
||||
sv3d_model,
|
||||
)
|
||||
images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame
|
||||
|
||||
sv3d_file = os.path.join(output_folder, "t000.mp4")
|
||||
save_video(sv3d_file, images_t0.unsqueeze(1))
|
||||
|
||||
for emb in model.conditioner.embedders:
|
||||
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
||||
emb.en_and_decode_n_samples_a_time = encoding_t
|
||||
model.en_and_decode_n_samples_a_time = decoding_t
|
||||
# Initialize image matrix
|
||||
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
||||
for i, v in enumerate(subsampled_views):
|
||||
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
||||
for t in range(n_frames):
|
||||
img_matrix[t][0] = images_v0[t]
|
||||
|
||||
# Interleaved sampling for anchor frames
|
||||
t0, v0 = 0, 0
|
||||
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
|
||||
view_indices = np.arange(V) + 1
|
||||
print(f"Sampling anchor frames {frame_indices}")
|
||||
image = img_matrix[t0][v0]
|
||||
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
||||
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||
model.sampler.num_steps = num_steps
|
||||
version_dict["options"]["num_steps"] = num_steps
|
||||
samples = run_img2vid(
|
||||
version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t
|
||||
)
|
||||
samples = samples.view(T, V, 3, H, W)
|
||||
for i, t in enumerate(frame_indices):
|
||||
for j, v in enumerate(view_indices):
|
||||
if img_matrix[t][v] is None:
|
||||
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
||||
|
||||
# concat video
|
||||
grid_list = []
|
||||
for t in frame_indices:
|
||||
imgs_view = torch.cat(img_matrix[t])
|
||||
grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0))
|
||||
# save output videos
|
||||
anchor_vis_file = os.path.join(output_folder, "anchor_vis.mp4")
|
||||
save_video(anchor_vis_file, grid_list, fps=3)
|
||||
anchor_file = os.path.join(output_folder, "anchor.mp4")
|
||||
image_list = samples.view(T*V, 3, H, W).unsqueeze(1) * 2 - 1
|
||||
save_video(anchor_file, image_list)
|
||||
|
||||
return sv3d_file, anchor_vis_file, anchor_file
|
||||
|
||||
|
||||
def sample_all(
|
||||
input_path: str = "inputs/test_video1.mp4", # Can either be video file or folder with image files
|
||||
sv3d_path: str = "outputs/sv4d/000000_t000.mp4",
|
||||
anchor_path: str = "outputs/sv4d/000000_anchor.mp4",
|
||||
seed: Optional[int] = None,
|
||||
num_steps: int = 20,
|
||||
device: str = "cuda",
|
||||
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
|
||||
azimuths_deg: Optional[List[float]] = None,
|
||||
):
|
||||
"""
|
||||
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
||||
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
||||
"""
|
||||
output_folder = os.path.dirname(input_path)
|
||||
torch.manual_seed(seed)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# Read input video frames i.e. images at view 0
|
||||
print(f"Reading {input_path}")
|
||||
images_v0 = read_video(
|
||||
input_path,
|
||||
n_frames=n_frames,
|
||||
device=device,
|
||||
)
|
||||
|
||||
images_t0 = read_video(
|
||||
sv3d_path,
|
||||
n_frames=n_views_sv3d,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Get camera viewpoints
|
||||
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||
elevations_deg = [elevations_deg] * n_views_sv3d
|
||||
assert (
|
||||
len(elevations_deg) == n_views_sv3d
|
||||
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
|
||||
if azimuths_deg is None:
|
||||
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
|
||||
assert (
|
||||
len(azimuths_deg) == n_views_sv3d
|
||||
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
||||
azimuths_rad = np.array(
|
||||
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||
)
|
||||
|
||||
# Initialize image matrix
|
||||
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
||||
for i, v in enumerate(subsampled_views):
|
||||
img_matrix[0][i] = images_t0[v]
|
||||
for t in range(n_frames):
|
||||
img_matrix[t][0] = images_v0[t]
|
||||
|
||||
# load interleaved sampling for anchor frames
|
||||
t0, v0 = 0, 0
|
||||
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
|
||||
view_indices = np.arange(V) + 1
|
||||
|
||||
anchor_frames = read_video(
|
||||
anchor_path,
|
||||
n_frames=T * V,
|
||||
device=device,
|
||||
)
|
||||
anchor_frames = torch.cat(anchor_frames).view(T, V, 3, H, W)
|
||||
for i, t in enumerate(frame_indices):
|
||||
for j, v in enumerate(view_indices):
|
||||
if img_matrix[t][v] is None:
|
||||
img_matrix[t][v] = anchor_frames[i, j][None]
|
||||
|
||||
# Dense sampling for the rest
|
||||
print(f"Sampling dense frames:")
|
||||
for t0 in np.arange(0, n_frames - 1, T - 1): # [0, 4, 8, 12, 16]
|
||||
frame_indices = t0 + np.arange(T)
|
||||
print(f"Sampling dense frames {frame_indices}")
|
||||
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
|
||||
|
||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||
|
||||
# alternate between forward and backward conditioning
|
||||
forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(
|
||||
frame_indices,
|
||||
img_matrix,
|
||||
v0,
|
||||
view_indices,
|
||||
model,
|
||||
version_dict,
|
||||
seed,
|
||||
polars,
|
||||
azims
|
||||
)
|
||||
|
||||
for step in range(num_steps):
|
||||
if step % 2 == 1:
|
||||
c, uc, additional_model_inputs, sampler = forward_inputs
|
||||
frame_indices = forward_frame_indices
|
||||
else:
|
||||
c, uc, additional_model_inputs, sampler = backward_inputs
|
||||
frame_indices = backward_frame_indices
|
||||
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
|
||||
|
||||
samples = do_sample_per_step(
|
||||
model,
|
||||
sampler,
|
||||
noisy_latents,
|
||||
c,
|
||||
uc,
|
||||
step,
|
||||
additional_model_inputs,
|
||||
)
|
||||
samples = samples.view(T, V, C, H // F, W // F)
|
||||
for i, t in enumerate(frame_indices):
|
||||
for j, v in enumerate(view_indices):
|
||||
latent_matrix[t, v] = samples[i, j]
|
||||
|
||||
img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)
|
||||
|
||||
|
||||
# concat video
|
||||
grid_list = []
|
||||
for t in range(n_frames):
|
||||
imgs_view = torch.cat(img_matrix[t])
|
||||
grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0))
|
||||
# save output videos
|
||||
vid_file = os.path.join(output_folder, "sv4d_final.mp4")
|
||||
save_video(vid_file, grid_list)
|
||||
|
||||
return vid_file, seed
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(
|
||||
"""# Demo for SV4D from Stability AI ([model](https://huggingface.co/stabilityai/sv4d), [news](https://stability.ai/news/stable-video-4d))
|
||||
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv4d/blob/main/LICENSE.md)): generate 8 novel view videos from a single-view video (with white background).
|
||||
#### It takes ~45s to generate anchor frames and another ~160s to generate full results (21 frames).
|
||||
#### Hints for improving performance:
|
||||
- Use a white background;
|
||||
- Make the object in the center of the image;
|
||||
- The SV4D process the first 21 frames of the uploaded video. Gradio provides a nice option of trimming the uploaded video if needed.
|
||||
"""
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_video = gr.Video(label="Upload your video")
|
||||
generate_btn = gr.Button("Step 1: generate 8 novel view videos (5 anchor frames each)")
|
||||
interpolate_btn = gr.Button("Step 2: Extend novel view videos to 21 frames")
|
||||
with gr.Column():
|
||||
anchor_video = gr.Video(label="SV4D outputs (anchor frames)")
|
||||
sv3d_video = gr.Video(label="SV3D outputs", interactive=False)
|
||||
with gr.Column():
|
||||
sv4d_interpolated_video = gr.Video(label="SV4D outputs (21 frames)")
|
||||
|
||||
with gr.Accordion("Advanced options", open=False):
|
||||
seed = gr.Slider(
|
||||
label="Seed",
|
||||
value=23,
|
||||
# randomize=True,
|
||||
minimum=0,
|
||||
maximum=100,
|
||||
step=1,
|
||||
)
|
||||
encoding_t = gr.Slider(
|
||||
label="Encode n frames at a time",
|
||||
info="Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.",
|
||||
value=8,
|
||||
minimum=1,
|
||||
maximum=40,
|
||||
)
|
||||
decoding_t = gr.Slider(
|
||||
label="Decode n frames at a time",
|
||||
info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.",
|
||||
value=4,
|
||||
minimum=1,
|
||||
maximum=14,
|
||||
)
|
||||
denoising_steps = gr.Slider(
|
||||
label="Number of denoising steps",
|
||||
info="Increase will improve the performance but needs more time.",
|
||||
value=20,
|
||||
minimum=10,
|
||||
maximum=50,
|
||||
step=1,
|
||||
)
|
||||
remove_bg = gr.Checkbox(
|
||||
label="Remove background",
|
||||
info="We use rembg. Users can check the alternative way: SAM2 (https://github.com/facebookresearch/segment-anything-2)",
|
||||
)
|
||||
|
||||
input_video.upload(fn=preprocess_video, inputs=[input_video, remove_bg], outputs=input_video, queue=False)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
anchor_frames = gr.Video()
|
||||
|
||||
generate_btn.click(
|
||||
fn=sample_anchor,
|
||||
inputs=[input_video, seed, encoding_t, decoding_t, denoising_steps],
|
||||
outputs=[sv3d_video, anchor_video, anchor_frames],
|
||||
api_name="SV4D output (5 frames)",
|
||||
)
|
||||
|
||||
interpolate_btn.click(
|
||||
fn=sample_all,
|
||||
inputs=[input_video, sv3d_video, anchor_frames, seed, denoising_steps],
|
||||
outputs=[sv4d_interpolated_video, seed],
|
||||
api_name="SV4D interpolation (21 frames)",
|
||||
)
|
||||
|
||||
examples = gr.Examples(
|
||||
fn=preprocess_video,
|
||||
examples=[
|
||||
"./assets/sv4d_videos/test_video1.mp4",
|
||||
"./assets/sv4d_videos/test_video2.mp4",
|
||||
"./assets/sv4d_videos/green_robot.mp4",
|
||||
"./assets/sv4d_videos/dolphin.mp4",
|
||||
"./assets/sv4d_videos/lucia_v000.mp4",
|
||||
"./assets/sv4d_videos/snowboard_v000.mp4",
|
||||
"./assets/sv4d_videos/stroller_v000.mp4",
|
||||
"./assets/sv4d_videos/human5.mp4",
|
||||
"./assets/sv4d_videos/bunnyman.mp4",
|
||||
"./assets/sv4d_videos/hiphop_parrot.mp4",
|
||||
"./assets/sv4d_videos/guppie_v0.mp4",
|
||||
"./assets/sv4d_videos/wave_hello.mp4",
|
||||
"./assets/sv4d_videos/pistol_v0.mp4",
|
||||
"./assets/sv4d_videos/human7.mp4",
|
||||
"./assets/sv4d_videos/monkey.mp4",
|
||||
"./assets/sv4d_videos/train_v0.mp4",
|
||||
],
|
||||
inputs=[input_video],
|
||||
run_on_click=True,
|
||||
outputs=[input_video],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.queue(max_size=20)
|
||||
demo.launch(share=True)
|
||||
|
||||
1421
scripts/demo/sv4d_helpers.py
Executable file
210
scripts/sampling/configs/sp4d.yaml
Executable file
@@ -0,0 +1,210 @@
|
||||
N_TIME: 4
|
||||
N_VIEW: 12
|
||||
N_FRAMES: 48
|
||||
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
en_and_decode_n_samples_a_time: 8
|
||||
disable_first_stage_autocast: True
|
||||
ckpt_path: checkpoints/sp4d.safetensors
|
||||
dual_concat: True
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.video_model.DualSpatialUNetWithCrossComm
|
||||
params:
|
||||
unet_config:
|
||||
adm_in_channels: 1280
|
||||
attention_resolutions: [4, 2, 1]
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
context_dim: 1024
|
||||
motion_context_dim: 4
|
||||
extra_ff_mix_layer: True
|
||||
in_channels: 8
|
||||
legacy: False
|
||||
model_channels: 320
|
||||
num_classes: sequential
|
||||
num_head_channels: 64
|
||||
num_res_blocks: 2
|
||||
out_channels: 4
|
||||
replicate_time_mix_bug: True
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
time_block_merge_factor: 0.0
|
||||
time_block_merge_strategy: learned_with_images
|
||||
time_kernel_size: [3, 1, 1]
|
||||
time_mix_legacy: False
|
||||
transformer_depth: 1
|
||||
use_checkpoint: False
|
||||
use_linear_in_transformer: True
|
||||
use_spatial_context: True
|
||||
use_spatial_transformer: True
|
||||
separate_motion_merge_factor: True
|
||||
use_motion_attention: True
|
||||
use_3d_attention: True
|
||||
use_camera_emb: True
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
|
||||
- input_key: cond_frames_without_noise
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||
is_trainable: False
|
||||
params:
|
||||
n_cond_frames: ${N_TIME}
|
||||
n_copies: 1
|
||||
open_clip_embedding_config:
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
|
||||
- input_key: cond_frames
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
is_trainable: False
|
||||
params:
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_FRAMES}
|
||||
n_copies: 1
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
embed_dim: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
monitor: val/rec_loss
|
||||
sigma_cond_config:
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
- input_key: polar_rad
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 512
|
||||
|
||||
- input_key: azimuth_rad
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 512
|
||||
|
||||
- input_key: cond_view
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_VIEW}
|
||||
n_copies: 1
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
- input_key: cond_motion
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_TIME}
|
||||
n_copies: 1
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencodingEngine
|
||||
params:
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
regularizer_config:
|
||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||
encoder_config:
|
||||
target: torch.nn.Identity
|
||||
decoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.DecoderDual
|
||||
params:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||
params:
|
||||
num_steps: 50
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||
params:
|
||||
sigma_max: 500.0
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider
|
||||
params:
|
||||
max_scale: 1.5
|
||||
min_scale: 1.5
|
||||
num_frames: ${N_FRAMES}
|
||||
num_views: ${N_VIEW}
|
||||
additional_cond_keys: [ cond_view, cond_motion ]
|
||||
203
scripts/sampling/configs/sv4d.yaml
Executable file
@@ -0,0 +1,203 @@
|
||||
N_TIME: 5
|
||||
N_VIEW: 8
|
||||
N_FRAMES: 40
|
||||
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
en_and_decode_n_samples_a_time: 7
|
||||
disable_first_stage_autocast: True
|
||||
ckpt_path: checkpoints/sv4d.safetensors
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
|
||||
params:
|
||||
adm_in_channels: 1280
|
||||
attention_resolutions: [4, 2, 1]
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
context_dim: 1024
|
||||
motion_context_dim: 4
|
||||
extra_ff_mix_layer: True
|
||||
in_channels: 8
|
||||
legacy: False
|
||||
model_channels: 320
|
||||
num_classes: sequential
|
||||
num_head_channels: 64
|
||||
num_res_blocks: 2
|
||||
out_channels: 4
|
||||
replicate_time_mix_bug: True
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
time_block_merge_factor: 0.0
|
||||
time_block_merge_strategy: learned_with_images
|
||||
time_kernel_size: [3, 1, 1]
|
||||
time_mix_legacy: False
|
||||
transformer_depth: 1
|
||||
use_checkpoint: False
|
||||
use_linear_in_transformer: True
|
||||
use_spatial_context: True
|
||||
use_spatial_transformer: True
|
||||
use_motion_attention: True
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
|
||||
- input_key: cond_frames_without_noise
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||
is_trainable: False
|
||||
params:
|
||||
n_cond_frames: ${N_TIME}
|
||||
n_copies: 1
|
||||
open_clip_embedding_config:
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
|
||||
- input_key: cond_frames
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
is_trainable: False
|
||||
params:
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_FRAMES}
|
||||
n_copies: 1
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
embed_dim: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
monitor: val/rec_loss
|
||||
sigma_cond_config:
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
- input_key: polar_rad
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 512
|
||||
|
||||
- input_key: azimuth_rad
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 512
|
||||
|
||||
- input_key: cond_view
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_VIEW}
|
||||
n_copies: 1
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
- input_key: cond_motion
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_TIME}
|
||||
n_copies: 1
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencodingEngine
|
||||
params:
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
regularizer_config:
|
||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||
encoder_config:
|
||||
target: torch.nn.Identity
|
||||
decoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Decoder
|
||||
params:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||
params:
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||
params:
|
||||
sigma_max: 500.0
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||
params:
|
||||
max_scale: 2.5
|
||||
num_frames: ${N_FRAMES}
|
||||
additional_cond_keys: [ cond_view, cond_motion ]
|
||||
208
scripts/sampling/configs/sv4d2.yaml
Executable file
@@ -0,0 +1,208 @@
|
||||
N_TIME: 12
|
||||
N_VIEW: 4
|
||||
N_FRAMES: 48
|
||||
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
en_and_decode_n_samples_a_time: 8
|
||||
disable_first_stage_autocast: True
|
||||
ckpt_path: checkpoints/sv4d2.safetensors
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
|
||||
params:
|
||||
adm_in_channels: 1280
|
||||
attention_resolutions: [4, 2, 1]
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
context_dim: 1024
|
||||
motion_context_dim: 4
|
||||
extra_ff_mix_layer: True
|
||||
in_channels: 8
|
||||
legacy: False
|
||||
model_channels: 320
|
||||
num_classes: sequential
|
||||
num_head_channels: 64
|
||||
num_res_blocks: 2
|
||||
out_channels: 4
|
||||
replicate_time_mix_bug: True
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
time_block_merge_factor: 0.0
|
||||
time_block_merge_strategy: learned_with_images
|
||||
time_kernel_size: [3, 1, 1]
|
||||
time_mix_legacy: False
|
||||
transformer_depth: 1
|
||||
use_checkpoint: False
|
||||
use_linear_in_transformer: True
|
||||
use_spatial_context: True
|
||||
use_spatial_transformer: True
|
||||
separate_motion_merge_factor: True
|
||||
use_motion_attention: True
|
||||
use_3d_attention: True
|
||||
use_camera_emb: True
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
|
||||
- input_key: cond_frames_without_noise
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||
is_trainable: False
|
||||
params:
|
||||
n_cond_frames: ${N_TIME}
|
||||
n_copies: 1
|
||||
open_clip_embedding_config:
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
|
||||
- input_key: cond_frames
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
is_trainable: False
|
||||
params:
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_FRAMES}
|
||||
n_copies: 1
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
embed_dim: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
monitor: val/rec_loss
|
||||
sigma_cond_config:
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
- input_key: polar_rad
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 512
|
||||
|
||||
- input_key: azimuth_rad
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 512
|
||||
|
||||
- input_key: cond_view
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_VIEW}
|
||||
n_copies: 1
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
- input_key: cond_motion
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_TIME}
|
||||
n_copies: 1
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencodingEngine
|
||||
params:
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
regularizer_config:
|
||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||
encoder_config:
|
||||
target: torch.nn.Identity
|
||||
decoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Decoder
|
||||
params:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||
params:
|
||||
num_steps: 50
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||
params:
|
||||
sigma_max: 500.0
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider
|
||||
params:
|
||||
max_scale: 1.5
|
||||
min_scale: 1.5
|
||||
num_frames: ${N_FRAMES}
|
||||
num_views: ${N_VIEW}
|
||||
additional_cond_keys: [ cond_view, cond_motion ]
|
||||
208
scripts/sampling/configs/sv4d2_8views.yaml
Executable file
@@ -0,0 +1,208 @@
|
||||
N_TIME: 5
|
||||
N_VIEW: 8
|
||||
N_FRAMES: 40
|
||||
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
en_and_decode_n_samples_a_time: 8
|
||||
disable_first_stage_autocast: True
|
||||
ckpt_path: checkpoints/sv4d2_8views.safetensors
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
|
||||
params:
|
||||
adm_in_channels: 1280
|
||||
attention_resolutions: [4, 2, 1]
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
context_dim: 1024
|
||||
motion_context_dim: 4
|
||||
extra_ff_mix_layer: True
|
||||
in_channels: 8
|
||||
legacy: False
|
||||
model_channels: 320
|
||||
num_classes: sequential
|
||||
num_head_channels: 64
|
||||
num_res_blocks: 2
|
||||
out_channels: 4
|
||||
replicate_time_mix_bug: True
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
time_block_merge_factor: 0.0
|
||||
time_block_merge_strategy: learned_with_images
|
||||
time_kernel_size: [3, 1, 1]
|
||||
time_mix_legacy: False
|
||||
transformer_depth: 1
|
||||
use_checkpoint: False
|
||||
use_linear_in_transformer: True
|
||||
use_spatial_context: True
|
||||
use_spatial_transformer: True
|
||||
separate_motion_merge_factor: True
|
||||
use_motion_attention: True
|
||||
use_3d_attention: False
|
||||
use_camera_emb: True
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
|
||||
- input_key: cond_frames_without_noise
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||
is_trainable: False
|
||||
params:
|
||||
n_cond_frames: ${N_TIME}
|
||||
n_copies: 1
|
||||
open_clip_embedding_config:
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
|
||||
- input_key: cond_frames
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
is_trainable: False
|
||||
params:
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_FRAMES}
|
||||
n_copies: 1
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
embed_dim: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
monitor: val/rec_loss
|
||||
sigma_cond_config:
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
- input_key: polar_rad
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 512
|
||||
|
||||
- input_key: azimuth_rad
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 512
|
||||
|
||||
- input_key: cond_view
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_VIEW}
|
||||
n_copies: 1
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
- input_key: cond_motion
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
is_ae: True
|
||||
n_cond_frames: ${N_TIME}
|
||||
n_copies: 1
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencodingEngine
|
||||
params:
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
regularizer_config:
|
||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||
encoder_config:
|
||||
target: torch.nn.Identity
|
||||
decoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Decoder
|
||||
params:
|
||||
attn_resolutions: []
|
||||
attn_type: vanilla-xformers
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
double_z: True
|
||||
dropout: 0.0
|
||||
in_channels: 3
|
||||
num_res_blocks: 2
|
||||
out_ch: 3
|
||||
resolution: 256
|
||||
z_channels: 4
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||
params:
|
||||
num_steps: 50
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||
params:
|
||||
sigma_max: 500.0
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider
|
||||
params:
|
||||
max_scale: 2.0
|
||||
min_scale: 1.5
|
||||
num_frames: ${N_FRAMES}
|
||||
num_views: ${N_VIEW}
|
||||
additional_cond_keys: [ cond_view, cond_motion ]
|
||||
@@ -163,7 +163,7 @@ def sample(
|
||||
else:
|
||||
with Image.open(input_img_path) as image:
|
||||
if image.mode == "RGBA":
|
||||
input_image = image.convert("RGB")
|
||||
image = image.convert("RGB")
|
||||
w, h = image.size
|
||||
|
||||
if h % 64 != 0 or w % 64 != 0:
|
||||
@@ -172,7 +172,8 @@ def sample(
|
||||
print(
|
||||
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
|
||||
)
|
||||
|
||||
input_image = np.array(image)
|
||||
|
||||
image = ToTensor()(input_image)
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
|
||||
259
scripts/sampling/simple_video_sample_4d.py
Executable file
@@ -0,0 +1,259 @@
|
||||
import os
|
||||
import sys
|
||||
from glob import glob
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||
import numpy as np
|
||||
import torch
|
||||
from fire import Fire
|
||||
|
||||
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
||||
from scripts.demo.sv4d_helpers import (
|
||||
decode_latents,
|
||||
load_model,
|
||||
initial_model_load,
|
||||
read_video,
|
||||
run_img2vid,
|
||||
prepare_sampling,
|
||||
prepare_inputs,
|
||||
do_sample_per_step,
|
||||
sample_sv3d,
|
||||
save_video,
|
||||
preprocess_video,
|
||||
)
|
||||
|
||||
|
||||
def sample(
|
||||
input_path: str = "assets/sv4d_videos/test_video1.mp4", # Can either be image file or folder with image files
|
||||
output_folder: Optional[str] = "outputs/sv4d",
|
||||
num_steps: Optional[int] = 20,
|
||||
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
|
||||
img_size: int = 576, # image resolution
|
||||
fps_id: int = 6,
|
||||
motion_bucket_id: int = 127,
|
||||
cond_aug: float = 1e-5,
|
||||
seed: int = 23,
|
||||
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
||||
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||
device: str = "cuda",
|
||||
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
|
||||
azimuths_deg: Optional[List[float]] = None,
|
||||
image_frame_ratio: Optional[float] = 0.917,
|
||||
verbose: Optional[bool] = False,
|
||||
remove_bg: bool = False,
|
||||
):
|
||||
"""
|
||||
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
||||
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
|
||||
"""
|
||||
# Set model config
|
||||
T = 5 # number of frames per sample
|
||||
V = 8 # number of views per sample
|
||||
F = 8 # vae factor to downsize image->latent
|
||||
C = 4
|
||||
H, W = img_size, img_size
|
||||
n_frames = 21 # number of input and output video frames
|
||||
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
||||
n_views_sv3d = 21
|
||||
subsampled_views = np.array(
|
||||
[0, 2, 5, 7, 9, 12, 14, 16, 19]
|
||||
) # subsample (V+1=)9 (uniform) views from 21 SV3D views
|
||||
|
||||
model_config = "scripts/sampling/configs/sv4d.yaml"
|
||||
version_dict = {
|
||||
"T": T * V,
|
||||
"H": H,
|
||||
"W": W,
|
||||
"C": C,
|
||||
"f": F,
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.0,
|
||||
"num_views": V,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 5,
|
||||
"num_steps": num_steps,
|
||||
"force_uc_zero_embeddings": [
|
||||
"cond_frames",
|
||||
"cond_frames_without_noise",
|
||||
"cond_view",
|
||||
"cond_motion",
|
||||
],
|
||||
"additional_guider_kwargs": {
|
||||
"additional_cond_keys": ["cond_view", "cond_motion"]
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
torch.manual_seed(seed)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# Read input video frames i.e. images at view 0
|
||||
print(f"Reading {input_path}")
|
||||
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
|
||||
processed_input_path = preprocess_video(
|
||||
input_path,
|
||||
remove_bg=remove_bg,
|
||||
n_frames=n_frames,
|
||||
W=W,
|
||||
H=H,
|
||||
output_folder=output_folder,
|
||||
image_frame_ratio=image_frame_ratio,
|
||||
base_count=base_count,
|
||||
)
|
||||
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
|
||||
|
||||
# Get camera viewpoints
|
||||
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||
elevations_deg = [elevations_deg] * n_views_sv3d
|
||||
assert (
|
||||
len(elevations_deg) == n_views_sv3d
|
||||
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
|
||||
if azimuths_deg is None:
|
||||
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
|
||||
assert (
|
||||
len(azimuths_deg) == n_views_sv3d
|
||||
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
||||
azimuths_rad = np.array(
|
||||
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||
)
|
||||
|
||||
# Sample multi-view images of the first frame using SV3D i.e. images at time 0
|
||||
images_t0 = sample_sv3d(
|
||||
images_v0[0],
|
||||
n_views_sv3d,
|
||||
num_steps,
|
||||
sv3d_version,
|
||||
fps_id,
|
||||
motion_bucket_id,
|
||||
cond_aug,
|
||||
decoding_t,
|
||||
device,
|
||||
polars_rad,
|
||||
azimuths_rad,
|
||||
verbose,
|
||||
)
|
||||
images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame
|
||||
|
||||
# Initialize image matrix
|
||||
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
||||
for i, v in enumerate(subsampled_views):
|
||||
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
||||
for t in range(n_frames):
|
||||
img_matrix[t][0] = images_v0[t]
|
||||
|
||||
save_video(
|
||||
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
|
||||
img_matrix[0],
|
||||
)
|
||||
# save_video(
|
||||
# os.path.join(output_folder, f"{base_count:06d}_v000.mp4"),
|
||||
# [img_matrix[t][0] for t in range(n_frames)],
|
||||
# )
|
||||
|
||||
# Load SV4D model
|
||||
model, filter = load_model(
|
||||
model_config,
|
||||
device,
|
||||
version_dict["T"],
|
||||
num_steps,
|
||||
verbose,
|
||||
)
|
||||
model = initial_model_load(model)
|
||||
for emb in model.conditioner.embedders:
|
||||
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
||||
emb.en_and_decode_n_samples_a_time = encoding_t
|
||||
model.en_and_decode_n_samples_a_time = decoding_t
|
||||
|
||||
# Interleaved sampling for anchor frames
|
||||
t0, v0 = 0, 0
|
||||
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
|
||||
view_indices = np.arange(V) + 1
|
||||
print(f"Sampling anchor frames {frame_indices}")
|
||||
image = img_matrix[t0][v0]
|
||||
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
||||
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||
samples = run_img2vid(
|
||||
version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t
|
||||
)
|
||||
samples = samples.view(T, V, 3, H, W)
|
||||
for i, t in enumerate(frame_indices):
|
||||
for j, v in enumerate(view_indices):
|
||||
if img_matrix[t][v] is None:
|
||||
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
||||
|
||||
# Dense sampling for the rest
|
||||
print(f"Sampling dense frames:")
|
||||
for t0 in tqdm(np.arange(0, n_frames - 1, T - 1)): # [0, 4, 8, 12, 16]
|
||||
frame_indices = t0 + np.arange(T)
|
||||
print(f"Sampling dense frames {frame_indices}")
|
||||
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
|
||||
|
||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||
|
||||
# alternate between forward and backward conditioning
|
||||
forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(
|
||||
frame_indices,
|
||||
img_matrix,
|
||||
v0,
|
||||
view_indices,
|
||||
model,
|
||||
version_dict,
|
||||
seed,
|
||||
polars,
|
||||
azims
|
||||
)
|
||||
|
||||
for step in tqdm(range(num_steps)):
|
||||
if step % 2 == 1:
|
||||
c, uc, additional_model_inputs, sampler = forward_inputs
|
||||
frame_indices = forward_frame_indices
|
||||
else:
|
||||
c, uc, additional_model_inputs, sampler = backward_inputs
|
||||
frame_indices = backward_frame_indices
|
||||
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
|
||||
|
||||
samples = do_sample_per_step(
|
||||
model,
|
||||
sampler,
|
||||
noisy_latents,
|
||||
c,
|
||||
uc,
|
||||
step,
|
||||
additional_model_inputs,
|
||||
)
|
||||
samples = samples.view(T, V, C, H // F, W // F)
|
||||
for i, t in enumerate(frame_indices):
|
||||
for j, v in enumerate(view_indices):
|
||||
latent_matrix[t, v] = samples[i, j]
|
||||
|
||||
img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)
|
||||
|
||||
# Save output videos
|
||||
for v in view_indices:
|
||||
vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4")
|
||||
print(f"Saving {vid_file}")
|
||||
save_video(vid_file, [img_matrix[t][v] for t in range(n_frames)])
|
||||
|
||||
# Save diagonal video
|
||||
diag_frames = [
|
||||
img_matrix[t][(t // (n_frames // n_views)) % n_views] for t in range(n_frames)
|
||||
]
|
||||
vid_file = os.path.join(output_folder, f"{base_count:06d}_diag.mp4")
|
||||
print(f"Saving {vid_file}")
|
||||
save_video(vid_file, diag_frames)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(sample)
|
||||
235
scripts/sampling/simple_video_sample_4d2.py
Executable file
@@ -0,0 +1,235 @@
|
||||
import os
|
||||
import sys
|
||||
from glob import glob
|
||||
from typing import List, Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||
import numpy as np
|
||||
import torch
|
||||
from fire import Fire
|
||||
from scripts.demo.sv4d_helpers import (
|
||||
load_model,
|
||||
preprocess_video,
|
||||
read_video,
|
||||
run_img2vid,
|
||||
save_video,
|
||||
)
|
||||
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
||||
|
||||
sv4d2_configs = {
|
||||
"sv4d2": {
|
||||
"T": 12, # number of frames per sample
|
||||
"V": 4, # number of views per sample
|
||||
"model_config": "scripts/sampling/configs/sv4d2.yaml",
|
||||
"version_dict": {
|
||||
"T": 12 * 4,
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.0,
|
||||
"min_cfg": 2.0,
|
||||
"num_views": 4,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": [
|
||||
"cond_frames",
|
||||
"cond_frames_without_noise",
|
||||
"cond_view",
|
||||
"cond_motion",
|
||||
],
|
||||
"additional_guider_kwargs": {
|
||||
"additional_cond_keys": ["cond_view", "cond_motion"]
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"sv4d2_8views": {
|
||||
"T": 5, # number of frames per sample
|
||||
"V": 8, # number of views per sample
|
||||
"model_config": "scripts/sampling/configs/sv4d2_8views.yaml",
|
||||
"version_dict": {
|
||||
"T": 5 * 8,
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.5,
|
||||
"min_cfg": 1.5,
|
||||
"num_views": 8,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 5,
|
||||
"force_uc_zero_embeddings": [
|
||||
"cond_frames",
|
||||
"cond_frames_without_noise",
|
||||
"cond_view",
|
||||
"cond_motion",
|
||||
],
|
||||
"additional_guider_kwargs": {
|
||||
"additional_cond_keys": ["cond_view", "cond_motion"]
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def sample(
|
||||
input_path: str = "assets/sv4d_videos/camel.gif", # Can either be image file or folder with image files
|
||||
model_path: Optional[str] = "checkpoints/sv4d2.safetensors",
|
||||
output_folder: Optional[str] = "outputs",
|
||||
num_steps: Optional[int] = 50,
|
||||
img_size: int = 576, # image resolution
|
||||
n_frames: int = 21, # number of input and output video frames
|
||||
seed: int = 23,
|
||||
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
||||
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||
device: str = "cuda",
|
||||
elevations_deg: Optional[List[float]] = 0.0,
|
||||
azimuths_deg: Optional[List[float]] = None,
|
||||
image_frame_ratio: Optional[float] = 0.9,
|
||||
verbose: Optional[bool] = False,
|
||||
remove_bg: bool = False,
|
||||
):
|
||||
"""
|
||||
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
||||
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
|
||||
"""
|
||||
# Set model config
|
||||
assert os.path.basename(model_path) in [
|
||||
"sv4d2.safetensors",
|
||||
"sv4d2_8views.safetensors",
|
||||
]
|
||||
sv4d2_model = os.path.splitext(os.path.basename(model_path))[0]
|
||||
config = sv4d2_configs[sv4d2_model]
|
||||
print(sv4d2_model, config)
|
||||
T = config["T"]
|
||||
V = config["V"]
|
||||
model_config = config["model_config"]
|
||||
version_dict = config["version_dict"]
|
||||
F = 8 # vae factor to downsize image->latent
|
||||
C = 4
|
||||
H, W = img_size, img_size
|
||||
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
||||
subsampled_views = np.arange(n_views)
|
||||
version_dict["H"] = H
|
||||
version_dict["W"] = W
|
||||
version_dict["C"] = C
|
||||
version_dict["f"] = F
|
||||
version_dict["options"]["num_steps"] = num_steps
|
||||
|
||||
torch.manual_seed(seed)
|
||||
output_folder = os.path.join(output_folder, sv4d2_model)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# Read input video frames i.e. images at view 0
|
||||
print(f"Reading {input_path}")
|
||||
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // n_views
|
||||
processed_input_path = preprocess_video(
|
||||
input_path,
|
||||
remove_bg=remove_bg,
|
||||
n_frames=n_frames,
|
||||
W=W,
|
||||
H=H,
|
||||
output_folder=output_folder,
|
||||
image_frame_ratio=image_frame_ratio,
|
||||
base_count=base_count,
|
||||
)
|
||||
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
|
||||
images_t0 = torch.zeros(n_views, 3, H, W).float().to(device)
|
||||
|
||||
# Get camera viewpoints
|
||||
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||
elevations_deg = [elevations_deg] * n_views
|
||||
assert (
|
||||
len(elevations_deg) == n_views
|
||||
), f"Please provide 1 value, or a list of {n_views} values for elevations_deg! Given {len(elevations_deg)}"
|
||||
if azimuths_deg is None:
|
||||
# azimuths_deg = np.linspace(0, 360, n_views + 1)[1:] % 360
|
||||
azimuths_deg = (
|
||||
np.array([0, 60, 120, 180, 240])
|
||||
if sv4d2_model == "sv4d2"
|
||||
else np.array([0, 30, 75, 120, 165, 210, 255, 300, 330])
|
||||
)
|
||||
assert (
|
||||
len(azimuths_deg) == n_views
|
||||
), f"Please provide a list of {n_views} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
||||
azimuths_rad = np.array(
|
||||
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||
)
|
||||
|
||||
# Initialize image matrix
|
||||
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
||||
for i, v in enumerate(subsampled_views):
|
||||
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
||||
for t in range(n_frames):
|
||||
img_matrix[t][0] = images_v0[t]
|
||||
|
||||
# Load SV4D++ model
|
||||
model, _ = load_model(
|
||||
model_config,
|
||||
device,
|
||||
version_dict["T"],
|
||||
num_steps,
|
||||
verbose,
|
||||
model_path,
|
||||
)
|
||||
model.en_and_decode_n_samples_a_time = decoding_t
|
||||
for emb in model.conditioner.embedders:
|
||||
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
||||
emb.en_and_decode_n_samples_a_time = encoding_t
|
||||
|
||||
# Sampling novel-view videos
|
||||
v0 = 0
|
||||
view_indices = np.arange(V) + 1
|
||||
t0_list = (
|
||||
range(0, n_frames, T-1)
|
||||
if sv4d2_model == "sv4d2"
|
||||
else range(0, n_frames - T + 1, T - 1)
|
||||
)
|
||||
for t0 in tqdm(t0_list):
|
||||
if t0 + T > n_frames:
|
||||
t0 = n_frames - T
|
||||
frame_indices = t0 + np.arange(T)
|
||||
print(f"Sampling frames {frame_indices}")
|
||||
image = img_matrix[t0][v0]
|
||||
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
||||
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
polars = (polars - polars_rad[v0] + torch.pi / 2) % (torch.pi * 2)
|
||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||
cond_mv = False if t0 == 0 else True
|
||||
samples = run_img2vid(
|
||||
version_dict,
|
||||
model,
|
||||
image,
|
||||
seed,
|
||||
polars,
|
||||
azims,
|
||||
cond_motion,
|
||||
cond_view,
|
||||
decoding_t,
|
||||
cond_mv=cond_mv,
|
||||
)
|
||||
samples = samples.view(T, V, 3, H, W)
|
||||
|
||||
for i, t in enumerate(frame_indices):
|
||||
for j, v in enumerate(view_indices):
|
||||
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
||||
|
||||
# Save output videos
|
||||
for v in view_indices:
|
||||
vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4")
|
||||
print(f"Saving {vid_file}")
|
||||
save_video(
|
||||
vid_file,
|
||||
[img_matrix[t][v] for t in range(n_frames) if img_matrix[t][v] is not None],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(sample)
|
||||
198
scripts/sampling/simple_video_sample_sp4d.py
Executable file
@@ -0,0 +1,198 @@
|
||||
import os
|
||||
import sys
|
||||
from glob import glob
|
||||
from typing import List, Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||
import numpy as np
|
||||
import torch
|
||||
from fire import Fire
|
||||
from scripts.demo.sv4d_helpers import (
|
||||
load_model,
|
||||
preprocess_video,
|
||||
read_video,
|
||||
run_img2vid,
|
||||
save_video,
|
||||
)
|
||||
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
||||
|
||||
sp4d_configs = {
|
||||
"sp4d": {
|
||||
"T": 4, # number of frames per sample
|
||||
"V": 12, # number of views per sample
|
||||
"model_config": "scripts/sampling/configs/sp4d.yaml",
|
||||
"version_dict": {
|
||||
"T": 48,
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 3.0,
|
||||
"min_cfg": 1.5,
|
||||
"num_views": 12,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": [
|
||||
"cond_frames",
|
||||
"cond_frames_without_noise",
|
||||
"cond_view",
|
||||
"cond_motion",
|
||||
],
|
||||
"additional_guider_kwargs": {
|
||||
"additional_cond_keys": ["cond_view", "cond_motion"]
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def sample(
|
||||
input_path: str = "assets/sv4d_videos/camel.gif", # Can either be image file or folder with image files
|
||||
model_path: Optional[str] = "checkpoints/sp4d.safetensors",
|
||||
output_folder: Optional[str] = "outputs",
|
||||
num_steps: Optional[int] = 50,
|
||||
img_size: int = 512, # image resolution
|
||||
n_frames: int = 4, # number of input and output video frames
|
||||
seed: int = 23,
|
||||
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
||||
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||
device: str = "cuda",
|
||||
elevations_deg: Optional[List[float]] = 0.0,
|
||||
azimuths_deg: Optional[List[float]] = None,
|
||||
image_frame_ratio: Optional[float] = 0.9,
|
||||
verbose: Optional[bool] = False,
|
||||
remove_bg: bool = False,
|
||||
):
|
||||
"""
|
||||
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
||||
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
|
||||
"""
|
||||
# Set model config
|
||||
assert os.path.basename(model_path) in [
|
||||
"sp4d.safetensors",
|
||||
]
|
||||
sp4d_model = os.path.splitext(os.path.basename(model_path))[0]
|
||||
config = sp4d_configs[sp4d_model]
|
||||
print(sp4d_model, config)
|
||||
T = config["T"]
|
||||
V = config["V"]
|
||||
model_config = config["model_config"]
|
||||
version_dict = config["version_dict"]
|
||||
F = 8 # vae factor to downsize image->latent
|
||||
C = 4
|
||||
H, W = img_size, img_size
|
||||
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
||||
subsampled_views = np.arange(n_views)
|
||||
version_dict["H"] = H
|
||||
version_dict["W"] = W
|
||||
version_dict["C"] = C
|
||||
version_dict["f"] = F
|
||||
version_dict["options"]["num_steps"] = num_steps
|
||||
|
||||
torch.manual_seed(seed)
|
||||
output_folder = os.path.join(output_folder, sp4d_model)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# Read input video frames i.e. images at view 0
|
||||
print(f"Reading {input_path}")
|
||||
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // (n_frames + 1)
|
||||
processed_input_path = preprocess_video(
|
||||
input_path,
|
||||
remove_bg=remove_bg,
|
||||
n_frames=n_frames,
|
||||
W=W,
|
||||
H=H,
|
||||
output_folder=output_folder,
|
||||
image_frame_ratio=image_frame_ratio,
|
||||
base_count=base_count,
|
||||
)
|
||||
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
|
||||
images_t0 = torch.zeros(n_views, 3, H, W).float().to(device)
|
||||
|
||||
# Get camera viewpoints
|
||||
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||
elevations_deg = [elevations_deg] * n_views
|
||||
assert (
|
||||
len(elevations_deg) == n_views
|
||||
), f"Please provide 1 value, or a list of {n_views} values for elevations_deg! Given {len(elevations_deg)}"
|
||||
if azimuths_deg is None:
|
||||
azimuths_deg = np.linspace(0, 360, n_views + 1)[1:] % 360
|
||||
assert (
|
||||
len(azimuths_deg) == n_views
|
||||
), f"Please provide a list of {n_views} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
||||
azimuths_rad = np.array(
|
||||
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||
)
|
||||
|
||||
# Initialize image matrix
|
||||
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
||||
for i, v in enumerate(subsampled_views):
|
||||
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
||||
for t in range(n_frames):
|
||||
img_matrix[t][0] = images_v0[t]
|
||||
|
||||
# Load SV4D++ model
|
||||
model, _ = load_model(
|
||||
model_config,
|
||||
device,
|
||||
version_dict["T"],
|
||||
num_steps,
|
||||
verbose,
|
||||
model_path,
|
||||
)
|
||||
model.en_and_decode_n_samples_a_time = decoding_t
|
||||
for emb in model.conditioner.embedders:
|
||||
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
||||
emb.en_and_decode_n_samples_a_time = encoding_t
|
||||
|
||||
# Sampling novel-view videos
|
||||
v0 = 0
|
||||
view_indices = np.arange(V) + 1
|
||||
t0_list = range(0, n_frames - T + 1, T - 1)
|
||||
for t0 in tqdm(t0_list):
|
||||
if t0 + T > n_frames:
|
||||
t0 = n_frames - T
|
||||
frame_indices = t0 + np.arange(T)
|
||||
print(f"Sampling frames {frame_indices}")
|
||||
image = img_matrix[t0][v0]
|
||||
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
||||
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
polars = (polars - polars_rad[v0] + torch.pi / 2) % (torch.pi * 2)
|
||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||
samples = run_img2vid(
|
||||
version_dict,
|
||||
model,
|
||||
image,
|
||||
seed,
|
||||
polars,
|
||||
azims,
|
||||
cond_motion,
|
||||
cond_view,
|
||||
decoding_t,
|
||||
cond_mv=False,
|
||||
part_maps=True,
|
||||
)
|
||||
samples = samples.view(T, V, 3, H, -1)
|
||||
|
||||
for i, t in enumerate(frame_indices):
|
||||
for j, v in enumerate(view_indices):
|
||||
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
||||
|
||||
# Save output videos
|
||||
for t in frame_indices:
|
||||
vid_file = os.path.join(output_folder, f"{base_count:06d}_{t:03d}.mp4")
|
||||
print(f"Saving {vid_file}")
|
||||
save_video(
|
||||
vid_file,
|
||||
[img_matrix[t][v] for v in range(1, n_views) if img_matrix[t][v] is not None],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(sample)
|
||||
@@ -38,6 +38,7 @@ class DiffusionEngine(pl.LightningModule):
|
||||
no_cond_log: bool = False,
|
||||
compile_model: bool = False,
|
||||
en_and_decode_n_samples_a_time: Optional[int] = None,
|
||||
dual_concat: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.log_keys = log_keys
|
||||
@@ -47,7 +48,7 @@ class DiffusionEngine(pl.LightningModule):
|
||||
)
|
||||
model = instantiate_from_config(network_config)
|
||||
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
||||
model, compile_model=compile_model
|
||||
model, compile_model=compile_model, dual_concat=dual_concat
|
||||
)
|
||||
|
||||
self.denoiser = instantiate_from_config(denoiser_config)
|
||||
|
||||
@@ -94,7 +94,7 @@ class LinearPredictionGuider(Guider):
|
||||
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
|
||||
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
||||
else:
|
||||
assert c[k] == uc[k]
|
||||
# assert c[k] == uc[k]
|
||||
c_out[k] = c[k]
|
||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||
|
||||
@@ -105,7 +105,7 @@ class TrianglePredictionGuider(LinearPredictionGuider):
|
||||
max_scale: float,
|
||||
num_frames: int,
|
||||
min_scale: float = 1.0,
|
||||
period: float | List[float] = 1.0,
|
||||
period: Union[float, List[float]] = 1.0,
|
||||
period_fusing: Literal["mean", "multiply", "max"] = "max",
|
||||
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||
):
|
||||
@@ -129,3 +129,47 @@ class TrianglePredictionGuider(LinearPredictionGuider):
|
||||
|
||||
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
|
||||
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
||||
|
||||
|
||||
class TrapezoidPredictionGuider(LinearPredictionGuider):
|
||||
def __init__(
|
||||
self,
|
||||
max_scale: float,
|
||||
num_frames: int,
|
||||
min_scale: float = 1.0,
|
||||
edge_perc: float = 0.1,
|
||||
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||
):
|
||||
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
|
||||
|
||||
rise_steps = torch.linspace(min_scale, max_scale, int(num_frames * edge_perc))
|
||||
fall_steps = torch.flip(rise_steps, [0])
|
||||
self.scale = torch.cat(
|
||||
[
|
||||
rise_steps,
|
||||
torch.ones(num_frames - 2 * int(num_frames * edge_perc)),
|
||||
fall_steps,
|
||||
]
|
||||
).unsqueeze(0)
|
||||
|
||||
|
||||
class SpatiotemporalPredictionGuider(LinearPredictionGuider):
|
||||
def __init__(
|
||||
self,
|
||||
max_scale: float,
|
||||
num_frames: int,
|
||||
num_views: int = 1,
|
||||
min_scale: float = 1.0,
|
||||
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||
):
|
||||
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
|
||||
V = num_views
|
||||
T = num_frames // V
|
||||
scale = torch.zeros(num_frames).view(T, V)
|
||||
scale += torch.linspace(0, 1, T)[:,None] * 0.5
|
||||
scale += self.triangle_wave(torch.linspace(0, 1, V))[None,:] * 0.5
|
||||
scale = scale.flatten()
|
||||
self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)
|
||||
|
||||
def triangle_wave(self, values: torch.Tensor, period=1) -> torch.Tensor:
|
||||
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
||||
@@ -746,3 +746,170 @@ class Decoder(nn.Module):
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class DecoderDual(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logpy.info(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
make_attn_cls = self._make_attn()
|
||||
make_resblock_cls = self._make_resblock()
|
||||
make_conv_cls = self._make_conv()
|
||||
|
||||
# z to block_in (处理单个 latent)
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn_cls(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = make_conv_cls(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def _make_attn(self) -> Callable:
|
||||
return make_attn
|
||||
|
||||
def _make_resblock(self) -> Callable:
|
||||
return ResnetBlock
|
||||
|
||||
def _make_conv(self) -> Callable:
|
||||
return torch.nn.Conv2d
|
||||
|
||||
def get_last_layer(self, **kwargs):
|
||||
return self.conv_out.weight
|
||||
|
||||
def forward(self, z, **kwargs):
|
||||
"""
|
||||
输入 z 的形状应为 (B, 2 * z_channels, H, W)
|
||||
- 其中前一半通道为第一个 latent,后一半通道为第二个 latent
|
||||
- 分离后分别解码,最终在 W 维度拼接
|
||||
"""
|
||||
# 断言检查,确保输入的通道数是 2 倍的 z_channels
|
||||
assert (
|
||||
z.shape[1] == 2 * self.z_shape[1]
|
||||
), f"Expected {2 * self.z_shape[1]} channels, got {z.shape[1]}"
|
||||
|
||||
# 分割 latent 为两个部分
|
||||
z1, z2 = torch.chunk(z, 2, dim=1) # 按照通道维度 (C) 切分
|
||||
|
||||
# 分别解码
|
||||
img1 = self.decode_single(z1, **kwargs)
|
||||
img2 = self.decode_single(z2, **kwargs)
|
||||
|
||||
# 沿着 W 维度拼接
|
||||
output = torch.cat([img1, img2], dim=-1) # 在 width 维度拼接
|
||||
|
||||
return output
|
||||
|
||||
def decode_single(self, z, **kwargs):
|
||||
"""解码单个 latent 到一张图像"""
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, None, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, None, **kwargs)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, None, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
|
||||
return h
|
||||
|
||||
@@ -74,21 +74,62 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
x: th.Tensor,
|
||||
emb: th.Tensor,
|
||||
context: Optional[th.Tensor] = None,
|
||||
cam: Optional[th.Tensor] = None,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
cond_view: Optional[th.Tensor] = None,
|
||||
cond_motion: Optional[th.Tensor] = None,
|
||||
time_context: Optional[int] = None,
|
||||
num_video_frames: Optional[int] = None,
|
||||
time_step: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
from ...modules.diffusionmodules.video_model import VideoResBlock
|
||||
from ...modules.diffusionmodules.video_model import VideoResBlock, PostHocResBlockWithTime
|
||||
from ...modules.spacetime_attention import (
|
||||
BasicTransformerTimeMixBlock,
|
||||
PostHocSpatialTransformerWithTimeMixing,
|
||||
PostHocSpatialTransformerWithTimeMixingAndMotion,
|
||||
)
|
||||
|
||||
for layer in self:
|
||||
module = layer
|
||||
|
||||
if isinstance(module, TimestepBlock) and not isinstance(
|
||||
module, VideoResBlock
|
||||
if isinstance(
|
||||
module,
|
||||
(
|
||||
BasicTransformerTimeMixBlock,
|
||||
PostHocSpatialTransformerWithTimeMixing,
|
||||
),
|
||||
):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(module, VideoResBlock):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
x = layer(
|
||||
x,
|
||||
context,
|
||||
emb,
|
||||
time_context,
|
||||
num_video_frames,
|
||||
image_only_indicator,
|
||||
cond_view,
|
||||
cond_motion,
|
||||
time_step,
|
||||
name,
|
||||
)
|
||||
elif isinstance(
|
||||
module,
|
||||
(
|
||||
PostHocSpatialTransformerWithTimeMixingAndMotion,
|
||||
),
|
||||
):
|
||||
x = layer(
|
||||
x,
|
||||
context,
|
||||
emb,
|
||||
time_context,
|
||||
num_video_frames,
|
||||
image_only_indicator,
|
||||
cond_view,
|
||||
cond_motion,
|
||||
time_step,
|
||||
name,
|
||||
)
|
||||
elif isinstance(module, SpatialVideoTransformer):
|
||||
x = layer(
|
||||
x,
|
||||
@@ -96,7 +137,16 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
time_context,
|
||||
num_video_frames,
|
||||
image_only_indicator,
|
||||
# time_step,
|
||||
)
|
||||
elif isinstance(module, PostHocResBlockWithTime):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(module, VideoResBlock):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(module, TimestepBlock) and not isinstance(
|
||||
module, VideoResBlock
|
||||
):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(module, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
|
||||
from typing import Optional, Union
|
||||
from ...util import default, instantiate_from_config
|
||||
|
||||
|
||||
@@ -29,3 +29,10 @@ class DiscreteSampling:
|
||||
torch.randint(0, self.num_idx, (n_samples,)),
|
||||
)
|
||||
return self.idx_to_sigma(idx)
|
||||
|
||||
|
||||
class ZeroSampler:
|
||||
def __call__(
|
||||
self, n_samples: int, rand: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5
|
||||
|
||||
@@ -17,6 +17,36 @@ import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def get_alpha(
|
||||
merge_strategy: str,
|
||||
mix_factor: Optional[torch.Tensor],
|
||||
image_only_indicator: torch.Tensor,
|
||||
apply_sigmoid: bool = True,
|
||||
is_attn: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if merge_strategy == "fixed" or merge_strategy == "learned":
|
||||
alpha = mix_factor
|
||||
elif merge_strategy == "learned_with_images":
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(mix_factor, "... -> ... 1"),
|
||||
)
|
||||
if is_attn:
|
||||
alpha = rearrange(alpha, "b t -> (b t) 1 1")
|
||||
else:
|
||||
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
|
||||
elif merge_strategy == "fixed_with_images":
|
||||
alpha = image_only_indicator
|
||||
if is_attn:
|
||||
alpha = rearrange(alpha, "b t -> (b t) 1 1")
|
||||
else:
|
||||
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return torch.sigmoid(alpha) if apply_sigmoid else alpha
|
||||
|
||||
|
||||
def make_beta_schedule(
|
||||
schedule,
|
||||
n_timestep,
|
||||
|
||||
@@ -5,9 +5,15 @@ from einops import rearrange
|
||||
|
||||
from ...modules.diffusionmodules.openaimodel import *
|
||||
from ...modules.video_attention import SpatialVideoTransformer
|
||||
from ...modules.spacetime_attention import (
|
||||
BasicTransformerTimeMixBlock,
|
||||
PostHocSpatialTransformerWithTimeMixing,
|
||||
PostHocSpatialTransformerWithTimeMixingAndMotion,
|
||||
)
|
||||
from ...util import default
|
||||
from .util import AlphaBlender
|
||||
from .util import AlphaBlender, get_alpha
|
||||
|
||||
import torch
|
||||
|
||||
class VideoResBlock(ResBlock):
|
||||
def __init__(
|
||||
@@ -491,3 +497,913 @@ class VideoUNet(nn.Module):
|
||||
)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class PostHocAttentionBlockWithTimeMixing(AttentionBlock):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
n_heads: int,
|
||||
d_head: int,
|
||||
use_checkpoint: bool = False,
|
||||
use_new_attention_order: bool = False,
|
||||
dropout: float = 0.0,
|
||||
use_spatial_context: bool = False,
|
||||
merge_strategy: bool = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
apply_sigmoid_to_merge: bool = True,
|
||||
ff_in: bool = False,
|
||||
attn_mode: str = "softmax",
|
||||
disable_temporal_crossattention: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
inner_dim = n_heads * d_head
|
||||
|
||||
self.time_mix_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerTimeMixBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
checkpoint=use_checkpoint,
|
||||
ff_in=ff_in,
|
||||
attn_mode=attn_mode,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
)
|
||||
]
|
||||
)
|
||||
self.in_channels = in_channels
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.time_mix_time_embed = nn.Sequential(
|
||||
linear(self.in_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
|
||||
self.use_spatial_context = use_spatial_context
|
||||
|
||||
if merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", th.Tensor([merge_factor]))
|
||||
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
||||
self.register_parameter(
|
||||
"mix_factor", th.nn.Parameter(th.Tensor([merge_factor]))
|
||||
)
|
||||
elif merge_strategy == "fixed_with_images":
|
||||
self.mix_factor = None
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
||||
|
||||
self.get_alpha_fn = functools.partial(
|
||||
get_alpha,
|
||||
merge_strategy,
|
||||
self.mix_factor,
|
||||
apply_sigmoid=apply_sigmoid_to_merge,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
context: Optional[th.Tensor] = None,
|
||||
# cam: Optional[th.Tensor] = None,
|
||||
time_context: Optional[th.Tensor] = None,
|
||||
timesteps: Optional[int] = None,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
conv_view: Optional[th.Tensor] = None,
|
||||
conv_motion: Optional[th.Tensor] = None,
|
||||
):
|
||||
if time_context is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
_, _, h, w = x.shape
|
||||
if exists(context):
|
||||
context = rearrange(context, "b t ... -> (b t) ...")
|
||||
if self.use_spatial_context:
|
||||
time_context = repeat(context[:, 0], "b ... -> (b n) ...", n=h * w)
|
||||
|
||||
x = super().forward(
|
||||
x,
|
||||
)
|
||||
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
x_mix = x
|
||||
|
||||
num_frames = th.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
||||
emb = self.time_mix_time_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
x_mix = x_mix + emb
|
||||
|
||||
x_mix = self.time_mix_blocks[0](
|
||||
x_mix, context=time_context, timesteps=timesteps
|
||||
)
|
||||
|
||||
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
return x
|
||||
|
||||
|
||||
class PostHocResBlockWithTime(ResBlock):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
emb_channels: int,
|
||||
dropout: float,
|
||||
time_kernel_size: Union[int, List[int]] = 3,
|
||||
merge_strategy: bool = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
apply_sigmoid_to_merge: bool = True,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
use_scale_shift_norm: bool = False,
|
||||
dims: int = 2,
|
||||
use_checkpoint: bool = False,
|
||||
up: bool = False,
|
||||
down: bool = False,
|
||||
time_mix_legacy: bool = True,
|
||||
replicate_bug: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=out_channels,
|
||||
use_conv=use_conv,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
up=up,
|
||||
down=down,
|
||||
)
|
||||
|
||||
self.time_mix_blocks = ResBlock(
|
||||
default(out_channels, channels),
|
||||
emb_channels,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
out_channels=default(out_channels, channels),
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=time_kernel_size,
|
||||
use_checkpoint=use_checkpoint,
|
||||
exchange_temb_dims=True,
|
||||
)
|
||||
self.time_mix_legacy = time_mix_legacy
|
||||
if self.time_mix_legacy:
|
||||
if merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", th.Tensor([merge_factor]))
|
||||
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
||||
self.register_parameter(
|
||||
"mix_factor", th.nn.Parameter(th.Tensor([merge_factor]))
|
||||
)
|
||||
elif merge_strategy == "fixed_with_images":
|
||||
self.mix_factor = None
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
||||
|
||||
self.get_alpha_fn = functools.partial(
|
||||
get_alpha,
|
||||
merge_strategy,
|
||||
self.mix_factor,
|
||||
apply_sigmoid=apply_sigmoid_to_merge,
|
||||
)
|
||||
else:
|
||||
if False: # replicate_bug:
|
||||
logpy.warning(
|
||||
"*****************************************************************************************\n"
|
||||
"GRAVE WARNING: YOU'RE USING THE BUGGY LEGACY ALPHABLENDER!!! ARE YOU SURE YOU WANT THIS?!\n"
|
||||
"*****************************************************************************************"
|
||||
)
|
||||
self.time_mixer = LegacyAlphaBlenderWithBug(
|
||||
alpha=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
rearrange_pattern="b t -> b 1 t 1 1",
|
||||
)
|
||||
else:
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
rearrange_pattern="b t -> b 1 t 1 1",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
emb: th.Tensor,
|
||||
num_video_frames: int,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
cond_view: Optional[th.Tensor] = None,
|
||||
cond_motion: Optional[th.Tensor] = None,
|
||||
) -> th.Tensor:
|
||||
x = super().forward(x, emb)
|
||||
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
|
||||
x = self.time_mix_blocks(
|
||||
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
||||
)
|
||||
|
||||
if self.time_mix_legacy:
|
||||
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator*0.0)
|
||||
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
||||
else:
|
||||
x = self.time_mixer(
|
||||
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator*0.0
|
||||
)
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class SpatialUNetModelWithTime(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks: int,
|
||||
attention_resolutions: int,
|
||||
dropout: float = 0.0,
|
||||
channel_mult: List[int] = (1, 2, 4, 8),
|
||||
conv_resample: bool = True,
|
||||
dims: int = 2,
|
||||
num_classes: Optional[int] = None,
|
||||
use_checkpoint: bool = False,
|
||||
num_heads: int = -1,
|
||||
num_head_channels: int = -1,
|
||||
num_heads_upsample: int = -1,
|
||||
use_scale_shift_norm: bool = False,
|
||||
resblock_updown: bool = False,
|
||||
use_new_attention_order: bool = False,
|
||||
use_spatial_transformer: bool = False,
|
||||
transformer_depth: Union[List[int], int] = 1,
|
||||
transformer_depth_middle: Optional[int] = None,
|
||||
context_dim: Optional[int] = None,
|
||||
time_downup: bool = False,
|
||||
time_context_dim: Optional[int] = None,
|
||||
view_context_dim: Optional[int] = None,
|
||||
motion_context_dim: Optional[int] = None,
|
||||
extra_ff_mix_layer: bool = False,
|
||||
use_spatial_context: bool = False,
|
||||
time_block_merge_strategy: str = "fixed",
|
||||
time_block_merge_factor: float = 0.5,
|
||||
view_block_merge_factor: float = 0.5,
|
||||
motion_block_merge_factor: float = 0.5,
|
||||
spatial_transformer_attn_type: str = "softmax",
|
||||
time_kernel_size: Union[int, List[int]] = 3,
|
||||
use_linear_in_transformer: bool = False,
|
||||
legacy: bool = True,
|
||||
adm_in_channels: Optional[int] = None,
|
||||
use_temporal_resblock: bool = True,
|
||||
disable_temporal_crossattention: bool = False,
|
||||
time_mix_legacy: bool = True,
|
||||
max_ddpm_temb_period: int = 10000,
|
||||
replicate_time_mix_bug: bool = False,
|
||||
use_motion_attention: bool = False,
|
||||
use_camera_emb: bool = False,
|
||||
use_3d_attention: bool = False,
|
||||
separate_motion_merge_factor: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||
transformer_depth_middle = default(
|
||||
transformer_depth_middle, transformer_depth[-1]
|
||||
)
|
||||
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.use_temporal_resblocks = use_temporal_resblock
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "timestep":
|
||||
self.label_emb = nn.Sequential(
|
||||
Timestep(model_channels),
|
||||
nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
),
|
||||
)
|
||||
|
||||
elif self.num_classes == "sequential":
|
||||
assert adm_in_channels is not None
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
linear(adm_in_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
|
||||
def get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=1,
|
||||
context_dim=None,
|
||||
use_checkpoint=False,
|
||||
disabled_sa=False,
|
||||
):
|
||||
if not use_spatial_transformer:
|
||||
return PostHocAttentionBlockWithTimeMixing(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
dropout=dropout,
|
||||
ff_in=extra_ff_mix_layer,
|
||||
use_spatial_context=use_spatial_context,
|
||||
merge_strategy=time_block_merge_strategy,
|
||||
merge_factor=time_block_merge_factor,
|
||||
attn_mode=spatial_transformer_attn_type,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
)
|
||||
|
||||
elif use_motion_attention:
|
||||
return PostHocSpatialTransformerWithTimeMixingAndMotion(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=depth,
|
||||
context_dim=context_dim,
|
||||
time_context_dim=time_context_dim,
|
||||
motion_context_dim=motion_context_dim,
|
||||
dropout=dropout,
|
||||
ff_in=extra_ff_mix_layer,
|
||||
use_spatial_context=use_spatial_context,
|
||||
use_camera_emb=use_camera_emb,
|
||||
use_3d_attention=use_3d_attention,
|
||||
separate_motion_merge_factor=separate_motion_merge_factor,
|
||||
adm_in_channels=adm_in_channels,
|
||||
merge_strategy=time_block_merge_strategy,
|
||||
merge_factor=time_block_merge_factor,
|
||||
merge_factor_motion=motion_block_merge_factor,
|
||||
checkpoint=use_checkpoint,
|
||||
use_linear=use_linear_in_transformer,
|
||||
attn_mode=spatial_transformer_attn_type,
|
||||
disable_self_attn=disabled_sa,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
time_mix_legacy=time_mix_legacy,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
)
|
||||
|
||||
else:
|
||||
return PostHocSpatialTransformerWithTimeMixing(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=depth,
|
||||
context_dim=context_dim,
|
||||
time_context_dim=time_context_dim,
|
||||
dropout=dropout,
|
||||
ff_in=extra_ff_mix_layer,
|
||||
use_spatial_context=use_spatial_context,
|
||||
merge_strategy=time_block_merge_strategy,
|
||||
merge_factor=time_block_merge_factor,
|
||||
checkpoint=use_checkpoint,
|
||||
use_linear=use_linear_in_transformer,
|
||||
attn_mode=spatial_transformer_attn_type,
|
||||
disable_self_attn=disabled_sa,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
time_mix_legacy=time_mix_legacy,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
)
|
||||
|
||||
def get_resblock(
|
||||
time_block_merge_factor,
|
||||
time_block_merge_strategy,
|
||||
time_kernel_size,
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_ch,
|
||||
dims,
|
||||
use_checkpoint,
|
||||
use_scale_shift_norm,
|
||||
down=False,
|
||||
up=False,
|
||||
):
|
||||
if self.use_temporal_resblocks:
|
||||
return PostHocResBlockWithTime(
|
||||
merge_factor=time_block_merge_factor,
|
||||
merge_strategy=time_block_merge_strategy,
|
||||
time_kernel_size=time_kernel_size,
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
time_mix_legacy=time_mix_legacy,
|
||||
replicate_bug=replicate_time_mix_bug,
|
||||
)
|
||||
else:
|
||||
return ResBlock(
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
)
|
||||
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
get_resblock(
|
||||
time_block_merge_factor=time_block_merge_factor,
|
||||
time_block_merge_strategy=time_block_merge_strategy,
|
||||
time_kernel_size=time_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
dim_head = (
|
||||
ch // num_heads
|
||||
if use_spatial_transformer
|
||||
else num_head_channels
|
||||
)
|
||||
|
||||
layers.append(
|
||||
get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth[level],
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disabled_sa=False,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
ds *= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
get_resblock(
|
||||
time_block_merge_factor=time_block_merge_factor,
|
||||
time_block_merge_strategy=time_block_merge_strategy,
|
||||
time_kernel_size=time_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch,
|
||||
conv_resample,
|
||||
dims=dims,
|
||||
out_channels=out_ch,
|
||||
third_down=time_downup,
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
get_resblock(
|
||||
time_block_merge_factor=time_block_merge_factor,
|
||||
time_block_merge_strategy=time_block_merge_strategy,
|
||||
time_kernel_size=time_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
out_ch=None,
|
||||
dropout=dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth_middle,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
),
|
||||
get_resblock(
|
||||
time_block_merge_factor=time_block_merge_factor,
|
||||
time_block_merge_strategy=time_block_merge_strategy,
|
||||
time_kernel_size=time_kernel_size,
|
||||
ch=ch,
|
||||
out_ch=None,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
get_resblock(
|
||||
time_block_merge_factor=time_block_merge_factor,
|
||||
time_block_merge_strategy=time_block_merge_strategy,
|
||||
time_kernel_size=time_kernel_size,
|
||||
ch=ch + ich,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
dim_head = (
|
||||
ch // num_heads
|
||||
if use_spatial_transformer
|
||||
else num_head_channels
|
||||
)
|
||||
|
||||
layers.append(
|
||||
get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth[level],
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disabled_sa=False,
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
ds //= 2
|
||||
layers.append(
|
||||
get_resblock(
|
||||
time_block_merge_factor=time_block_merge_factor,
|
||||
time_block_merge_strategy=time_block_merge_strategy,
|
||||
time_kernel_size=time_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(
|
||||
ch,
|
||||
conv_resample,
|
||||
dims=dims,
|
||||
out_channels=out_ch,
|
||||
third_up=time_downup,
|
||||
)
|
||||
)
|
||||
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
timesteps: th.Tensor,
|
||||
context: Optional[th.Tensor] = None,
|
||||
y: Optional[th.Tensor] = None,
|
||||
cam: Optional[th.Tensor] = None,
|
||||
time_context: Optional[th.Tensor] = None,
|
||||
num_video_frames: Optional[int] = None,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
cond_view: Optional[th.Tensor] = None,
|
||||
cond_motion: Optional[th.Tensor] = None,
|
||||
time_step: Optional[int] = None,
|
||||
):
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # 21 x 320
|
||||
emb = self.time_embed(t_emb) # 21 x 1280
|
||||
time = str(timesteps[0].data.cpu().numpy())
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y) # 21 x 1280
|
||||
|
||||
h = x # 21 x 8 x 64 x 64
|
||||
for i, module in enumerate(self.input_blocks):
|
||||
h = module(
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
cam=cam,
|
||||
image_only_indicator=image_only_indicator,
|
||||
cond_view=cond_view,
|
||||
cond_motion=cond_motion,
|
||||
time_context=time_context,
|
||||
num_video_frames=num_video_frames,
|
||||
time_step=time_step,
|
||||
name='encoder_{}_{}'.format(time, i)
|
||||
)
|
||||
hs.append(h)
|
||||
h = self.middle_block(
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
cam=cam,
|
||||
image_only_indicator=image_only_indicator,
|
||||
cond_view=cond_view,
|
||||
cond_motion=cond_motion,
|
||||
time_context=time_context,
|
||||
num_video_frames=num_video_frames,
|
||||
time_step=time_step,
|
||||
name='middle_{}_0'.format(time, i)
|
||||
)
|
||||
for i, module in enumerate(self.output_blocks):
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
cam=cam,
|
||||
image_only_indicator=image_only_indicator,
|
||||
cond_view=cond_view,
|
||||
cond_motion=cond_motion,
|
||||
time_context=time_context,
|
||||
num_video_frames=num_video_frames,
|
||||
time_step=time_step,
|
||||
name='decoder_{}_{}'.format(time, i)
|
||||
)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class CrossNetworkLayer(nn.Module):
|
||||
def __init__(self, feature_dim: int):
|
||||
super().__init__()
|
||||
self.fusion_conv = nn.Sequential(
|
||||
nn.Conv2d(feature_dim * 2, feature_dim, kernel_size=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(feature_dim, feature_dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, h1: torch.Tensor, h2: torch.Tensor):
|
||||
"""
|
||||
h1, h2: (B, C, H, W)
|
||||
return: (out1, out2), (B, C, H, W)
|
||||
"""
|
||||
fused_input = torch.cat([h1, h2], dim=1) # (B, 2C, H, W)
|
||||
fused_output = self.fusion_conv(fused_input) # (B, C, H, W)
|
||||
out1 = fused_output + h1
|
||||
out2 = fused_output + h2
|
||||
return out1, out2
|
||||
|
||||
|
||||
class DualSpatialUNetWithCrossComm(nn.Module):
|
||||
def __init__(self, unet_config):
|
||||
super().__init__()
|
||||
self.num_classes = unet_config["num_classes"]
|
||||
self.model_channels = unet_config["model_channels"]
|
||||
|
||||
self.net1 = SpatialUNetModelWithTime(**unet_config)
|
||||
self.net2 = SpatialUNetModelWithTime(**unet_config)
|
||||
|
||||
self.input_cross_layers = nn.ModuleList()
|
||||
for block in self.net1.input_blocks:
|
||||
out_ch = self._get_block_out_channels(block)
|
||||
self.input_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
|
||||
|
||||
middle_out_ch = self._get_block_out_channels(self.net1.middle_block)
|
||||
self.middle_cross = CrossNetworkLayer(feature_dim=middle_out_ch)
|
||||
|
||||
self.output_cross_layers = nn.ModuleList()
|
||||
for block in self.net1.output_blocks:
|
||||
out_ch = self._get_block_out_channels(block)
|
||||
self.output_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
|
||||
|
||||
def _get_block_out_channels(self, block: nn.Module) -> int:
|
||||
mod_list = list(block.children())
|
||||
for m in reversed(mod_list):
|
||||
if hasattr(m, "out_channels"):
|
||||
return m.out_channels
|
||||
|
||||
if isinstance(
|
||||
m,
|
||||
(SpatialTransformer, PostHocSpatialTransformerWithTimeMixingAndMotion),
|
||||
):
|
||||
return m.in_channels
|
||||
|
||||
if isinstance(m, nn.Conv2d):
|
||||
return m.out_channels
|
||||
|
||||
raise ValueError(f"Cannot determine out_channels from block: {block}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
timesteps: th.Tensor,
|
||||
context: Optional[th.Tensor] = None,
|
||||
y: Optional[th.Tensor] = None,
|
||||
cam: Optional[th.Tensor] = None,
|
||||
time_context: Optional[th.Tensor] = None,
|
||||
num_video_frames: Optional[int] = None,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
cond_view: Optional[th.Tensor] = None,
|
||||
cond_motion: Optional[th.Tensor] = None,
|
||||
time_step: Optional[int] = None,
|
||||
):
|
||||
|
||||
# ============ encoder ============
|
||||
h1, h2 = x[:, : x.shape[1] // 2], x[:, x.shape[1] // 2 :]
|
||||
|
||||
encoder_feats1 = []
|
||||
encoder_feats2 = []
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
||||
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False
|
||||
) # 21 x 320
|
||||
|
||||
emb = self.net1.time_embed(t_emb) # 21 x 1280
|
||||
time = str(timesteps[0].data.cpu().numpy())
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == h1.shape[0]
|
||||
emb = emb + self.net1.label_emb(y) # 21 x 1280
|
||||
|
||||
filtered_args = {
|
||||
"emb": emb,
|
||||
"context": context,
|
||||
"cam": cam,
|
||||
"cond_view": cond_view,
|
||||
"cond_motion": cond_motion,
|
||||
"time_context": time_context,
|
||||
"num_video_frames": num_video_frames,
|
||||
"image_only_indicator": image_only_indicator,
|
||||
"time_step": time_step,
|
||||
}
|
||||
|
||||
for i, (block1, block2) in enumerate(
|
||||
zip(self.net1.input_blocks, self.net2.input_blocks)
|
||||
):
|
||||
h1 = block1(h1, name="encoder_{}_{}".format(time, i), **filtered_args)
|
||||
h2 = block2(h2, name="encoder_{}_{}".format(time, i), **filtered_args)
|
||||
|
||||
# cross
|
||||
h1, h2 = self.input_cross_layers[i](h1, h2)
|
||||
|
||||
encoder_feats1.append(h1)
|
||||
encoder_feats2.append(h2)
|
||||
|
||||
# ============ middle block ============
|
||||
h1 = self.net1.middle_block(
|
||||
h1, name="middle_{}_0".format(time, i), **filtered_args
|
||||
)
|
||||
h2 = self.net2.middle_block(
|
||||
h2, name="middle_{}_0".format(time, i), **filtered_args
|
||||
)
|
||||
|
||||
# cross
|
||||
h1, h2 = self.middle_cross(h1, h2)
|
||||
|
||||
# ============ decoder ============
|
||||
for i, (block1, block2) in enumerate(
|
||||
zip(self.net1.output_blocks, self.net2.output_blocks)
|
||||
):
|
||||
skip1 = encoder_feats1.pop()
|
||||
skip2 = encoder_feats2.pop()
|
||||
h1 = torch.cat([h1, skip1], dim=1)
|
||||
h2 = torch.cat([h2, skip2], dim=1)
|
||||
|
||||
h1 = block1(h1, name="decoder_{}_{}".format(time, i), **filtered_args)
|
||||
h2 = block2(h2, name="decoder_{}_{}".format(time, i), **filtered_args)
|
||||
|
||||
# cross
|
||||
h1, h2 = self.output_cross_layers[i](h1, h2)
|
||||
|
||||
# ============ output ============
|
||||
out1 = self.net1.out(h1) # shape: (B, out_channels, H, W)
|
||||
out2 = self.net2.out(h2) # same shape
|
||||
out = torch.cat([out1, out2], dim=1)
|
||||
|
||||
return out
|
||||
@@ -6,7 +6,7 @@ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
|
||||
|
||||
|
||||
class IdentityWrapper(nn.Module):
|
||||
def __init__(self, diffusion_model, compile_model: bool = False):
|
||||
def __init__(self, diffusion_model, compile_model: bool = False, dual_concat: bool = False):
|
||||
super().__init__()
|
||||
compile = (
|
||||
torch.compile
|
||||
@@ -15,6 +15,7 @@ class IdentityWrapper(nn.Module):
|
||||
else lambda x: x
|
||||
)
|
||||
self.diffusion_model = compile(diffusion_model)
|
||||
self.dual_concat = dual_concat
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.diffusion_model(*args, **kwargs)
|
||||
@@ -24,11 +25,29 @@ class OpenAIWrapper(IdentityWrapper):
|
||||
def forward(
|
||||
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
|
||||
) -> torch.Tensor:
|
||||
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
||||
return self.diffusion_model(
|
||||
x,
|
||||
timesteps=t,
|
||||
context=c.get("crossattn", None),
|
||||
y=c.get("vector", None),
|
||||
**kwargs,
|
||||
)
|
||||
if self.dual_concat:
|
||||
x_1 = x[:, : x.shape[1] // 2]
|
||||
x_2 = x[:, x.shape[1] // 2 :]
|
||||
x_1 = torch.cat((x_1, c.get("concat", torch.Tensor([]).type_as(x_1))), dim=1)
|
||||
x_2 = torch.cat((x_2, c.get("concat", torch.Tensor([]).type_as(x_2))), dim=1)
|
||||
x = torch.cat((x_1, x_2), dim=1)
|
||||
else:
|
||||
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
||||
if "cond_view" in c:
|
||||
return self.diffusion_model(
|
||||
x,
|
||||
timesteps=t,
|
||||
context=c.get("crossattn", None),
|
||||
y=c.get("vector", None),
|
||||
cond_view=c.get("cond_view", None),
|
||||
cond_motion=c.get("cond_motion", None),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self.diffusion_model(
|
||||
x,
|
||||
timesteps=t,
|
||||
context=c.get("crossattn", None),
|
||||
y=c.get("vector", None),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -69,8 +69,8 @@ class AbstractEmbModel(nn.Module):
|
||||
|
||||
|
||||
class GeneralConditioner(nn.Module):
|
||||
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
|
||||
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
|
||||
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat"} # , 5: "concat"}
|
||||
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1, "cond_view": 1, "cond_motion": 1}
|
||||
|
||||
def __init__(self, emb_models: Union[List, ListConfig]):
|
||||
super().__init__()
|
||||
@@ -138,7 +138,11 @@ class GeneralConditioner(nn.Module):
|
||||
if not isinstance(emb_out, (list, tuple)):
|
||||
emb_out = [emb_out]
|
||||
for emb in emb_out:
|
||||
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
|
||||
if embedder.input_key in ["cond_view", "cond_motion"]:
|
||||
out_key = embedder.input_key
|
||||
else:
|
||||
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
|
||||
|
||||
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
|
||||
emb = (
|
||||
expand_dims_like(
|
||||
@@ -994,7 +998,10 @@ class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
|
||||
sigmas = self.sigma_sampler(b).to(vid.device)
|
||||
if self.sigma_cond is not None:
|
||||
sigma_cond = self.sigma_cond(sigmas)
|
||||
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
|
||||
if self.n_cond_frames == 1:
|
||||
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
|
||||
else:
|
||||
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_cond_frames) # For SV4D
|
||||
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
|
||||
noise = torch.randn_like(vid)
|
||||
vid = vid + noise * append_dims(sigmas, vid.ndim)
|
||||
@@ -1017,8 +1024,9 @@ class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
|
||||
vid = torch.cat(all_out, dim=0)
|
||||
vid *= self.scale_factor
|
||||
|
||||
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
|
||||
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
|
||||
if self.n_cond_frames == 1:
|
||||
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
|
||||
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
|
||||
|
||||
return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
|
||||
|
||||
|
||||
625
sgm/modules/spacetime_attention.py
Normal file
@@ -0,0 +1,625 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..modules.attention import *
|
||||
from ..modules.diffusionmodules.util import (
|
||||
AlphaBlender,
|
||||
get_alpha,
|
||||
linear,
|
||||
mixed_checkpoint,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
class TimeMixSequential(nn.Sequential):
|
||||
def forward(self, x, context=None, timesteps=None):
|
||||
for layer in self:
|
||||
x = layer(x, context, timesteps)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerTimeMixBlock(nn.Module):
|
||||
ATTENTION_MODES = {
|
||||
"softmax": CrossAttention,
|
||||
"softmax-xformers": MemoryEfficientCrossAttention,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
timesteps=None,
|
||||
ff_in=False,
|
||||
inner_dim=None,
|
||||
attn_mode="softmax",
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
switch_temporal_ca_to_sa=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||
|
||||
self.ff_in = ff_in or inner_dim is not None
|
||||
if inner_dim is None:
|
||||
inner_dim = dim
|
||||
|
||||
assert int(n_heads * d_head) == inner_dim
|
||||
|
||||
self.is_res = inner_dim == dim
|
||||
|
||||
if self.ff_in:
|
||||
self.norm_in = nn.LayerNorm(dim)
|
||||
self.ff_in = FeedForward(
|
||||
dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
|
||||
)
|
||||
|
||||
self.timesteps = timesteps
|
||||
self.disable_self_attn = disable_self_attn
|
||||
if self.disable_self_attn:
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=inner_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=context_dim,
|
||||
dropout=dropout,
|
||||
) # is a cross-attention
|
||||
else:
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
|
||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
|
||||
|
||||
if disable_temporal_crossattention:
|
||||
if switch_temporal_ca_to_sa:
|
||||
raise ValueError
|
||||
else:
|
||||
self.attn2 = None
|
||||
else:
|
||||
self.norm2 = nn.LayerNorm(inner_dim)
|
||||
if switch_temporal_ca_to_sa:
|
||||
self.attn2 = attn_cls(
|
||||
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
else:
|
||||
self.attn2 = attn_cls(
|
||||
query_dim=inner_dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
) # is self-attn if context is none
|
||||
|
||||
self.norm1 = nn.LayerNorm(inner_dim)
|
||||
self.norm3 = nn.LayerNorm(inner_dim)
|
||||
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
||||
|
||||
self.checkpoint = checkpoint
|
||||
if self.checkpoint:
|
||||
logpy.info(f"{self.__class__.__name__} is using checkpointing")
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
|
||||
) -> torch.Tensor:
|
||||
if self.checkpoint:
|
||||
return checkpoint(self._forward, x, context, timesteps)
|
||||
else:
|
||||
return self._forward(x, context, timesteps=timesteps)
|
||||
|
||||
def _forward(self, x, context=None, timesteps=None):
|
||||
assert self.timesteps or timesteps
|
||||
assert not (self.timesteps and timesteps) or self.timesteps == timesteps
|
||||
timesteps = self.timesteps or timesteps
|
||||
B, S, C = x.shape
|
||||
x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
|
||||
|
||||
if self.ff_in:
|
||||
x_skip = x
|
||||
x = self.ff_in(self.norm_in(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
|
||||
if self.disable_self_attn:
|
||||
x = self.attn1(self.norm1(x), context=context) + x
|
||||
else:
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
|
||||
if self.attn2 is not None:
|
||||
if self.switch_temporal_ca_to_sa:
|
||||
x = self.attn2(self.norm2(x)) + x
|
||||
else:
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x_skip = x
|
||||
x = self.ff(self.norm3(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
|
||||
x = rearrange(
|
||||
x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||
)
|
||||
return x
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.ff.net[-1].weight
|
||||
|
||||
|
||||
class PostHocSpatialTransformerWithTimeMixing(SpatialTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
use_linear=False,
|
||||
context_dim=None,
|
||||
use_spatial_context=False,
|
||||
timesteps=None,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
apply_sigmoid_to_merge: bool = True,
|
||||
time_context_dim=None,
|
||||
ff_in=False,
|
||||
checkpoint=False,
|
||||
time_depth=1,
|
||||
attn_mode="softmax",
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
time_mix_legacy: bool = True,
|
||||
max_time_embed_period: int = 10000,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=depth,
|
||||
dropout=dropout,
|
||||
attn_type=attn_mode,
|
||||
use_checkpoint=checkpoint,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
disable_self_attn=disable_self_attn,
|
||||
)
|
||||
self.time_depth = time_depth
|
||||
self.depth = depth
|
||||
self.max_time_embed_period = max_time_embed_period
|
||||
|
||||
time_mix_d_head = d_head
|
||||
n_time_mix_heads = n_heads
|
||||
|
||||
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
||||
|
||||
inner_dim = n_heads * d_head
|
||||
if use_spatial_context:
|
||||
time_context_dim = context_dim
|
||||
|
||||
self.time_mix_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerTimeMixBlock(
|
||||
inner_dim,
|
||||
n_time_mix_heads,
|
||||
time_mix_d_head,
|
||||
dropout=dropout,
|
||||
context_dim=time_context_dim,
|
||||
timesteps=timesteps,
|
||||
checkpoint=checkpoint,
|
||||
ff_in=ff_in,
|
||||
inner_dim=time_mix_inner_dim,
|
||||
attn_mode=attn_mode,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
)
|
||||
|
||||
assert len(self.time_mix_blocks) == len(self.transformer_blocks)
|
||||
|
||||
self.use_spatial_context = use_spatial_context
|
||||
self.in_channels = in_channels
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.time_mix_time_embed = nn.Sequential(
|
||||
linear(self.in_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
|
||||
self.time_mix_legacy = time_mix_legacy
|
||||
if self.time_mix_legacy:
|
||||
if merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([merge_factor]))
|
||||
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor]))
|
||||
)
|
||||
elif merge_strategy == "fixed_with_images":
|
||||
self.mix_factor = None
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
||||
|
||||
self.get_alpha_fn = partial(
|
||||
get_alpha,
|
||||
merge_strategy,
|
||||
self.mix_factor,
|
||||
apply_sigmoid=apply_sigmoid_to_merge,
|
||||
is_attn=True,
|
||||
)
|
||||
else:
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor, merge_strategy=merge_strategy
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
# cam: Optional[torch.Tensor] = None,
|
||||
time_context: Optional[torch.Tensor] = None,
|
||||
timesteps: Optional[int] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
cond_view: Optional[torch.Tensor] = None,
|
||||
cond_motion: Optional[torch.Tensor] = None,
|
||||
time_step: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
_, _, h, w = x.shape
|
||||
x_in = x
|
||||
spatial_context = None
|
||||
if exists(context):
|
||||
spatial_context = context
|
||||
|
||||
if self.use_spatial_context:
|
||||
assert (
|
||||
context.ndim == 3
|
||||
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
||||
|
||||
time_context = context
|
||||
time_context_first_timestep = time_context[::timesteps]
|
||||
time_context = repeat(
|
||||
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
|
||||
)
|
||||
elif time_context is not None and not self.use_spatial_context:
|
||||
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
|
||||
if time_context.ndim == 2:
|
||||
time_context = rearrange(time_context, "b c -> b 1 c")
|
||||
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
|
||||
if self.time_mix_legacy:
|
||||
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
||||
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(
|
||||
num_frames,
|
||||
self.in_channels,
|
||||
repeat_only=False,
|
||||
max_period=self.max_time_embed_period,
|
||||
)
|
||||
emb = self.time_mix_time_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
for it_, (block, mix_block) in enumerate(
|
||||
zip(self.transformer_blocks, self.time_mix_blocks)
|
||||
):
|
||||
# spatial attention
|
||||
x = block(
|
||||
x,
|
||||
context=spatial_context,
|
||||
time_step=time_step,
|
||||
name=name + '_' + str(it_)
|
||||
)
|
||||
|
||||
x_mix = x
|
||||
x_mix = x_mix + emb
|
||||
|
||||
# temporal attention
|
||||
x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
|
||||
if self.time_mix_legacy:
|
||||
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
||||
else:
|
||||
x = self.time_mixer(
|
||||
x_spatial=x,
|
||||
x_temporal=x_mix,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
|
||||
class PostHocSpatialTransformerWithTimeMixingAndMotion(SpatialTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
use_linear=False,
|
||||
context_dim=None,
|
||||
use_spatial_context=False,
|
||||
use_camera_emb=False,
|
||||
use_3d_attention=False,
|
||||
separate_motion_merge_factor=False,
|
||||
adm_in_channels=None,
|
||||
timesteps=None,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
merge_factor_motion: float = 0.5,
|
||||
apply_sigmoid_to_merge: bool = True,
|
||||
time_context_dim=None,
|
||||
motion_context_dim=None,
|
||||
ff_in=False,
|
||||
checkpoint=False,
|
||||
time_depth=1,
|
||||
attn_mode="softmax",
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
time_mix_legacy: bool = True,
|
||||
max_time_embed_period: int = 10000,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=depth,
|
||||
dropout=dropout,
|
||||
attn_type=attn_mode,
|
||||
use_checkpoint=checkpoint,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
disable_self_attn=disable_self_attn,
|
||||
)
|
||||
self.time_depth = time_depth
|
||||
self.depth = depth
|
||||
self.max_time_embed_period = max_time_embed_period
|
||||
self.use_camera_emb = use_camera_emb
|
||||
self.motion_context_dim = motion_context_dim
|
||||
self.use_3d_attention = use_3d_attention
|
||||
self.separate_motion_merge_factor = separate_motion_merge_factor
|
||||
|
||||
time_mix_d_head = d_head
|
||||
n_time_mix_heads = n_heads
|
||||
|
||||
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
||||
|
||||
inner_dim = n_heads * d_head
|
||||
if use_spatial_context:
|
||||
time_context_dim = context_dim
|
||||
|
||||
# Camera attention layer
|
||||
self.time_mix_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerTimeMixBlock(
|
||||
inner_dim,
|
||||
n_time_mix_heads,
|
||||
time_mix_d_head,
|
||||
dropout=dropout,
|
||||
context_dim=time_context_dim,
|
||||
timesteps=timesteps,
|
||||
checkpoint=checkpoint,
|
||||
ff_in=ff_in,
|
||||
inner_dim=time_mix_inner_dim,
|
||||
attn_mode=attn_mode,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
)
|
||||
|
||||
# Motion attention layer
|
||||
self.motion_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerTimeMixBlock(
|
||||
inner_dim,
|
||||
n_time_mix_heads,
|
||||
time_mix_d_head,
|
||||
dropout=dropout,
|
||||
context_dim=motion_context_dim,
|
||||
timesteps=timesteps,
|
||||
checkpoint=checkpoint,
|
||||
ff_in=ff_in,
|
||||
inner_dim=time_mix_inner_dim,
|
||||
attn_mode=attn_mode,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
)
|
||||
|
||||
assert len(self.time_mix_blocks) == len(self.transformer_blocks)
|
||||
|
||||
self.use_spatial_context = use_spatial_context
|
||||
self.in_channels = in_channels
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
time_embed_channels = adm_in_channels if self.use_camera_emb else self.in_channels
|
||||
# Camera view embedding
|
||||
self.time_mix_time_embed = nn.Sequential(
|
||||
linear(time_embed_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
# Motion time embedding
|
||||
self.time_mix_motion_embed = nn.Sequential(
|
||||
linear(self.in_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
|
||||
self.time_mix_legacy = time_mix_legacy
|
||||
if self.time_mix_legacy:
|
||||
if merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([merge_factor]))
|
||||
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor]))
|
||||
)
|
||||
elif merge_strategy == "fixed_with_images":
|
||||
self.mix_factor = None
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
||||
|
||||
self.get_alpha_fn = partial(
|
||||
get_alpha,
|
||||
merge_strategy,
|
||||
self.mix_factor,
|
||||
apply_sigmoid=apply_sigmoid_to_merge,
|
||||
is_attn=True,
|
||||
)
|
||||
else:
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor, merge_strategy=merge_strategy
|
||||
)
|
||||
if self.separate_motion_merge_factor:
|
||||
self.time_mixer_motion = AlphaBlender(
|
||||
alpha=merge_factor_motion, merge_strategy=merge_strategy
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
cam: Optional[torch.Tensor] = None,
|
||||
time_context: Optional[torch.Tensor] = None,
|
||||
timesteps: Optional[int] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
cond_view: Optional[torch.Tensor] = None,
|
||||
cond_motion: Optional[torch.Tensor] = None,
|
||||
time_step: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
# context: b t 1024
|
||||
# cond_view: b*v 4 h w
|
||||
# cond_motion: b*t 4 h w
|
||||
# image_only_indicator: b t*v
|
||||
b, t, d1 = context.shape # CLIP
|
||||
v, d2 = cond_view.shape[0]//b, cond_view.shape[1] # VAE
|
||||
_, c, h, w = x.shape
|
||||
|
||||
x_in = x
|
||||
spatial_context = None
|
||||
if exists(context):
|
||||
spatial_context = context
|
||||
|
||||
cond_view = torch.nn.functional.interpolate(cond_view, size=(h,w), mode="bilinear") # b*v d h w
|
||||
spatial_context = context[:,:,None].repeat(1,1,v,1).reshape(-1,1,d1) # (b*t*v) 1 d1
|
||||
camera_context = context[:,:,None].repeat(1,1,h*w,1).reshape(-1,1,d1) # (b*t*h*w) 1 d1
|
||||
motion_context = cond_view.permute(0,2,3,1).reshape(-1,1,d2) # (b*v*h*w) 1 d2
|
||||
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
|
||||
if self.time_mix_legacy:
|
||||
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
||||
|
||||
num_frames = torch.arange(t, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=b)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(
|
||||
num_frames,
|
||||
self.in_channels,
|
||||
repeat_only=False,
|
||||
max_period=self.max_time_embed_period,
|
||||
)
|
||||
emb_time = self.time_mix_motion_embed(t_emb)
|
||||
emb_time = emb_time[:, None, :] # b*t 1 c
|
||||
|
||||
if self.use_camera_emb:
|
||||
emb_view = self.time_mix_time_embed(cam.view(b,t,v,-1)[:,0].reshape(b*v,-1))
|
||||
emb_view = emb_view[:, None, :]
|
||||
else:
|
||||
num_views = torch.arange(v, device=x.device)
|
||||
num_views = repeat(num_views, "t -> b t", b=b)
|
||||
num_views = rearrange(num_views, "b t -> (b t)")
|
||||
v_emb = timestep_embedding(
|
||||
num_views,
|
||||
self.in_channels,
|
||||
repeat_only=False,
|
||||
max_period=self.max_time_embed_period,
|
||||
)
|
||||
emb_view = self.time_mix_time_embed(v_emb)
|
||||
emb_view = emb_view[:, None, :] # b*v 1 c
|
||||
|
||||
if self.use_3d_attention:
|
||||
emb_view = emb_view.repeat(1, h*w, 1).view(-1,1,c) # b*v*h*w 1 c
|
||||
|
||||
for it_, (block, time_block, mot_block) in enumerate(
|
||||
zip(self.transformer_blocks, self.time_mix_blocks, self.motion_blocks)
|
||||
):
|
||||
# Spatial attention
|
||||
x = block(
|
||||
x,
|
||||
context=spatial_context,
|
||||
)
|
||||
|
||||
# Camera attention
|
||||
if self.use_3d_attention:
|
||||
x = x.view(b, t, v, h*w, c).permute(0,2,3,1,4).reshape(-1,t,c) # b*v*h*w t c
|
||||
else:
|
||||
x = x.view(b, t, v, h*w, c).permute(0,2,1,3,4).reshape(b*v,-1,c) # b*v t*h*w c
|
||||
x_mix = x + emb_view
|
||||
x_mix = time_block(x_mix, context=camera_context, timesteps=v)
|
||||
if self.time_mix_legacy:
|
||||
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
||||
else:
|
||||
x = self.time_mixer(
|
||||
x_spatial=x,
|
||||
x_temporal=x_mix,
|
||||
image_only_indicator=torch.zeros_like(image_only_indicator[:,:1].repeat(1,x.shape[0]//b)),
|
||||
)
|
||||
|
||||
# Motion attention
|
||||
if self.use_3d_attention:
|
||||
x = x.view(b, v, h*w, t, c).permute(0,3,1,2,4).reshape(b*t,-1,c) # b*t v*h*w c
|
||||
else:
|
||||
x = x.view(b, v, t, h*w, c).permute(0,2,1,3,4).reshape(b*t,-1,c) # b*t v*h*w c
|
||||
x_mix = x + emb_time
|
||||
x_mix = mot_block(x_mix, context=motion_context, timesteps=t)
|
||||
if self.time_mix_legacy:
|
||||
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
||||
else:
|
||||
motion_mixer = self.time_mixer_motion if self.separate_motion_merge_factor else self.time_mixer
|
||||
x = motion_mixer(
|
||||
x_spatial=x,
|
||||
x_temporal=x_mix,
|
||||
image_only_indicator=torch.zeros_like(image_only_indicator[:,:1].repeat(1,x.shape[0]//b)),
|
||||
)
|
||||
|
||||
x = x.view(b, t, v, h*w, c).reshape(-1,h*w,c) # b*t*v h*w c
|
||||
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||