Compare commits
3 Commits
sp4d
...
sv3d_gradi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4a9d1f865 | ||
|
|
0ebca5c662 | ||
|
|
97db8edf64 |
2
.gitignore
vendored
@@ -12,5 +12,3 @@
|
|||||||
/outputs
|
/outputs
|
||||||
/build
|
/build
|
||||||
/src
|
/src
|
||||||
/.vscode
|
|
||||||
**/__pycache__/
|
|
||||||
|
|||||||
79
README.md
Executable file → Normal file
@@ -4,84 +4,6 @@
|
|||||||
|
|
||||||
## News
|
## News
|
||||||
|
|
||||||
|
|
||||||
**Nov 4, 2025**
|
|
||||||
- We are releasing **[Stable Part Diffusion 4D (SP4D)](https://huggingface.co/stabilityai/sp4d)**, a video-to-4D diffusion model for multi-view part video synthesis and animatable 3D asset generation. For research purposes:
|
|
||||||
- **SP4D** was trained to generate 48 RGB frames and part segmentation maps (4 video frames x 12 camera views) at 576x576 resolution, given a 4-frame input video of the same size, ideally consisting of white-background images of a moving object.
|
|
||||||
- Based on our previous 4D model [SV4D 2.0](https://huggingface.co/stabilityai/sv4d2.0), **SP4D** can simultaneously generate multi-view RGB videos as well as the corresponding kinematic part segmentations that are consistent across time and camera views.
|
|
||||||
- The generated part videos can then be used to create animation-ready 3D assets with part-aware rigging capabilities.
|
|
||||||
- Please check our [project page](https://stablepartdiffusion4d.github.io/), [arxiv paper](https://arxiv.org/pdf/2509.10687) and [video summary](https://www.youtube.com/watch?v=FXEFeh8tf0k) for more details.
|
|
||||||
|
|
||||||
**QUICKSTART** :
|
|
||||||
- Setup environment following the SV4D instructions and download [sp4d.safetensors](https://huggingface.co/stabilityai/sp4d) from HuggingFace into `checkpoints/`
|
|
||||||
- Run `python scripts/sampling/simple_video_sample_sp4d.py --input_path assets/sv4d_videos/cows.gif --output_folder outputs` to generate multi-view part videos given the sample input.
|
|
||||||
|
|
||||||
|
|
||||||
**May 20, 2025**
|
|
||||||
- We are releasing **[Stable Video 4D 2.0 (SV4D 2.0)](https://huggingface.co/stabilityai/sv4d2.0)**, an enhanced video-to-4D diffusion model for high-fidelity novel-view video synthesis and 4D asset generation. For research purposes:
|
|
||||||
- **SV4D 2.0** was trained to generate 48 frames (12 video frames x 4 camera views) at 576x576 resolution, given a 12-frame input video of the same size, ideally consisting of white-background images of a moving object.
|
|
||||||
- Compared to our previous 4D model [SV4D](https://huggingface.co/stabilityai/sv4d), **SV4D 2.0** can generate videos with higher fidelity, sharper details during motion, and better spatio-temporal consistency. It also generalizes much better to real-world videos. Moreover, it does not rely on refernce multi-view of the first frame generated by SV3D, making it more robust to self-occlusions.
|
|
||||||
- To generate longer novel-view videos, we autoregressively generate 12 frames at a time and use the previous generation as conditioning views for the remaining frames.
|
|
||||||
- Please check our [project page](https://sv4d20.github.io), [arxiv paper](https://arxiv.org/pdf/2503.16396) and [video summary](https://www.youtube.com/watch?v=dtqj-s50ynU) for more details.
|
|
||||||
|
|
||||||
**QUICKSTART** :
|
|
||||||
- `python scripts/sampling/simple_video_sample_4d2.py --input_path assets/sv4d_videos/camel.gif --output_folder outputs` (after downloading [sv4d2.safetensors](https://huggingface.co/stabilityai/sv4d2.0) from HuggingFace into `checkpoints/`)
|
|
||||||
|
|
||||||
To run **SV4D 2.0** on a single input video of 21 frames:
|
|
||||||
- Download SV4D 2.0 model (`sv4d2.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d2.0) to `checkpoints/`: `huggingface-cli download stabilityai/sv4d2.0 sv4d2.safetensors --local-dir checkpoints`
|
|
||||||
- Run inference: `python scripts/sampling/simple_video_sample_4d2.py --input_path <path/to/video>`
|
|
||||||
- `input_path` : The input video `<path/to/video>` can be
|
|
||||||
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/camel.gif`, or
|
|
||||||
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
|
|
||||||
- a file name pattern matching images of video frames.
|
|
||||||
- `num_steps` : default is 50, can decrease to it to shorten sampling time.
|
|
||||||
- `elevations_deg` : specified elevations (reletive to input view), default is 0.0 (same as input view).
|
|
||||||
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.
|
|
||||||
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- We also train a 8-view model that generates 5 frames x 8 views at a time (same as SV4D).
|
|
||||||
- Download the model from huggingface: `huggingface-cli download stabilityai/sv4d2.0 sv4d2_8views.safetensors --local-dir checkpoints`
|
|
||||||
- Run inference: `python scripts/sampling/simple_video_sample_4d2.py --model_path checkpoints/sv4d2_8views.safetensors --input_path assets/sv4d_videos/chest.gif --output_folder outputs`
|
|
||||||
- The 5x8 model takes 5 frames of input at a time. But the inference scripts for both model take 21-frame video as input by default (same as SV3D and SV4D), we run the model autoregressively until we generate 21 frames.
|
|
||||||
- Install dependencies before running:
|
|
||||||
```
|
|
||||||
python3.10 -m venv .generativemodels
|
|
||||||
source .generativemodels/bin/activate
|
|
||||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # check CUDA version
|
|
||||||
pip3 install -r requirements/pt2.txt
|
|
||||||
pip3 install .
|
|
||||||
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
|
|
||||||
```
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
|
|
||||||
**July 24, 2024**
|
|
||||||
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
|
|
||||||
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
|
|
||||||
- To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.
|
|
||||||
- To run the community-build gradio demo locally, run `python -m scripts.demo.gradio_app_sv4d`.
|
|
||||||
- Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details.
|
|
||||||
|
|
||||||
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`)
|
|
||||||
|
|
||||||
To run **SV4D** on a single input video of 21 frames:
|
|
||||||
- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`
|
|
||||||
- Run `python scripts/sampling/simple_video_sample_4d.py --input_path <path/to/video>`
|
|
||||||
- `input_path` : The input video `<path/to/video>` can be
|
|
||||||
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/test_video1.mp4`, or
|
|
||||||
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
|
|
||||||
- a file name pattern matching images of video frames.
|
|
||||||
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
|
|
||||||
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
|
|
||||||
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
|
|
||||||
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.
|
|
||||||
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
|
|
||||||
**March 18, 2024**
|
**March 18, 2024**
|
||||||
- We are releasing **[SV3D](https://huggingface.co/stabilityai/sv3d)**, an image-to-video model for novel multi-view synthesis, for research purposes:
|
- We are releasing **[SV3D](https://huggingface.co/stabilityai/sv3d)**, an image-to-video model for novel multi-view synthesis, for research purposes:
|
||||||
- **SV3D** was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object.
|
- **SV3D** was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object.
|
||||||
@@ -216,7 +138,6 @@ This is assuming you have navigated to the `generative-models` root after clonin
|
|||||||
# install required packages from pypi
|
# install required packages from pypi
|
||||||
python3 -m venv .pt2
|
python3 -m venv .pt2
|
||||||
source .pt2/bin/activate
|
source .pt2/bin/activate
|
||||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
|
||||||
pip3 install -r requirements/pt2.txt
|
pip3 install -r requirements/pt2.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
BIN
assets/sv4d.gif
|
Before Width: | Height: | Size: 8.0 MiB |
BIN
assets/sv4d2.gif
|
Before Width: | Height: | Size: 9.7 MiB |
|
Before Width: | Height: | Size: 2.2 MiB |
|
Before Width: | Height: | Size: 638 KiB |
|
Before Width: | Height: | Size: 2.2 MiB |
|
Before Width: | Height: | Size: 1.9 MiB |
|
Before Width: | Height: | Size: 1.4 MiB |
|
Before Width: | Height: | Size: 2.2 MiB |
|
Before Width: | Height: | Size: 1.7 MiB |
|
Before Width: | Height: | Size: 1.2 MiB |
|
Before Width: | Height: | Size: 2.1 MiB |
|
Before Width: | Height: | Size: 446 KiB |
|
Before Width: | Height: | Size: 1.6 MiB |
|
Before Width: | Height: | Size: 1.4 MiB |
|
Before Width: | Height: | Size: 946 KiB |
|
Before Width: | Height: | Size: 1.5 MiB |
|
Before Width: | Height: | Size: 2.4 MiB |
@@ -5,16 +5,13 @@ einops>=0.6.1
|
|||||||
fairscale>=0.4.13
|
fairscale>=0.4.13
|
||||||
fire>=0.5.0
|
fire>=0.5.0
|
||||||
fsspec>=2023.6.0
|
fsspec>=2023.6.0
|
||||||
imageio[ffmpeg]
|
|
||||||
imageio[pyav]
|
|
||||||
invisible-watermark>=0.2.0
|
invisible-watermark>=0.2.0
|
||||||
kornia==0.6.9
|
kornia==0.6.9
|
||||||
matplotlib>=3.7.2
|
matplotlib>=3.7.2
|
||||||
natsort>=8.4.0
|
natsort>=8.4.0
|
||||||
ninja>=1.11.1
|
ninja>=1.11.1
|
||||||
numpy==2.1
|
numpy>=1.24.4
|
||||||
omegaconf>=2.3.0
|
omegaconf>=2.3.0
|
||||||
onnxruntime
|
|
||||||
open-clip-torch>=2.20.0
|
open-clip-torch>=2.20.0
|
||||||
opencv-python==4.6.0.66
|
opencv-python==4.6.0.66
|
||||||
pandas>=2.0.3
|
pandas>=2.0.3
|
||||||
|
|||||||
@@ -1,496 +0,0 @@
|
|||||||
# Adding this at the very top of app.py to make 'generative-models' directory discoverable
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models"))
|
|
||||||
|
|
||||||
from glob import glob
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
import torchvision
|
|
||||||
|
|
||||||
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
|
||||||
from scripts.demo.sv4d_helpers import (
|
|
||||||
decode_latents,
|
|
||||||
load_model,
|
|
||||||
initial_model_load,
|
|
||||||
read_video,
|
|
||||||
run_img2vid,
|
|
||||||
prepare_inputs,
|
|
||||||
do_sample_per_step,
|
|
||||||
sample_sv3d,
|
|
||||||
save_video,
|
|
||||||
preprocess_video,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# the tmp path, if /tmp/gradio is not writable, change it to a writable path
|
|
||||||
# os.environ["GRADIO_TEMP_DIR"] = "gradio_tmp"
|
|
||||||
|
|
||||||
version = "sv4d" # replace with 'sv3d_p' or 'sv3d_u' for other models
|
|
||||||
|
|
||||||
# Define the repo, local directory and filename
|
|
||||||
repo_id = "stabilityai/sv4d"
|
|
||||||
filename = f"{version}.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors"
|
|
||||||
local_dir = "checkpoints"
|
|
||||||
local_ckpt_path = os.path.join(local_dir, filename)
|
|
||||||
|
|
||||||
# Check if the file already exists
|
|
||||||
if not os.path.exists(local_ckpt_path):
|
|
||||||
# If the file doesn't exist, download it
|
|
||||||
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
|
|
||||||
print("File downloaded. (sv4d)")
|
|
||||||
else:
|
|
||||||
print("File already exists. No need to download. (sv4d)")
|
|
||||||
|
|
||||||
device = "cuda"
|
|
||||||
max_64_bit_int = 2**63 - 1
|
|
||||||
|
|
||||||
num_frames = 21
|
|
||||||
num_steps = 20
|
|
||||||
model_config = f"scripts/sampling/configs/{version}.yaml"
|
|
||||||
|
|
||||||
# Set model config
|
|
||||||
T = 5 # number of frames per sample
|
|
||||||
V = 8 # number of views per sample
|
|
||||||
F = 8 # vae factor to downsize image->latent
|
|
||||||
C = 4
|
|
||||||
H, W = 576, 576
|
|
||||||
n_frames = 21 # number of input and output video frames
|
|
||||||
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
|
||||||
n_views_sv3d = 21
|
|
||||||
subsampled_views = np.array(
|
|
||||||
[0, 2, 5, 7, 9, 12, 14, 16, 19]
|
|
||||||
) # subsample (V+1=)9 (uniform) views from 21 SV3D views
|
|
||||||
|
|
||||||
version_dict = {
|
|
||||||
"T": T * V,
|
|
||||||
"H": H,
|
|
||||||
"W": W,
|
|
||||||
"C": C,
|
|
||||||
"f": F,
|
|
||||||
"options": {
|
|
||||||
"discretization": 1,
|
|
||||||
"cfg": 3,
|
|
||||||
"sigma_min": 0.002,
|
|
||||||
"sigma_max": 700.0,
|
|
||||||
"rho": 7.0,
|
|
||||||
"guider": 5,
|
|
||||||
"num_steps": num_steps,
|
|
||||||
"force_uc_zero_embeddings": [
|
|
||||||
"cond_frames",
|
|
||||||
"cond_frames_without_noise",
|
|
||||||
"cond_view",
|
|
||||||
"cond_motion",
|
|
||||||
],
|
|
||||||
"additional_guider_kwargs": {
|
|
||||||
"additional_cond_keys": ["cond_view", "cond_motion"]
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Load SV4D model
|
|
||||||
model, filter = load_model(
|
|
||||||
model_config,
|
|
||||||
device,
|
|
||||||
version_dict["T"],
|
|
||||||
num_steps,
|
|
||||||
)
|
|
||||||
model = initial_model_load(model)
|
|
||||||
|
|
||||||
# -----------sv3d config and model loading----------------
|
|
||||||
# if version == "sv3d_u":
|
|
||||||
sv3d_model_config = "scripts/sampling/configs/sv3d_u.yaml"
|
|
||||||
# elif version == "sv3d_p":
|
|
||||||
# sv3d_model_config = "scripts/sampling/configs/sv3d_p.yaml"
|
|
||||||
# else:
|
|
||||||
# raise ValueError(f"Version {version} does not exist.")
|
|
||||||
|
|
||||||
# Define the repo, local directory and filename
|
|
||||||
repo_id = "stabilityai/sv3d"
|
|
||||||
filename = f"sv3d_u.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors"
|
|
||||||
local_dir = "checkpoints"
|
|
||||||
local_ckpt_path = os.path.join(local_dir, filename)
|
|
||||||
|
|
||||||
# Check if the file already exists
|
|
||||||
if not os.path.exists(local_ckpt_path):
|
|
||||||
# If the file doesn't exist, download it
|
|
||||||
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
|
|
||||||
print("File downloaded. (sv3d)")
|
|
||||||
else:
|
|
||||||
print("File already exists. No need to download. (sv3d)")
|
|
||||||
|
|
||||||
# load sv3d model
|
|
||||||
sv3d_model, filter = load_model(
|
|
||||||
sv3d_model_config,
|
|
||||||
device,
|
|
||||||
21,
|
|
||||||
num_steps,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
sv3d_model = initial_model_load(sv3d_model)
|
|
||||||
# ------------------
|
|
||||||
|
|
||||||
def sample_anchor(
|
|
||||||
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
|
||||||
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
|
||||||
num_steps: int = 20,
|
|
||||||
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
|
|
||||||
fps_id: int = 6,
|
|
||||||
motion_bucket_id: int = 127,
|
|
||||||
cond_aug: float = 1e-5,
|
|
||||||
device: str = "cuda",
|
|
||||||
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
|
|
||||||
azimuths_deg: Optional[List[float]] = None,
|
|
||||||
verbose: Optional[bool] = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
|
||||||
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
|
||||||
"""
|
|
||||||
output_folder = os.path.dirname(input_path)
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
|
||||||
|
|
||||||
# Read input video frames i.e. images at view 0
|
|
||||||
print(f"Reading {input_path}")
|
|
||||||
images_v0 = read_video(
|
|
||||||
input_path,
|
|
||||||
n_frames=n_frames,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get camera viewpoints
|
|
||||||
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
|
||||||
elevations_deg = [elevations_deg] * n_views_sv3d
|
|
||||||
assert (
|
|
||||||
len(elevations_deg) == n_views_sv3d
|
|
||||||
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
|
|
||||||
if azimuths_deg is None:
|
|
||||||
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
|
|
||||||
assert (
|
|
||||||
len(azimuths_deg) == n_views_sv3d
|
|
||||||
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
|
|
||||||
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
|
||||||
azimuths_rad = np.array(
|
|
||||||
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sample multi-view images of the first frame using SV3D i.e. images at time 0
|
|
||||||
sv3d_model.sampler.num_steps = num_steps
|
|
||||||
print("sv3d_model.sampler.num_steps", sv3d_model.sampler.num_steps)
|
|
||||||
images_t0 = sample_sv3d(
|
|
||||||
images_v0[0],
|
|
||||||
n_views_sv3d,
|
|
||||||
num_steps,
|
|
||||||
sv3d_version,
|
|
||||||
fps_id,
|
|
||||||
motion_bucket_id,
|
|
||||||
cond_aug,
|
|
||||||
decoding_t,
|
|
||||||
device,
|
|
||||||
polars_rad,
|
|
||||||
azimuths_rad,
|
|
||||||
verbose,
|
|
||||||
sv3d_model,
|
|
||||||
)
|
|
||||||
images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame
|
|
||||||
|
|
||||||
sv3d_file = os.path.join(output_folder, "t000.mp4")
|
|
||||||
save_video(sv3d_file, images_t0.unsqueeze(1))
|
|
||||||
|
|
||||||
for emb in model.conditioner.embedders:
|
|
||||||
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
|
||||||
emb.en_and_decode_n_samples_a_time = encoding_t
|
|
||||||
model.en_and_decode_n_samples_a_time = decoding_t
|
|
||||||
# Initialize image matrix
|
|
||||||
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
|
||||||
for i, v in enumerate(subsampled_views):
|
|
||||||
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
|
||||||
for t in range(n_frames):
|
|
||||||
img_matrix[t][0] = images_v0[t]
|
|
||||||
|
|
||||||
# Interleaved sampling for anchor frames
|
|
||||||
t0, v0 = 0, 0
|
|
||||||
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
|
|
||||||
view_indices = np.arange(V) + 1
|
|
||||||
print(f"Sampling anchor frames {frame_indices}")
|
|
||||||
image = img_matrix[t0][v0]
|
|
||||||
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
|
||||||
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
|
||||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
|
||||||
model.sampler.num_steps = num_steps
|
|
||||||
version_dict["options"]["num_steps"] = num_steps
|
|
||||||
samples = run_img2vid(
|
|
||||||
version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t
|
|
||||||
)
|
|
||||||
samples = samples.view(T, V, 3, H, W)
|
|
||||||
for i, t in enumerate(frame_indices):
|
|
||||||
for j, v in enumerate(view_indices):
|
|
||||||
if img_matrix[t][v] is None:
|
|
||||||
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
|
||||||
|
|
||||||
# concat video
|
|
||||||
grid_list = []
|
|
||||||
for t in frame_indices:
|
|
||||||
imgs_view = torch.cat(img_matrix[t])
|
|
||||||
grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0))
|
|
||||||
# save output videos
|
|
||||||
anchor_vis_file = os.path.join(output_folder, "anchor_vis.mp4")
|
|
||||||
save_video(anchor_vis_file, grid_list, fps=3)
|
|
||||||
anchor_file = os.path.join(output_folder, "anchor.mp4")
|
|
||||||
image_list = samples.view(T*V, 3, H, W).unsqueeze(1) * 2 - 1
|
|
||||||
save_video(anchor_file, image_list)
|
|
||||||
|
|
||||||
return sv3d_file, anchor_vis_file, anchor_file
|
|
||||||
|
|
||||||
|
|
||||||
def sample_all(
|
|
||||||
input_path: str = "inputs/test_video1.mp4", # Can either be video file or folder with image files
|
|
||||||
sv3d_path: str = "outputs/sv4d/000000_t000.mp4",
|
|
||||||
anchor_path: str = "outputs/sv4d/000000_anchor.mp4",
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
num_steps: int = 20,
|
|
||||||
device: str = "cuda",
|
|
||||||
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
|
|
||||||
azimuths_deg: Optional[List[float]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
|
||||||
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
|
||||||
"""
|
|
||||||
output_folder = os.path.dirname(input_path)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
|
||||||
|
|
||||||
# Read input video frames i.e. images at view 0
|
|
||||||
print(f"Reading {input_path}")
|
|
||||||
images_v0 = read_video(
|
|
||||||
input_path,
|
|
||||||
n_frames=n_frames,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
images_t0 = read_video(
|
|
||||||
sv3d_path,
|
|
||||||
n_frames=n_views_sv3d,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get camera viewpoints
|
|
||||||
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
|
||||||
elevations_deg = [elevations_deg] * n_views_sv3d
|
|
||||||
assert (
|
|
||||||
len(elevations_deg) == n_views_sv3d
|
|
||||||
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
|
|
||||||
if azimuths_deg is None:
|
|
||||||
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
|
|
||||||
assert (
|
|
||||||
len(azimuths_deg) == n_views_sv3d
|
|
||||||
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
|
|
||||||
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
|
||||||
azimuths_rad = np.array(
|
|
||||||
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize image matrix
|
|
||||||
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
|
||||||
for i, v in enumerate(subsampled_views):
|
|
||||||
img_matrix[0][i] = images_t0[v]
|
|
||||||
for t in range(n_frames):
|
|
||||||
img_matrix[t][0] = images_v0[t]
|
|
||||||
|
|
||||||
# load interleaved sampling for anchor frames
|
|
||||||
t0, v0 = 0, 0
|
|
||||||
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
|
|
||||||
view_indices = np.arange(V) + 1
|
|
||||||
|
|
||||||
anchor_frames = read_video(
|
|
||||||
anchor_path,
|
|
||||||
n_frames=T * V,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
anchor_frames = torch.cat(anchor_frames).view(T, V, 3, H, W)
|
|
||||||
for i, t in enumerate(frame_indices):
|
|
||||||
for j, v in enumerate(view_indices):
|
|
||||||
if img_matrix[t][v] is None:
|
|
||||||
img_matrix[t][v] = anchor_frames[i, j][None]
|
|
||||||
|
|
||||||
# Dense sampling for the rest
|
|
||||||
print(f"Sampling dense frames:")
|
|
||||||
for t0 in np.arange(0, n_frames - 1, T - 1): # [0, 4, 8, 12, 16]
|
|
||||||
frame_indices = t0 + np.arange(T)
|
|
||||||
print(f"Sampling dense frames {frame_indices}")
|
|
||||||
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
|
|
||||||
|
|
||||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
|
||||||
|
|
||||||
# alternate between forward and backward conditioning
|
|
||||||
forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(
|
|
||||||
frame_indices,
|
|
||||||
img_matrix,
|
|
||||||
v0,
|
|
||||||
view_indices,
|
|
||||||
model,
|
|
||||||
version_dict,
|
|
||||||
seed,
|
|
||||||
polars,
|
|
||||||
azims
|
|
||||||
)
|
|
||||||
|
|
||||||
for step in range(num_steps):
|
|
||||||
if step % 2 == 1:
|
|
||||||
c, uc, additional_model_inputs, sampler = forward_inputs
|
|
||||||
frame_indices = forward_frame_indices
|
|
||||||
else:
|
|
||||||
c, uc, additional_model_inputs, sampler = backward_inputs
|
|
||||||
frame_indices = backward_frame_indices
|
|
||||||
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
|
|
||||||
|
|
||||||
samples = do_sample_per_step(
|
|
||||||
model,
|
|
||||||
sampler,
|
|
||||||
noisy_latents,
|
|
||||||
c,
|
|
||||||
uc,
|
|
||||||
step,
|
|
||||||
additional_model_inputs,
|
|
||||||
)
|
|
||||||
samples = samples.view(T, V, C, H // F, W // F)
|
|
||||||
for i, t in enumerate(frame_indices):
|
|
||||||
for j, v in enumerate(view_indices):
|
|
||||||
latent_matrix[t, v] = samples[i, j]
|
|
||||||
|
|
||||||
img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)
|
|
||||||
|
|
||||||
|
|
||||||
# concat video
|
|
||||||
grid_list = []
|
|
||||||
for t in range(n_frames):
|
|
||||||
imgs_view = torch.cat(img_matrix[t])
|
|
||||||
grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0))
|
|
||||||
# save output videos
|
|
||||||
vid_file = os.path.join(output_folder, "sv4d_final.mp4")
|
|
||||||
save_video(vid_file, grid_list)
|
|
||||||
|
|
||||||
return vid_file, seed
|
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
|
||||||
gr.Markdown(
|
|
||||||
"""# Demo for SV4D from Stability AI ([model](https://huggingface.co/stabilityai/sv4d), [news](https://stability.ai/news/stable-video-4d))
|
|
||||||
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv4d/blob/main/LICENSE.md)): generate 8 novel view videos from a single-view video (with white background).
|
|
||||||
#### It takes ~45s to generate anchor frames and another ~160s to generate full results (21 frames).
|
|
||||||
#### Hints for improving performance:
|
|
||||||
- Use a white background;
|
|
||||||
- Make the object in the center of the image;
|
|
||||||
- The SV4D process the first 21 frames of the uploaded video. Gradio provides a nice option of trimming the uploaded video if needed.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
input_video = gr.Video(label="Upload your video")
|
|
||||||
generate_btn = gr.Button("Step 1: generate 8 novel view videos (5 anchor frames each)")
|
|
||||||
interpolate_btn = gr.Button("Step 2: Extend novel view videos to 21 frames")
|
|
||||||
with gr.Column():
|
|
||||||
anchor_video = gr.Video(label="SV4D outputs (anchor frames)")
|
|
||||||
sv3d_video = gr.Video(label="SV3D outputs", interactive=False)
|
|
||||||
with gr.Column():
|
|
||||||
sv4d_interpolated_video = gr.Video(label="SV4D outputs (21 frames)")
|
|
||||||
|
|
||||||
with gr.Accordion("Advanced options", open=False):
|
|
||||||
seed = gr.Slider(
|
|
||||||
label="Seed",
|
|
||||||
value=23,
|
|
||||||
# randomize=True,
|
|
||||||
minimum=0,
|
|
||||||
maximum=100,
|
|
||||||
step=1,
|
|
||||||
)
|
|
||||||
encoding_t = gr.Slider(
|
|
||||||
label="Encode n frames at a time",
|
|
||||||
info="Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.",
|
|
||||||
value=8,
|
|
||||||
minimum=1,
|
|
||||||
maximum=40,
|
|
||||||
)
|
|
||||||
decoding_t = gr.Slider(
|
|
||||||
label="Decode n frames at a time",
|
|
||||||
info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.",
|
|
||||||
value=4,
|
|
||||||
minimum=1,
|
|
||||||
maximum=14,
|
|
||||||
)
|
|
||||||
denoising_steps = gr.Slider(
|
|
||||||
label="Number of denoising steps",
|
|
||||||
info="Increase will improve the performance but needs more time.",
|
|
||||||
value=20,
|
|
||||||
minimum=10,
|
|
||||||
maximum=50,
|
|
||||||
step=1,
|
|
||||||
)
|
|
||||||
remove_bg = gr.Checkbox(
|
|
||||||
label="Remove background",
|
|
||||||
info="We use rembg. Users can check the alternative way: SAM2 (https://github.com/facebookresearch/segment-anything-2)",
|
|
||||||
)
|
|
||||||
|
|
||||||
input_video.upload(fn=preprocess_video, inputs=[input_video, remove_bg], outputs=input_video, queue=False)
|
|
||||||
|
|
||||||
with gr.Row(visible=False):
|
|
||||||
anchor_frames = gr.Video()
|
|
||||||
|
|
||||||
generate_btn.click(
|
|
||||||
fn=sample_anchor,
|
|
||||||
inputs=[input_video, seed, encoding_t, decoding_t, denoising_steps],
|
|
||||||
outputs=[sv3d_video, anchor_video, anchor_frames],
|
|
||||||
api_name="SV4D output (5 frames)",
|
|
||||||
)
|
|
||||||
|
|
||||||
interpolate_btn.click(
|
|
||||||
fn=sample_all,
|
|
||||||
inputs=[input_video, sv3d_video, anchor_frames, seed, denoising_steps],
|
|
||||||
outputs=[sv4d_interpolated_video, seed],
|
|
||||||
api_name="SV4D interpolation (21 frames)",
|
|
||||||
)
|
|
||||||
|
|
||||||
examples = gr.Examples(
|
|
||||||
fn=preprocess_video,
|
|
||||||
examples=[
|
|
||||||
"./assets/sv4d_videos/test_video1.mp4",
|
|
||||||
"./assets/sv4d_videos/test_video2.mp4",
|
|
||||||
"./assets/sv4d_videos/green_robot.mp4",
|
|
||||||
"./assets/sv4d_videos/dolphin.mp4",
|
|
||||||
"./assets/sv4d_videos/lucia_v000.mp4",
|
|
||||||
"./assets/sv4d_videos/snowboard_v000.mp4",
|
|
||||||
"./assets/sv4d_videos/stroller_v000.mp4",
|
|
||||||
"./assets/sv4d_videos/human5.mp4",
|
|
||||||
"./assets/sv4d_videos/bunnyman.mp4",
|
|
||||||
"./assets/sv4d_videos/hiphop_parrot.mp4",
|
|
||||||
"./assets/sv4d_videos/guppie_v0.mp4",
|
|
||||||
"./assets/sv4d_videos/wave_hello.mp4",
|
|
||||||
"./assets/sv4d_videos/pistol_v0.mp4",
|
|
||||||
"./assets/sv4d_videos/human7.mp4",
|
|
||||||
"./assets/sv4d_videos/monkey.mp4",
|
|
||||||
"./assets/sv4d_videos/train_v0.mp4",
|
|
||||||
],
|
|
||||||
inputs=[input_video],
|
|
||||||
run_on_click=True,
|
|
||||||
outputs=[input_video],
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
demo.queue(max_size=20)
|
|
||||||
demo.launch(share=True)
|
|
||||||
|
|
||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
def generate_dynamic_cycle_xy_values(
|
def generate_dynamic_cycle_xy_values(
|
||||||
@@ -74,8 +75,9 @@ def gen_dynamic_loop(length=21, elev_deg=0):
|
|||||||
return np.roll(azim_rad, -1), np.roll(elev_rad, -1)
|
return np.roll(azim_rad, -1), np.roll(elev_rad, -1)
|
||||||
|
|
||||||
|
|
||||||
def plot_3D(azim, polar, save_path, dynamic=True):
|
def plot_3D(azim, polar, save_path=None, dynamic=True):
|
||||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
if save_path is not None:
|
||||||
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
elev = np.deg2rad(90) - polar
|
elev = np.deg2rad(90) - polar
|
||||||
fig = plt.figure(figsize=(5, 5))
|
fig = plt.figure(figsize=(5, 5))
|
||||||
ax = fig.add_subplot(projection="3d")
|
ax = fig.add_subplot(projection="3d")
|
||||||
@@ -98,7 +100,20 @@ def plot_3D(azim, polar, save_path, dynamic=True):
|
|||||||
ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1])
|
ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1])
|
||||||
ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k")
|
ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k")
|
||||||
ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k")
|
ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k")
|
||||||
ax.view_init(elev=30, azim=-20, roll=0)
|
ax.view_init(elev=40, azim=-20, roll=0)
|
||||||
plt.savefig(save_path, bbox_inches="tight")
|
ax.xaxis.set_ticklabels([])
|
||||||
|
ax.yaxis.set_ticklabels([])
|
||||||
|
ax.zaxis.set_ticklabels([])
|
||||||
|
if save_path is None:
|
||||||
|
fig.canvas.draw()
|
||||||
|
lst = list(fig.canvas.get_width_height())
|
||||||
|
lst.append(3)
|
||||||
|
image = Image.fromarray(
|
||||||
|
np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(lst)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
plt.savefig(save_path, bbox_inches="tight")
|
||||||
plt.clf()
|
plt.clf()
|
||||||
plt.close()
|
plt.close()
|
||||||
|
if save_path is None:
|
||||||
|
return image
|
||||||
|
|||||||
340
scripts/demo/sv3d_p_gradio.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
# Adding this at the very top of app.py to make 'generative-models' directory discoverable
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
import random
|
||||||
|
from glob import glob
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gradio as gr
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image
|
||||||
|
from rembg import remove
|
||||||
|
from scripts.demo.sv3d_helpers import gen_dynamic_loop, plot_3D
|
||||||
|
from scripts.sampling.simple_video_sample import (
|
||||||
|
get_batch,
|
||||||
|
get_unique_embedder_keys_from_conditioner,
|
||||||
|
load_model,
|
||||||
|
)
|
||||||
|
from sgm.inference.helpers import embed_watermark
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
|
|
||||||
|
version = "sv3d_p" # replace with 'sv3d_p' or 'sv3d_u' for other models
|
||||||
|
|
||||||
|
# Define the repo, local directory and filename
|
||||||
|
repo_id = "stabilityai/sv3d"
|
||||||
|
filename = f"{version}.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors"
|
||||||
|
local_dir = "checkpoints"
|
||||||
|
local_ckpt_path = os.path.join(local_dir, filename)
|
||||||
|
|
||||||
|
# Check if the file already exists
|
||||||
|
if not os.path.exists(local_ckpt_path):
|
||||||
|
# If the file doesn't exist, download it
|
||||||
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
|
||||||
|
print("File downloaded.")
|
||||||
|
else:
|
||||||
|
print("File already exists. No need to download.")
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
max_64_bit_int = 2**63 - 1
|
||||||
|
|
||||||
|
num_frames = 21
|
||||||
|
num_steps = 50
|
||||||
|
model_config = f"scripts/sampling/configs/{version}.yaml"
|
||||||
|
|
||||||
|
model, filter = load_model(
|
||||||
|
model_config,
|
||||||
|
device,
|
||||||
|
num_frames,
|
||||||
|
num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
polars_rad = np.array([np.deg2rad(90 - 10.0)] * num_frames)
|
||||||
|
azimuths_rad = np.linspace(0, 2 * np.pi, num_frames + 1)[1:]
|
||||||
|
|
||||||
|
|
||||||
|
def gen_orbit(orbit, elev_deg):
|
||||||
|
if orbit == "dynamic":
|
||||||
|
azim_rad, elev_rad = gen_dynamic_loop(length=num_frames, elev_deg=elev_deg)
|
||||||
|
polars_rad = np.deg2rad(90) - elev_rad
|
||||||
|
azimuths_rad = azim_rad
|
||||||
|
else:
|
||||||
|
polars_rad = np.array([np.deg2rad(90 - elev_deg)] * num_frames)
|
||||||
|
azimuths_rad = np.linspace(0, 2 * np.pi, num_frames + 1)[1:]
|
||||||
|
|
||||||
|
plot = plot_3D(
|
||||||
|
azim=azimuths_rad,
|
||||||
|
polar=polars_rad,
|
||||||
|
save_path=None,
|
||||||
|
dynamic=(orbit == "dynamic"),
|
||||||
|
)
|
||||||
|
return plot
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
randomize_seed: bool = True,
|
||||||
|
decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
device: str = "cuda",
|
||||||
|
output_folder: str = None,
|
||||||
|
image_frame_ratio: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
|
||||||
|
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
||||||
|
"""
|
||||||
|
if randomize_seed:
|
||||||
|
seed = random.randint(0, max_64_bit_int)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
path = Path(input_path)
|
||||||
|
all_img_paths = []
|
||||||
|
if path.is_file():
|
||||||
|
if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
|
||||||
|
all_img_paths = [input_path]
|
||||||
|
else:
|
||||||
|
raise ValueError("Path is not valid image file.")
|
||||||
|
elif path.is_dir():
|
||||||
|
all_img_paths = sorted(
|
||||||
|
[
|
||||||
|
f
|
||||||
|
for f in path.iterdir()
|
||||||
|
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if len(all_img_paths) == 0:
|
||||||
|
raise ValueError("Folder does not contain any images.")
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
for input_img_path in all_img_paths:
|
||||||
|
|
||||||
|
image = Image.open(input_img_path)
|
||||||
|
if image.mode == "RGBA":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# remove bg
|
||||||
|
image.thumbnail([768, 768], Image.Resampling.LANCZOS)
|
||||||
|
image = remove(image.convert("RGBA"), alpha_matting=True)
|
||||||
|
|
||||||
|
# resize object in frame
|
||||||
|
image_arr = np.array(image)
|
||||||
|
in_w, in_h = image_arr.shape[:2]
|
||||||
|
ret, mask = cv2.threshold(
|
||||||
|
np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY
|
||||||
|
)
|
||||||
|
x, y, w, h = cv2.boundingRect(mask)
|
||||||
|
max_size = max(w, h)
|
||||||
|
side_len = (
|
||||||
|
int(max_size / image_frame_ratio) if image_frame_ratio is not None else in_w
|
||||||
|
)
|
||||||
|
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
|
||||||
|
center = side_len // 2
|
||||||
|
padded_image[
|
||||||
|
center - h // 2 : center - h // 2 + h,
|
||||||
|
center - w // 2 : center - w // 2 + w,
|
||||||
|
] = image_arr[y : y + h, x : x + w]
|
||||||
|
# resize frame to 576x576
|
||||||
|
rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS)
|
||||||
|
# white bg
|
||||||
|
rgba_arr = np.array(rgba) / 255.0
|
||||||
|
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
|
||||||
|
input_image = Image.fromarray((rgb * 255).astype(np.uint8))
|
||||||
|
|
||||||
|
image = ToTensor()(input_image)
|
||||||
|
image = image * 2.0 - 1.0
|
||||||
|
|
||||||
|
image = image.unsqueeze(0).to(device)
|
||||||
|
H, W = image.shape[2:]
|
||||||
|
assert image.shape[1] == 3
|
||||||
|
F = 8
|
||||||
|
C = 4
|
||||||
|
shape = (num_frames, C, H // F, W // F)
|
||||||
|
if (H, W) != (576, 576) and "sv3d" in version:
|
||||||
|
print(
|
||||||
|
"WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576."
|
||||||
|
)
|
||||||
|
|
||||||
|
cond_aug = 1e-5
|
||||||
|
|
||||||
|
value_dict = {}
|
||||||
|
value_dict["cond_aug"] = cond_aug
|
||||||
|
value_dict["cond_frames_without_noise"] = image
|
||||||
|
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
|
||||||
|
value_dict["cond_aug"] = cond_aug
|
||||||
|
|
||||||
|
value_dict["polars_rad"] = polars_rad
|
||||||
|
value_dict["azimuths_rad"] = azimuths_rad
|
||||||
|
|
||||||
|
output_folder = output_folder or f"outputs/gradio/{version}"
|
||||||
|
cond_aug = 1e-5
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.autocast(device):
|
||||||
|
batch, batch_uc = get_batch(
|
||||||
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
[1, num_frames],
|
||||||
|
T=num_frames,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
|
batch,
|
||||||
|
batch_uc=batch_uc,
|
||||||
|
force_uc_zero_embeddings=[
|
||||||
|
"cond_frames",
|
||||||
|
"cond_frames_without_noise",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in ["crossattn", "concat"]:
|
||||||
|
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
|
||||||
|
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
|
||||||
|
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
|
||||||
|
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
|
||||||
|
|
||||||
|
randn = torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
additional_model_inputs = {}
|
||||||
|
additional_model_inputs["image_only_indicator"] = torch.zeros(
|
||||||
|
2, num_frames
|
||||||
|
).to(device)
|
||||||
|
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
|
||||||
|
|
||||||
|
def denoiser(input, sigma, c):
|
||||||
|
return model.denoiser(
|
||||||
|
model.model, input, sigma, c, **additional_model_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
|
||||||
|
model.en_and_decode_n_samples_a_time = decoding_t
|
||||||
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
samples_x[-1:] = value_dict["cond_frames_without_noise"]
|
||||||
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
||||||
|
|
||||||
|
imageio.imwrite(
|
||||||
|
os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image
|
||||||
|
)
|
||||||
|
|
||||||
|
samples = embed_watermark(samples)
|
||||||
|
samples = filter(samples)
|
||||||
|
vid = (
|
||||||
|
(rearrange(samples, "t c h w -> t h w c") * 255)
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
.astype(np.uint8)
|
||||||
|
)
|
||||||
|
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
||||||
|
imageio.mimwrite(video_path, vid)
|
||||||
|
|
||||||
|
return video_path, seed
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image(image_path, output_size=(576, 576)):
|
||||||
|
image = Image.open(image_path)
|
||||||
|
# Calculate aspect ratios
|
||||||
|
target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
|
||||||
|
image_aspect = image.width / image.height # Aspect ratio of the original image
|
||||||
|
|
||||||
|
# Resize then crop if the original image is larger
|
||||||
|
if image_aspect > target_aspect:
|
||||||
|
# Resize the image to match the target height, maintaining aspect ratio
|
||||||
|
new_height = output_size[1]
|
||||||
|
new_width = int(new_height * image_aspect)
|
||||||
|
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
|
||||||
|
# Calculate coordinates for cropping
|
||||||
|
left = (new_width - output_size[0]) / 2
|
||||||
|
top = 0
|
||||||
|
right = (new_width + output_size[0]) / 2
|
||||||
|
bottom = output_size[1]
|
||||||
|
else:
|
||||||
|
# Resize the image to match the target width, maintaining aspect ratio
|
||||||
|
new_width = output_size[0]
|
||||||
|
new_height = int(new_width / image_aspect)
|
||||||
|
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
|
||||||
|
# Calculate coordinates for cropping
|
||||||
|
left = 0
|
||||||
|
top = (new_height - output_size[1]) / 2
|
||||||
|
right = output_size[0]
|
||||||
|
bottom = (new_height + output_size[1]) / 2
|
||||||
|
|
||||||
|
# Crop the image
|
||||||
|
cropped_image = resized_image.crop((left, top, right, bottom))
|
||||||
|
|
||||||
|
return cropped_image
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown(
|
||||||
|
"""# Demo for SV3D_p from Stability AI ([model](https://huggingface.co/stabilityai/sv3d), [news](https://stability.ai/news/introducing-stable-video-3d))
|
||||||
|
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv3d/blob/main/LICENSE)): generate 21 frames orbital video from a single image, at variable elevation and azimuth.
|
||||||
|
Generation takes ~40s (for 50 steps) in an A100.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
image = gr.Image(label="Upload your image", type="filepath")
|
||||||
|
generate_btn = gr.Button("Generate")
|
||||||
|
video = gr.Video()
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
orbit = gr.Dropdown(
|
||||||
|
["same elevation", "dynamic"],
|
||||||
|
label="Orbit",
|
||||||
|
info="Choose with orbit to generate",
|
||||||
|
)
|
||||||
|
elev_deg = gr.Slider(
|
||||||
|
label="Elevation (in degrees)",
|
||||||
|
info="Elevation of the camera in the conditioning image, in degrees.",
|
||||||
|
value=10.0,
|
||||||
|
minimum=-10,
|
||||||
|
maximum=30,
|
||||||
|
)
|
||||||
|
plot_image = gr.Image()
|
||||||
|
with gr.Accordion("Advanced options", open=False):
|
||||||
|
seed = gr.Slider(
|
||||||
|
label="Seed",
|
||||||
|
value=23,
|
||||||
|
randomize=True,
|
||||||
|
minimum=0,
|
||||||
|
maximum=max_64_bit_int,
|
||||||
|
step=1,
|
||||||
|
)
|
||||||
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
||||||
|
decoding_t = gr.Slider(
|
||||||
|
label="Decode n frames at a time",
|
||||||
|
info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.",
|
||||||
|
value=7,
|
||||||
|
minimum=1,
|
||||||
|
maximum=14,
|
||||||
|
)
|
||||||
|
|
||||||
|
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
|
||||||
|
|
||||||
|
orbit.change(gen_orbit, [orbit, elev_deg], plot_image)
|
||||||
|
elev_deg.change(gen_orbit, [orbit, elev_deg], plot_image)
|
||||||
|
# seed.change(gen_orbit, [orbit, elev_deg], plot_image)
|
||||||
|
|
||||||
|
generate_btn.click(
|
||||||
|
fn=sample,
|
||||||
|
inputs=[image, seed, randomize_seed, decoding_t],
|
||||||
|
outputs=[video, seed],
|
||||||
|
api_name="video",
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo.queue(max_size=20)
|
||||||
|
demo.launch(share=True)
|
||||||
295
scripts/demo/sv3d_u_gradio.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
# Adding this at the very top of app.py to make 'generative-models' directory discoverable
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
import random
|
||||||
|
from glob import glob
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gradio as gr
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image
|
||||||
|
from rembg import remove
|
||||||
|
from scripts.sampling.simple_video_sample import (
|
||||||
|
get_batch,
|
||||||
|
get_unique_embedder_keys_from_conditioner,
|
||||||
|
load_model,
|
||||||
|
)
|
||||||
|
from sgm.inference.helpers import embed_watermark
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
|
|
||||||
|
version = "sv3d_u" # replace with 'sv3d_p' or 'sv3d_u' for other models
|
||||||
|
|
||||||
|
# Define the repo, local directory and filename
|
||||||
|
repo_id = "stabilityai/sv3d"
|
||||||
|
filename = f"{version}.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors"
|
||||||
|
local_dir = "checkpoints"
|
||||||
|
local_ckpt_path = os.path.join(local_dir, filename)
|
||||||
|
|
||||||
|
# Check if the file already exists
|
||||||
|
if not os.path.exists(local_ckpt_path):
|
||||||
|
# If the file doesn't exist, download it
|
||||||
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
|
||||||
|
print("File downloaded.")
|
||||||
|
else:
|
||||||
|
print("File already exists. No need to download.")
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
max_64_bit_int = 2**63 - 1
|
||||||
|
|
||||||
|
num_frames = 21
|
||||||
|
num_steps = 50
|
||||||
|
model_config = f"scripts/sampling/configs/{version}.yaml"
|
||||||
|
|
||||||
|
model, filter = load_model(
|
||||||
|
model_config,
|
||||||
|
device,
|
||||||
|
num_frames,
|
||||||
|
num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
randomize_seed: bool = True,
|
||||||
|
decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
device: str = "cuda",
|
||||||
|
output_folder: str = None,
|
||||||
|
image_frame_ratio: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
|
||||||
|
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
||||||
|
"""
|
||||||
|
if randomize_seed:
|
||||||
|
seed = random.randint(0, max_64_bit_int)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
path = Path(input_path)
|
||||||
|
all_img_paths = []
|
||||||
|
if path.is_file():
|
||||||
|
if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
|
||||||
|
all_img_paths = [input_path]
|
||||||
|
else:
|
||||||
|
raise ValueError("Path is not valid image file.")
|
||||||
|
elif path.is_dir():
|
||||||
|
all_img_paths = sorted(
|
||||||
|
[
|
||||||
|
f
|
||||||
|
for f in path.iterdir()
|
||||||
|
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if len(all_img_paths) == 0:
|
||||||
|
raise ValueError("Folder does not contain any images.")
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
for input_img_path in all_img_paths:
|
||||||
|
|
||||||
|
image = Image.open(input_img_path)
|
||||||
|
if image.mode == "RGBA":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# remove bg
|
||||||
|
image.thumbnail([768, 768], Image.Resampling.LANCZOS)
|
||||||
|
image = remove(image.convert("RGBA"), alpha_matting=True)
|
||||||
|
|
||||||
|
# resize object in frame
|
||||||
|
image_arr = np.array(image)
|
||||||
|
in_w, in_h = image_arr.shape[:2]
|
||||||
|
ret, mask = cv2.threshold(
|
||||||
|
np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY
|
||||||
|
)
|
||||||
|
x, y, w, h = cv2.boundingRect(mask)
|
||||||
|
max_size = max(w, h)
|
||||||
|
side_len = (
|
||||||
|
int(max_size / image_frame_ratio) if image_frame_ratio is not None else in_w
|
||||||
|
)
|
||||||
|
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
|
||||||
|
center = side_len // 2
|
||||||
|
padded_image[
|
||||||
|
center - h // 2 : center - h // 2 + h,
|
||||||
|
center - w // 2 : center - w // 2 + w,
|
||||||
|
] = image_arr[y : y + h, x : x + w]
|
||||||
|
# resize frame to 576x576
|
||||||
|
rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS)
|
||||||
|
# white bg
|
||||||
|
rgba_arr = np.array(rgba) / 255.0
|
||||||
|
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
|
||||||
|
input_image = Image.fromarray((rgb * 255).astype(np.uint8))
|
||||||
|
|
||||||
|
image = ToTensor()(input_image)
|
||||||
|
image = image * 2.0 - 1.0
|
||||||
|
|
||||||
|
image = image.unsqueeze(0).to(device)
|
||||||
|
H, W = image.shape[2:]
|
||||||
|
assert image.shape[1] == 3
|
||||||
|
F = 8
|
||||||
|
C = 4
|
||||||
|
shape = (num_frames, C, H // F, W // F)
|
||||||
|
if (H, W) != (576, 576) and "sv3d" in version:
|
||||||
|
print(
|
||||||
|
"WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576."
|
||||||
|
)
|
||||||
|
|
||||||
|
cond_aug = 1e-5
|
||||||
|
|
||||||
|
value_dict = {}
|
||||||
|
value_dict["cond_aug"] = cond_aug
|
||||||
|
value_dict["cond_frames_without_noise"] = image
|
||||||
|
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
|
||||||
|
value_dict["cond_aug"] = cond_aug
|
||||||
|
|
||||||
|
output_folder = output_folder or f"outputs/gradio/{version}"
|
||||||
|
cond_aug = 1e-5
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.autocast(device):
|
||||||
|
batch, batch_uc = get_batch(
|
||||||
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
[1, num_frames],
|
||||||
|
T=num_frames,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
|
batch,
|
||||||
|
batch_uc=batch_uc,
|
||||||
|
force_uc_zero_embeddings=[
|
||||||
|
"cond_frames",
|
||||||
|
"cond_frames_without_noise",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in ["crossattn", "concat"]:
|
||||||
|
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
|
||||||
|
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
|
||||||
|
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
|
||||||
|
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
|
||||||
|
|
||||||
|
randn = torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
additional_model_inputs = {}
|
||||||
|
additional_model_inputs["image_only_indicator"] = torch.zeros(
|
||||||
|
2, num_frames
|
||||||
|
).to(device)
|
||||||
|
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
|
||||||
|
|
||||||
|
def denoiser(input, sigma, c):
|
||||||
|
return model.denoiser(
|
||||||
|
model.model, input, sigma, c, **additional_model_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
|
||||||
|
model.en_and_decode_n_samples_a_time = decoding_t
|
||||||
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
samples_x[-1:] = value_dict["cond_frames_without_noise"]
|
||||||
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
||||||
|
|
||||||
|
imageio.imwrite(
|
||||||
|
os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image
|
||||||
|
)
|
||||||
|
|
||||||
|
samples = embed_watermark(samples)
|
||||||
|
samples = filter(samples)
|
||||||
|
vid = (
|
||||||
|
(rearrange(samples, "t c h w -> t h w c") * 255)
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
.astype(np.uint8)
|
||||||
|
)
|
||||||
|
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
||||||
|
imageio.mimwrite(video_path, vid)
|
||||||
|
|
||||||
|
return video_path, seed
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image(image_path, output_size=(576, 576)):
|
||||||
|
image = Image.open(image_path)
|
||||||
|
# Calculate aspect ratios
|
||||||
|
target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
|
||||||
|
image_aspect = image.width / image.height # Aspect ratio of the original image
|
||||||
|
|
||||||
|
# Resize then crop if the original image is larger
|
||||||
|
if image_aspect > target_aspect:
|
||||||
|
# Resize the image to match the target height, maintaining aspect ratio
|
||||||
|
new_height = output_size[1]
|
||||||
|
new_width = int(new_height * image_aspect)
|
||||||
|
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
|
||||||
|
# Calculate coordinates for cropping
|
||||||
|
left = (new_width - output_size[0]) / 2
|
||||||
|
top = 0
|
||||||
|
right = (new_width + output_size[0]) / 2
|
||||||
|
bottom = output_size[1]
|
||||||
|
else:
|
||||||
|
# Resize the image to match the target width, maintaining aspect ratio
|
||||||
|
new_width = output_size[0]
|
||||||
|
new_height = int(new_width / image_aspect)
|
||||||
|
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
|
||||||
|
# Calculate coordinates for cropping
|
||||||
|
left = 0
|
||||||
|
top = (new_height - output_size[1]) / 2
|
||||||
|
right = output_size[0]
|
||||||
|
bottom = (new_height + output_size[1]) / 2
|
||||||
|
|
||||||
|
# Crop the image
|
||||||
|
cropped_image = resized_image.crop((left, top, right, bottom))
|
||||||
|
|
||||||
|
return cropped_image
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown(
|
||||||
|
"""# Demo for SV3D_u from Stability AI ([model](https://huggingface.co/stabilityai/sv3d), [news](https://stability.ai/news/introducing-stable-video-3d))
|
||||||
|
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv3d/blob/main/LICENSE)): generate 21 frames orbital video from a single image, at the same elevation.
|
||||||
|
Generation takes ~40s (for 50 steps) in an A100.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
image = gr.Image(label="Upload your image", type="filepath")
|
||||||
|
generate_btn = gr.Button("Generate")
|
||||||
|
video = gr.Video()
|
||||||
|
with gr.Accordion("Advanced options", open=False):
|
||||||
|
seed = gr.Slider(
|
||||||
|
label="Seed",
|
||||||
|
value=23,
|
||||||
|
randomize=True,
|
||||||
|
minimum=0,
|
||||||
|
maximum=max_64_bit_int,
|
||||||
|
step=1,
|
||||||
|
)
|
||||||
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
||||||
|
decoding_t = gr.Slider(
|
||||||
|
label="Decode n frames at a time",
|
||||||
|
info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.",
|
||||||
|
value=7,
|
||||||
|
minimum=1,
|
||||||
|
maximum=14,
|
||||||
|
)
|
||||||
|
|
||||||
|
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
|
||||||
|
generate_btn.click(
|
||||||
|
fn=sample,
|
||||||
|
inputs=[image, seed, randomize_seed, decoding_t],
|
||||||
|
outputs=[video, seed],
|
||||||
|
api_name="video",
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo.queue(max_size=20)
|
||||||
|
demo.launch(share=True)
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
N_TIME: 4
|
|
||||||
N_VIEW: 12
|
|
||||||
N_FRAMES: 48
|
|
||||||
|
|
||||||
model:
|
|
||||||
target: sgm.models.diffusion.DiffusionEngine
|
|
||||||
params:
|
|
||||||
scale_factor: 0.18215
|
|
||||||
en_and_decode_n_samples_a_time: 8
|
|
||||||
disable_first_stage_autocast: True
|
|
||||||
ckpt_path: checkpoints/sp4d.safetensors
|
|
||||||
dual_concat: True
|
|
||||||
denoiser_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
|
||||||
params:
|
|
||||||
scaling_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
|
||||||
|
|
||||||
network_config:
|
|
||||||
target: sgm.modules.diffusionmodules.video_model.DualSpatialUNetWithCrossComm
|
|
||||||
params:
|
|
||||||
unet_config:
|
|
||||||
adm_in_channels: 1280
|
|
||||||
attention_resolutions: [4, 2, 1]
|
|
||||||
channel_mult: [1, 2, 4, 4]
|
|
||||||
context_dim: 1024
|
|
||||||
motion_context_dim: 4
|
|
||||||
extra_ff_mix_layer: True
|
|
||||||
in_channels: 8
|
|
||||||
legacy: False
|
|
||||||
model_channels: 320
|
|
||||||
num_classes: sequential
|
|
||||||
num_head_channels: 64
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_channels: 4
|
|
||||||
replicate_time_mix_bug: True
|
|
||||||
spatial_transformer_attn_type: softmax-xformers
|
|
||||||
time_block_merge_factor: 0.0
|
|
||||||
time_block_merge_strategy: learned_with_images
|
|
||||||
time_kernel_size: [3, 1, 1]
|
|
||||||
time_mix_legacy: False
|
|
||||||
transformer_depth: 1
|
|
||||||
use_checkpoint: False
|
|
||||||
use_linear_in_transformer: True
|
|
||||||
use_spatial_context: True
|
|
||||||
use_spatial_transformer: True
|
|
||||||
separate_motion_merge_factor: True
|
|
||||||
use_motion_attention: True
|
|
||||||
use_3d_attention: True
|
|
||||||
use_camera_emb: True
|
|
||||||
|
|
||||||
conditioner_config:
|
|
||||||
target: sgm.modules.GeneralConditioner
|
|
||||||
params:
|
|
||||||
emb_models:
|
|
||||||
|
|
||||||
- input_key: cond_frames_without_noise
|
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
|
||||||
is_trainable: False
|
|
||||||
params:
|
|
||||||
n_cond_frames: ${N_TIME}
|
|
||||||
n_copies: 1
|
|
||||||
open_clip_embedding_config:
|
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
|
||||||
params:
|
|
||||||
freeze: True
|
|
||||||
|
|
||||||
- input_key: cond_frames
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
is_trainable: False
|
|
||||||
params:
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_FRAMES}
|
|
||||||
n_copies: 1
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
embed_dim: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
monitor: val/rec_loss
|
|
||||||
sigma_cond_config:
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 256
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
- input_key: polar_rad
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 512
|
|
||||||
|
|
||||||
- input_key: azimuth_rad
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 512
|
|
||||||
|
|
||||||
- input_key: cond_view
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
params:
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_VIEW}
|
|
||||||
n_copies: 1
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
- input_key: cond_motion
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
params:
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_TIME}
|
|
||||||
n_copies: 1
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencodingEngine
|
|
||||||
params:
|
|
||||||
loss_config:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
regularizer_config:
|
|
||||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
|
||||||
encoder_config:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
decoder_config:
|
|
||||||
target: sgm.modules.diffusionmodules.model.DecoderDual
|
|
||||||
params:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
|
|
||||||
sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
|
||||||
params:
|
|
||||||
num_steps: 50
|
|
||||||
discretization_config:
|
|
||||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
|
||||||
params:
|
|
||||||
sigma_max: 500.0
|
|
||||||
guider_config:
|
|
||||||
target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider
|
|
||||||
params:
|
|
||||||
max_scale: 1.5
|
|
||||||
min_scale: 1.5
|
|
||||||
num_frames: ${N_FRAMES}
|
|
||||||
num_views: ${N_VIEW}
|
|
||||||
additional_cond_keys: [ cond_view, cond_motion ]
|
|
||||||
@@ -1,203 +0,0 @@
|
|||||||
N_TIME: 5
|
|
||||||
N_VIEW: 8
|
|
||||||
N_FRAMES: 40
|
|
||||||
|
|
||||||
model:
|
|
||||||
target: sgm.models.diffusion.DiffusionEngine
|
|
||||||
params:
|
|
||||||
scale_factor: 0.18215
|
|
||||||
en_and_decode_n_samples_a_time: 7
|
|
||||||
disable_first_stage_autocast: True
|
|
||||||
ckpt_path: checkpoints/sv4d.safetensors
|
|
||||||
|
|
||||||
denoiser_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
|
||||||
params:
|
|
||||||
scaling_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
|
||||||
|
|
||||||
network_config:
|
|
||||||
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
|
|
||||||
params:
|
|
||||||
adm_in_channels: 1280
|
|
||||||
attention_resolutions: [4, 2, 1]
|
|
||||||
channel_mult: [1, 2, 4, 4]
|
|
||||||
context_dim: 1024
|
|
||||||
motion_context_dim: 4
|
|
||||||
extra_ff_mix_layer: True
|
|
||||||
in_channels: 8
|
|
||||||
legacy: False
|
|
||||||
model_channels: 320
|
|
||||||
num_classes: sequential
|
|
||||||
num_head_channels: 64
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_channels: 4
|
|
||||||
replicate_time_mix_bug: True
|
|
||||||
spatial_transformer_attn_type: softmax-xformers
|
|
||||||
time_block_merge_factor: 0.0
|
|
||||||
time_block_merge_strategy: learned_with_images
|
|
||||||
time_kernel_size: [3, 1, 1]
|
|
||||||
time_mix_legacy: False
|
|
||||||
transformer_depth: 1
|
|
||||||
use_checkpoint: False
|
|
||||||
use_linear_in_transformer: True
|
|
||||||
use_spatial_context: True
|
|
||||||
use_spatial_transformer: True
|
|
||||||
use_motion_attention: True
|
|
||||||
|
|
||||||
conditioner_config:
|
|
||||||
target: sgm.modules.GeneralConditioner
|
|
||||||
params:
|
|
||||||
emb_models:
|
|
||||||
|
|
||||||
- input_key: cond_frames_without_noise
|
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
|
||||||
is_trainable: False
|
|
||||||
params:
|
|
||||||
n_cond_frames: ${N_TIME}
|
|
||||||
n_copies: 1
|
|
||||||
open_clip_embedding_config:
|
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
|
||||||
params:
|
|
||||||
freeze: True
|
|
||||||
|
|
||||||
- input_key: cond_frames
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
is_trainable: False
|
|
||||||
params:
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_FRAMES}
|
|
||||||
n_copies: 1
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
embed_dim: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
monitor: val/rec_loss
|
|
||||||
sigma_cond_config:
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 256
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
- input_key: polar_rad
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 512
|
|
||||||
|
|
||||||
- input_key: azimuth_rad
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 512
|
|
||||||
|
|
||||||
- input_key: cond_view
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
params:
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_VIEW}
|
|
||||||
n_copies: 1
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
- input_key: cond_motion
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
params:
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_TIME}
|
|
||||||
n_copies: 1
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencodingEngine
|
|
||||||
params:
|
|
||||||
loss_config:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
regularizer_config:
|
|
||||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
|
||||||
encoder_config:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
decoder_config:
|
|
||||||
target: sgm.modules.diffusionmodules.model.Decoder
|
|
||||||
params:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
|
|
||||||
sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
|
||||||
params:
|
|
||||||
discretization_config:
|
|
||||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
|
||||||
params:
|
|
||||||
sigma_max: 500.0
|
|
||||||
guider_config:
|
|
||||||
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
|
||||||
params:
|
|
||||||
max_scale: 2.5
|
|
||||||
num_frames: ${N_FRAMES}
|
|
||||||
additional_cond_keys: [ cond_view, cond_motion ]
|
|
||||||
@@ -1,208 +0,0 @@
|
|||||||
N_TIME: 12
|
|
||||||
N_VIEW: 4
|
|
||||||
N_FRAMES: 48
|
|
||||||
|
|
||||||
model:
|
|
||||||
target: sgm.models.diffusion.DiffusionEngine
|
|
||||||
params:
|
|
||||||
scale_factor: 0.18215
|
|
||||||
en_and_decode_n_samples_a_time: 8
|
|
||||||
disable_first_stage_autocast: True
|
|
||||||
ckpt_path: checkpoints/sv4d2.safetensors
|
|
||||||
denoiser_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
|
||||||
params:
|
|
||||||
scaling_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
|
||||||
|
|
||||||
network_config:
|
|
||||||
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
|
|
||||||
params:
|
|
||||||
adm_in_channels: 1280
|
|
||||||
attention_resolutions: [4, 2, 1]
|
|
||||||
channel_mult: [1, 2, 4, 4]
|
|
||||||
context_dim: 1024
|
|
||||||
motion_context_dim: 4
|
|
||||||
extra_ff_mix_layer: True
|
|
||||||
in_channels: 8
|
|
||||||
legacy: False
|
|
||||||
model_channels: 320
|
|
||||||
num_classes: sequential
|
|
||||||
num_head_channels: 64
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_channels: 4
|
|
||||||
replicate_time_mix_bug: True
|
|
||||||
spatial_transformer_attn_type: softmax-xformers
|
|
||||||
time_block_merge_factor: 0.0
|
|
||||||
time_block_merge_strategy: learned_with_images
|
|
||||||
time_kernel_size: [3, 1, 1]
|
|
||||||
time_mix_legacy: False
|
|
||||||
transformer_depth: 1
|
|
||||||
use_checkpoint: False
|
|
||||||
use_linear_in_transformer: True
|
|
||||||
use_spatial_context: True
|
|
||||||
use_spatial_transformer: True
|
|
||||||
separate_motion_merge_factor: True
|
|
||||||
use_motion_attention: True
|
|
||||||
use_3d_attention: True
|
|
||||||
use_camera_emb: True
|
|
||||||
|
|
||||||
conditioner_config:
|
|
||||||
target: sgm.modules.GeneralConditioner
|
|
||||||
params:
|
|
||||||
emb_models:
|
|
||||||
|
|
||||||
- input_key: cond_frames_without_noise
|
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
|
||||||
is_trainable: False
|
|
||||||
params:
|
|
||||||
n_cond_frames: ${N_TIME}
|
|
||||||
n_copies: 1
|
|
||||||
open_clip_embedding_config:
|
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
|
||||||
params:
|
|
||||||
freeze: True
|
|
||||||
|
|
||||||
- input_key: cond_frames
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
is_trainable: False
|
|
||||||
params:
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_FRAMES}
|
|
||||||
n_copies: 1
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
embed_dim: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
monitor: val/rec_loss
|
|
||||||
sigma_cond_config:
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 256
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
- input_key: polar_rad
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 512
|
|
||||||
|
|
||||||
- input_key: azimuth_rad
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 512
|
|
||||||
|
|
||||||
- input_key: cond_view
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
params:
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_VIEW}
|
|
||||||
n_copies: 1
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
- input_key: cond_motion
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
params:
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_TIME}
|
|
||||||
n_copies: 1
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencodingEngine
|
|
||||||
params:
|
|
||||||
loss_config:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
regularizer_config:
|
|
||||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
|
||||||
encoder_config:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
decoder_config:
|
|
||||||
target: sgm.modules.diffusionmodules.model.Decoder
|
|
||||||
params:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
|
|
||||||
sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
|
||||||
params:
|
|
||||||
num_steps: 50
|
|
||||||
discretization_config:
|
|
||||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
|
||||||
params:
|
|
||||||
sigma_max: 500.0
|
|
||||||
guider_config:
|
|
||||||
target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider
|
|
||||||
params:
|
|
||||||
max_scale: 1.5
|
|
||||||
min_scale: 1.5
|
|
||||||
num_frames: ${N_FRAMES}
|
|
||||||
num_views: ${N_VIEW}
|
|
||||||
additional_cond_keys: [ cond_view, cond_motion ]
|
|
||||||
@@ -1,208 +0,0 @@
|
|||||||
N_TIME: 5
|
|
||||||
N_VIEW: 8
|
|
||||||
N_FRAMES: 40
|
|
||||||
|
|
||||||
model:
|
|
||||||
target: sgm.models.diffusion.DiffusionEngine
|
|
||||||
params:
|
|
||||||
scale_factor: 0.18215
|
|
||||||
en_and_decode_n_samples_a_time: 8
|
|
||||||
disable_first_stage_autocast: True
|
|
||||||
ckpt_path: checkpoints/sv4d2_8views.safetensors
|
|
||||||
denoiser_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
|
||||||
params:
|
|
||||||
scaling_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
|
||||||
|
|
||||||
network_config:
|
|
||||||
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
|
|
||||||
params:
|
|
||||||
adm_in_channels: 1280
|
|
||||||
attention_resolutions: [4, 2, 1]
|
|
||||||
channel_mult: [1, 2, 4, 4]
|
|
||||||
context_dim: 1024
|
|
||||||
motion_context_dim: 4
|
|
||||||
extra_ff_mix_layer: True
|
|
||||||
in_channels: 8
|
|
||||||
legacy: False
|
|
||||||
model_channels: 320
|
|
||||||
num_classes: sequential
|
|
||||||
num_head_channels: 64
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_channels: 4
|
|
||||||
replicate_time_mix_bug: True
|
|
||||||
spatial_transformer_attn_type: softmax-xformers
|
|
||||||
time_block_merge_factor: 0.0
|
|
||||||
time_block_merge_strategy: learned_with_images
|
|
||||||
time_kernel_size: [3, 1, 1]
|
|
||||||
time_mix_legacy: False
|
|
||||||
transformer_depth: 1
|
|
||||||
use_checkpoint: False
|
|
||||||
use_linear_in_transformer: True
|
|
||||||
use_spatial_context: True
|
|
||||||
use_spatial_transformer: True
|
|
||||||
separate_motion_merge_factor: True
|
|
||||||
use_motion_attention: True
|
|
||||||
use_3d_attention: False
|
|
||||||
use_camera_emb: True
|
|
||||||
|
|
||||||
conditioner_config:
|
|
||||||
target: sgm.modules.GeneralConditioner
|
|
||||||
params:
|
|
||||||
emb_models:
|
|
||||||
|
|
||||||
- input_key: cond_frames_without_noise
|
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
|
||||||
is_trainable: False
|
|
||||||
params:
|
|
||||||
n_cond_frames: ${N_TIME}
|
|
||||||
n_copies: 1
|
|
||||||
open_clip_embedding_config:
|
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
|
||||||
params:
|
|
||||||
freeze: True
|
|
||||||
|
|
||||||
- input_key: cond_frames
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
is_trainable: False
|
|
||||||
params:
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_FRAMES}
|
|
||||||
n_copies: 1
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
embed_dim: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
monitor: val/rec_loss
|
|
||||||
sigma_cond_config:
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 256
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
- input_key: polar_rad
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 512
|
|
||||||
|
|
||||||
- input_key: azimuth_rad
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 512
|
|
||||||
|
|
||||||
- input_key: cond_view
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
params:
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_VIEW}
|
|
||||||
n_copies: 1
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
- input_key: cond_motion
|
|
||||||
is_trainable: False
|
|
||||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
|
||||||
params:
|
|
||||||
is_ae: True
|
|
||||||
n_cond_frames: ${N_TIME}
|
|
||||||
n_copies: 1
|
|
||||||
encoder_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
sigma_sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencodingEngine
|
|
||||||
params:
|
|
||||||
loss_config:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
regularizer_config:
|
|
||||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
|
||||||
encoder_config:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
decoder_config:
|
|
||||||
target: sgm.modules.diffusionmodules.model.Decoder
|
|
||||||
params:
|
|
||||||
attn_resolutions: []
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
double_z: True
|
|
||||||
dropout: 0.0
|
|
||||||
in_channels: 3
|
|
||||||
num_res_blocks: 2
|
|
||||||
out_ch: 3
|
|
||||||
resolution: 256
|
|
||||||
z_channels: 4
|
|
||||||
|
|
||||||
sampler_config:
|
|
||||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
|
||||||
params:
|
|
||||||
num_steps: 50
|
|
||||||
discretization_config:
|
|
||||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
|
||||||
params:
|
|
||||||
sigma_max: 500.0
|
|
||||||
guider_config:
|
|
||||||
target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider
|
|
||||||
params:
|
|
||||||
max_scale: 2.0
|
|
||||||
min_scale: 1.5
|
|
||||||
num_frames: ${N_FRAMES}
|
|
||||||
num_views: ${N_VIEW}
|
|
||||||
additional_cond_keys: [ cond_view, cond_motion ]
|
|
||||||
@@ -100,7 +100,7 @@ def sample(
|
|||||||
device,
|
device,
|
||||||
num_frames,
|
num_frames,
|
||||||
num_steps,
|
num_steps,
|
||||||
verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
@@ -163,7 +163,7 @@ def sample(
|
|||||||
else:
|
else:
|
||||||
with Image.open(input_img_path) as image:
|
with Image.open(input_img_path) as image:
|
||||||
if image.mode == "RGBA":
|
if image.mode == "RGBA":
|
||||||
image = image.convert("RGB")
|
input_image = image.convert("RGB")
|
||||||
w, h = image.size
|
w, h = image.size
|
||||||
|
|
||||||
if h % 64 != 0 or w % 64 != 0:
|
if h % 64 != 0 or w % 64 != 0:
|
||||||
@@ -172,7 +172,6 @@ def sample(
|
|||||||
print(
|
print(
|
||||||
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
|
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
|
||||||
)
|
)
|
||||||
input_image = np.array(image)
|
|
||||||
|
|
||||||
image = ToTensor()(input_image)
|
image = ToTensor()(input_image)
|
||||||
image = image * 2.0 - 1.0
|
image = image * 2.0 - 1.0
|
||||||
|
|||||||
@@ -1,259 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
from glob import glob
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from fire import Fire
|
|
||||||
|
|
||||||
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
|
||||||
from scripts.demo.sv4d_helpers import (
|
|
||||||
decode_latents,
|
|
||||||
load_model,
|
|
||||||
initial_model_load,
|
|
||||||
read_video,
|
|
||||||
run_img2vid,
|
|
||||||
prepare_sampling,
|
|
||||||
prepare_inputs,
|
|
||||||
do_sample_per_step,
|
|
||||||
sample_sv3d,
|
|
||||||
save_video,
|
|
||||||
preprocess_video,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sample(
|
|
||||||
input_path: str = "assets/sv4d_videos/test_video1.mp4", # Can either be image file or folder with image files
|
|
||||||
output_folder: Optional[str] = "outputs/sv4d",
|
|
||||||
num_steps: Optional[int] = 20,
|
|
||||||
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
|
|
||||||
img_size: int = 576, # image resolution
|
|
||||||
fps_id: int = 6,
|
|
||||||
motion_bucket_id: int = 127,
|
|
||||||
cond_aug: float = 1e-5,
|
|
||||||
seed: int = 23,
|
|
||||||
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
|
||||||
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
|
||||||
device: str = "cuda",
|
|
||||||
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
|
|
||||||
azimuths_deg: Optional[List[float]] = None,
|
|
||||||
image_frame_ratio: Optional[float] = 0.917,
|
|
||||||
verbose: Optional[bool] = False,
|
|
||||||
remove_bg: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
|
||||||
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
|
|
||||||
"""
|
|
||||||
# Set model config
|
|
||||||
T = 5 # number of frames per sample
|
|
||||||
V = 8 # number of views per sample
|
|
||||||
F = 8 # vae factor to downsize image->latent
|
|
||||||
C = 4
|
|
||||||
H, W = img_size, img_size
|
|
||||||
n_frames = 21 # number of input and output video frames
|
|
||||||
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
|
||||||
n_views_sv3d = 21
|
|
||||||
subsampled_views = np.array(
|
|
||||||
[0, 2, 5, 7, 9, 12, 14, 16, 19]
|
|
||||||
) # subsample (V+1=)9 (uniform) views from 21 SV3D views
|
|
||||||
|
|
||||||
model_config = "scripts/sampling/configs/sv4d.yaml"
|
|
||||||
version_dict = {
|
|
||||||
"T": T * V,
|
|
||||||
"H": H,
|
|
||||||
"W": W,
|
|
||||||
"C": C,
|
|
||||||
"f": F,
|
|
||||||
"options": {
|
|
||||||
"discretization": 1,
|
|
||||||
"cfg": 2.0,
|
|
||||||
"num_views": V,
|
|
||||||
"sigma_min": 0.002,
|
|
||||||
"sigma_max": 700.0,
|
|
||||||
"rho": 7.0,
|
|
||||||
"guider": 5,
|
|
||||||
"num_steps": num_steps,
|
|
||||||
"force_uc_zero_embeddings": [
|
|
||||||
"cond_frames",
|
|
||||||
"cond_frames_without_noise",
|
|
||||||
"cond_view",
|
|
||||||
"cond_motion",
|
|
||||||
],
|
|
||||||
"additional_guider_kwargs": {
|
|
||||||
"additional_cond_keys": ["cond_view", "cond_motion"]
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
|
||||||
|
|
||||||
# Read input video frames i.e. images at view 0
|
|
||||||
print(f"Reading {input_path}")
|
|
||||||
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
|
|
||||||
processed_input_path = preprocess_video(
|
|
||||||
input_path,
|
|
||||||
remove_bg=remove_bg,
|
|
||||||
n_frames=n_frames,
|
|
||||||
W=W,
|
|
||||||
H=H,
|
|
||||||
output_folder=output_folder,
|
|
||||||
image_frame_ratio=image_frame_ratio,
|
|
||||||
base_count=base_count,
|
|
||||||
)
|
|
||||||
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
|
|
||||||
|
|
||||||
# Get camera viewpoints
|
|
||||||
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
|
||||||
elevations_deg = [elevations_deg] * n_views_sv3d
|
|
||||||
assert (
|
|
||||||
len(elevations_deg) == n_views_sv3d
|
|
||||||
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
|
|
||||||
if azimuths_deg is None:
|
|
||||||
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
|
|
||||||
assert (
|
|
||||||
len(azimuths_deg) == n_views_sv3d
|
|
||||||
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
|
|
||||||
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
|
||||||
azimuths_rad = np.array(
|
|
||||||
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sample multi-view images of the first frame using SV3D i.e. images at time 0
|
|
||||||
images_t0 = sample_sv3d(
|
|
||||||
images_v0[0],
|
|
||||||
n_views_sv3d,
|
|
||||||
num_steps,
|
|
||||||
sv3d_version,
|
|
||||||
fps_id,
|
|
||||||
motion_bucket_id,
|
|
||||||
cond_aug,
|
|
||||||
decoding_t,
|
|
||||||
device,
|
|
||||||
polars_rad,
|
|
||||||
azimuths_rad,
|
|
||||||
verbose,
|
|
||||||
)
|
|
||||||
images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame
|
|
||||||
|
|
||||||
# Initialize image matrix
|
|
||||||
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
|
||||||
for i, v in enumerate(subsampled_views):
|
|
||||||
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
|
||||||
for t in range(n_frames):
|
|
||||||
img_matrix[t][0] = images_v0[t]
|
|
||||||
|
|
||||||
save_video(
|
|
||||||
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
|
|
||||||
img_matrix[0],
|
|
||||||
)
|
|
||||||
# save_video(
|
|
||||||
# os.path.join(output_folder, f"{base_count:06d}_v000.mp4"),
|
|
||||||
# [img_matrix[t][0] for t in range(n_frames)],
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Load SV4D model
|
|
||||||
model, filter = load_model(
|
|
||||||
model_config,
|
|
||||||
device,
|
|
||||||
version_dict["T"],
|
|
||||||
num_steps,
|
|
||||||
verbose,
|
|
||||||
)
|
|
||||||
model = initial_model_load(model)
|
|
||||||
for emb in model.conditioner.embedders:
|
|
||||||
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
|
||||||
emb.en_and_decode_n_samples_a_time = encoding_t
|
|
||||||
model.en_and_decode_n_samples_a_time = decoding_t
|
|
||||||
|
|
||||||
# Interleaved sampling for anchor frames
|
|
||||||
t0, v0 = 0, 0
|
|
||||||
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
|
|
||||||
view_indices = np.arange(V) + 1
|
|
||||||
print(f"Sampling anchor frames {frame_indices}")
|
|
||||||
image = img_matrix[t0][v0]
|
|
||||||
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
|
||||||
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
|
||||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
|
||||||
samples = run_img2vid(
|
|
||||||
version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t
|
|
||||||
)
|
|
||||||
samples = samples.view(T, V, 3, H, W)
|
|
||||||
for i, t in enumerate(frame_indices):
|
|
||||||
for j, v in enumerate(view_indices):
|
|
||||||
if img_matrix[t][v] is None:
|
|
||||||
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
|
||||||
|
|
||||||
# Dense sampling for the rest
|
|
||||||
print(f"Sampling dense frames:")
|
|
||||||
for t0 in tqdm(np.arange(0, n_frames - 1, T - 1)): # [0, 4, 8, 12, 16]
|
|
||||||
frame_indices = t0 + np.arange(T)
|
|
||||||
print(f"Sampling dense frames {frame_indices}")
|
|
||||||
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
|
|
||||||
|
|
||||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
|
||||||
|
|
||||||
# alternate between forward and backward conditioning
|
|
||||||
forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(
|
|
||||||
frame_indices,
|
|
||||||
img_matrix,
|
|
||||||
v0,
|
|
||||||
view_indices,
|
|
||||||
model,
|
|
||||||
version_dict,
|
|
||||||
seed,
|
|
||||||
polars,
|
|
||||||
azims
|
|
||||||
)
|
|
||||||
|
|
||||||
for step in tqdm(range(num_steps)):
|
|
||||||
if step % 2 == 1:
|
|
||||||
c, uc, additional_model_inputs, sampler = forward_inputs
|
|
||||||
frame_indices = forward_frame_indices
|
|
||||||
else:
|
|
||||||
c, uc, additional_model_inputs, sampler = backward_inputs
|
|
||||||
frame_indices = backward_frame_indices
|
|
||||||
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
|
|
||||||
|
|
||||||
samples = do_sample_per_step(
|
|
||||||
model,
|
|
||||||
sampler,
|
|
||||||
noisy_latents,
|
|
||||||
c,
|
|
||||||
uc,
|
|
||||||
step,
|
|
||||||
additional_model_inputs,
|
|
||||||
)
|
|
||||||
samples = samples.view(T, V, C, H // F, W // F)
|
|
||||||
for i, t in enumerate(frame_indices):
|
|
||||||
for j, v in enumerate(view_indices):
|
|
||||||
latent_matrix[t, v] = samples[i, j]
|
|
||||||
|
|
||||||
img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)
|
|
||||||
|
|
||||||
# Save output videos
|
|
||||||
for v in view_indices:
|
|
||||||
vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4")
|
|
||||||
print(f"Saving {vid_file}")
|
|
||||||
save_video(vid_file, [img_matrix[t][v] for t in range(n_frames)])
|
|
||||||
|
|
||||||
# Save diagonal video
|
|
||||||
diag_frames = [
|
|
||||||
img_matrix[t][(t // (n_frames // n_views)) % n_views] for t in range(n_frames)
|
|
||||||
]
|
|
||||||
vid_file = os.path.join(output_folder, f"{base_count:06d}_diag.mp4")
|
|
||||||
print(f"Saving {vid_file}")
|
|
||||||
save_video(vid_file, diag_frames)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
Fire(sample)
|
|
||||||
@@ -1,235 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
from glob import glob
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from fire import Fire
|
|
||||||
from scripts.demo.sv4d_helpers import (
|
|
||||||
load_model,
|
|
||||||
preprocess_video,
|
|
||||||
read_video,
|
|
||||||
run_img2vid,
|
|
||||||
save_video,
|
|
||||||
)
|
|
||||||
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
|
||||||
|
|
||||||
sv4d2_configs = {
|
|
||||||
"sv4d2": {
|
|
||||||
"T": 12, # number of frames per sample
|
|
||||||
"V": 4, # number of views per sample
|
|
||||||
"model_config": "scripts/sampling/configs/sv4d2.yaml",
|
|
||||||
"version_dict": {
|
|
||||||
"T": 12 * 4,
|
|
||||||
"options": {
|
|
||||||
"discretization": 1,
|
|
||||||
"cfg": 2.0,
|
|
||||||
"min_cfg": 2.0,
|
|
||||||
"num_views": 4,
|
|
||||||
"sigma_min": 0.002,
|
|
||||||
"sigma_max": 700.0,
|
|
||||||
"rho": 7.0,
|
|
||||||
"guider": 2,
|
|
||||||
"force_uc_zero_embeddings": [
|
|
||||||
"cond_frames",
|
|
||||||
"cond_frames_without_noise",
|
|
||||||
"cond_view",
|
|
||||||
"cond_motion",
|
|
||||||
],
|
|
||||||
"additional_guider_kwargs": {
|
|
||||||
"additional_cond_keys": ["cond_view", "cond_motion"]
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"sv4d2_8views": {
|
|
||||||
"T": 5, # number of frames per sample
|
|
||||||
"V": 8, # number of views per sample
|
|
||||||
"model_config": "scripts/sampling/configs/sv4d2_8views.yaml",
|
|
||||||
"version_dict": {
|
|
||||||
"T": 5 * 8,
|
|
||||||
"options": {
|
|
||||||
"discretization": 1,
|
|
||||||
"cfg": 2.5,
|
|
||||||
"min_cfg": 1.5,
|
|
||||||
"num_views": 8,
|
|
||||||
"sigma_min": 0.002,
|
|
||||||
"sigma_max": 700.0,
|
|
||||||
"rho": 7.0,
|
|
||||||
"guider": 5,
|
|
||||||
"force_uc_zero_embeddings": [
|
|
||||||
"cond_frames",
|
|
||||||
"cond_frames_without_noise",
|
|
||||||
"cond_view",
|
|
||||||
"cond_motion",
|
|
||||||
],
|
|
||||||
"additional_guider_kwargs": {
|
|
||||||
"additional_cond_keys": ["cond_view", "cond_motion"]
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def sample(
|
|
||||||
input_path: str = "assets/sv4d_videos/camel.gif", # Can either be image file or folder with image files
|
|
||||||
model_path: Optional[str] = "checkpoints/sv4d2.safetensors",
|
|
||||||
output_folder: Optional[str] = "outputs",
|
|
||||||
num_steps: Optional[int] = 50,
|
|
||||||
img_size: int = 576, # image resolution
|
|
||||||
n_frames: int = 21, # number of input and output video frames
|
|
||||||
seed: int = 23,
|
|
||||||
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
|
||||||
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
|
||||||
device: str = "cuda",
|
|
||||||
elevations_deg: Optional[List[float]] = 0.0,
|
|
||||||
azimuths_deg: Optional[List[float]] = None,
|
|
||||||
image_frame_ratio: Optional[float] = 0.9,
|
|
||||||
verbose: Optional[bool] = False,
|
|
||||||
remove_bg: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
|
||||||
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
|
|
||||||
"""
|
|
||||||
# Set model config
|
|
||||||
assert os.path.basename(model_path) in [
|
|
||||||
"sv4d2.safetensors",
|
|
||||||
"sv4d2_8views.safetensors",
|
|
||||||
]
|
|
||||||
sv4d2_model = os.path.splitext(os.path.basename(model_path))[0]
|
|
||||||
config = sv4d2_configs[sv4d2_model]
|
|
||||||
print(sv4d2_model, config)
|
|
||||||
T = config["T"]
|
|
||||||
V = config["V"]
|
|
||||||
model_config = config["model_config"]
|
|
||||||
version_dict = config["version_dict"]
|
|
||||||
F = 8 # vae factor to downsize image->latent
|
|
||||||
C = 4
|
|
||||||
H, W = img_size, img_size
|
|
||||||
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
|
||||||
subsampled_views = np.arange(n_views)
|
|
||||||
version_dict["H"] = H
|
|
||||||
version_dict["W"] = W
|
|
||||||
version_dict["C"] = C
|
|
||||||
version_dict["f"] = F
|
|
||||||
version_dict["options"]["num_steps"] = num_steps
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
output_folder = os.path.join(output_folder, sv4d2_model)
|
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
|
||||||
|
|
||||||
# Read input video frames i.e. images at view 0
|
|
||||||
print(f"Reading {input_path}")
|
|
||||||
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // n_views
|
|
||||||
processed_input_path = preprocess_video(
|
|
||||||
input_path,
|
|
||||||
remove_bg=remove_bg,
|
|
||||||
n_frames=n_frames,
|
|
||||||
W=W,
|
|
||||||
H=H,
|
|
||||||
output_folder=output_folder,
|
|
||||||
image_frame_ratio=image_frame_ratio,
|
|
||||||
base_count=base_count,
|
|
||||||
)
|
|
||||||
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
|
|
||||||
images_t0 = torch.zeros(n_views, 3, H, W).float().to(device)
|
|
||||||
|
|
||||||
# Get camera viewpoints
|
|
||||||
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
|
||||||
elevations_deg = [elevations_deg] * n_views
|
|
||||||
assert (
|
|
||||||
len(elevations_deg) == n_views
|
|
||||||
), f"Please provide 1 value, or a list of {n_views} values for elevations_deg! Given {len(elevations_deg)}"
|
|
||||||
if azimuths_deg is None:
|
|
||||||
# azimuths_deg = np.linspace(0, 360, n_views + 1)[1:] % 360
|
|
||||||
azimuths_deg = (
|
|
||||||
np.array([0, 60, 120, 180, 240])
|
|
||||||
if sv4d2_model == "sv4d2"
|
|
||||||
else np.array([0, 30, 75, 120, 165, 210, 255, 300, 330])
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
len(azimuths_deg) == n_views
|
|
||||||
), f"Please provide a list of {n_views} values for azimuths_deg! Given {len(azimuths_deg)}"
|
|
||||||
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
|
||||||
azimuths_rad = np.array(
|
|
||||||
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize image matrix
|
|
||||||
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
|
||||||
for i, v in enumerate(subsampled_views):
|
|
||||||
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
|
||||||
for t in range(n_frames):
|
|
||||||
img_matrix[t][0] = images_v0[t]
|
|
||||||
|
|
||||||
# Load SV4D++ model
|
|
||||||
model, _ = load_model(
|
|
||||||
model_config,
|
|
||||||
device,
|
|
||||||
version_dict["T"],
|
|
||||||
num_steps,
|
|
||||||
verbose,
|
|
||||||
model_path,
|
|
||||||
)
|
|
||||||
model.en_and_decode_n_samples_a_time = decoding_t
|
|
||||||
for emb in model.conditioner.embedders:
|
|
||||||
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
|
||||||
emb.en_and_decode_n_samples_a_time = encoding_t
|
|
||||||
|
|
||||||
# Sampling novel-view videos
|
|
||||||
v0 = 0
|
|
||||||
view_indices = np.arange(V) + 1
|
|
||||||
t0_list = (
|
|
||||||
range(0, n_frames, T-1)
|
|
||||||
if sv4d2_model == "sv4d2"
|
|
||||||
else range(0, n_frames - T + 1, T - 1)
|
|
||||||
)
|
|
||||||
for t0 in tqdm(t0_list):
|
|
||||||
if t0 + T > n_frames:
|
|
||||||
t0 = n_frames - T
|
|
||||||
frame_indices = t0 + np.arange(T)
|
|
||||||
print(f"Sampling frames {frame_indices}")
|
|
||||||
image = img_matrix[t0][v0]
|
|
||||||
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
|
||||||
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
|
||||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
polars = (polars - polars_rad[v0] + torch.pi / 2) % (torch.pi * 2)
|
|
||||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
|
||||||
cond_mv = False if t0 == 0 else True
|
|
||||||
samples = run_img2vid(
|
|
||||||
version_dict,
|
|
||||||
model,
|
|
||||||
image,
|
|
||||||
seed,
|
|
||||||
polars,
|
|
||||||
azims,
|
|
||||||
cond_motion,
|
|
||||||
cond_view,
|
|
||||||
decoding_t,
|
|
||||||
cond_mv=cond_mv,
|
|
||||||
)
|
|
||||||
samples = samples.view(T, V, 3, H, W)
|
|
||||||
|
|
||||||
for i, t in enumerate(frame_indices):
|
|
||||||
for j, v in enumerate(view_indices):
|
|
||||||
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
|
||||||
|
|
||||||
# Save output videos
|
|
||||||
for v in view_indices:
|
|
||||||
vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4")
|
|
||||||
print(f"Saving {vid_file}")
|
|
||||||
save_video(
|
|
||||||
vid_file,
|
|
||||||
[img_matrix[t][v] for t in range(n_frames) if img_matrix[t][v] is not None],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
Fire(sample)
|
|
||||||
@@ -1,198 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
from glob import glob
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from fire import Fire
|
|
||||||
from scripts.demo.sv4d_helpers import (
|
|
||||||
load_model,
|
|
||||||
preprocess_video,
|
|
||||||
read_video,
|
|
||||||
run_img2vid,
|
|
||||||
save_video,
|
|
||||||
)
|
|
||||||
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
|
||||||
|
|
||||||
sp4d_configs = {
|
|
||||||
"sp4d": {
|
|
||||||
"T": 4, # number of frames per sample
|
|
||||||
"V": 12, # number of views per sample
|
|
||||||
"model_config": "scripts/sampling/configs/sp4d.yaml",
|
|
||||||
"version_dict": {
|
|
||||||
"T": 48,
|
|
||||||
"options": {
|
|
||||||
"discretization": 1,
|
|
||||||
"cfg": 3.0,
|
|
||||||
"min_cfg": 1.5,
|
|
||||||
"num_views": 12,
|
|
||||||
"sigma_min": 0.002,
|
|
||||||
"sigma_max": 700.0,
|
|
||||||
"rho": 7.0,
|
|
||||||
"guider": 2,
|
|
||||||
"force_uc_zero_embeddings": [
|
|
||||||
"cond_frames",
|
|
||||||
"cond_frames_without_noise",
|
|
||||||
"cond_view",
|
|
||||||
"cond_motion",
|
|
||||||
],
|
|
||||||
"additional_guider_kwargs": {
|
|
||||||
"additional_cond_keys": ["cond_view", "cond_motion"]
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def sample(
|
|
||||||
input_path: str = "assets/sv4d_videos/camel.gif", # Can either be image file or folder with image files
|
|
||||||
model_path: Optional[str] = "checkpoints/sp4d.safetensors",
|
|
||||||
output_folder: Optional[str] = "outputs",
|
|
||||||
num_steps: Optional[int] = 50,
|
|
||||||
img_size: int = 512, # image resolution
|
|
||||||
n_frames: int = 4, # number of input and output video frames
|
|
||||||
seed: int = 23,
|
|
||||||
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
|
||||||
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
|
||||||
device: str = "cuda",
|
|
||||||
elevations_deg: Optional[List[float]] = 0.0,
|
|
||||||
azimuths_deg: Optional[List[float]] = None,
|
|
||||||
image_frame_ratio: Optional[float] = 0.9,
|
|
||||||
verbose: Optional[bool] = False,
|
|
||||||
remove_bg: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
|
||||||
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
|
|
||||||
"""
|
|
||||||
# Set model config
|
|
||||||
assert os.path.basename(model_path) in [
|
|
||||||
"sp4d.safetensors",
|
|
||||||
]
|
|
||||||
sp4d_model = os.path.splitext(os.path.basename(model_path))[0]
|
|
||||||
config = sp4d_configs[sp4d_model]
|
|
||||||
print(sp4d_model, config)
|
|
||||||
T = config["T"]
|
|
||||||
V = config["V"]
|
|
||||||
model_config = config["model_config"]
|
|
||||||
version_dict = config["version_dict"]
|
|
||||||
F = 8 # vae factor to downsize image->latent
|
|
||||||
C = 4
|
|
||||||
H, W = img_size, img_size
|
|
||||||
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
|
||||||
subsampled_views = np.arange(n_views)
|
|
||||||
version_dict["H"] = H
|
|
||||||
version_dict["W"] = W
|
|
||||||
version_dict["C"] = C
|
|
||||||
version_dict["f"] = F
|
|
||||||
version_dict["options"]["num_steps"] = num_steps
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
output_folder = os.path.join(output_folder, sp4d_model)
|
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
|
||||||
|
|
||||||
# Read input video frames i.e. images at view 0
|
|
||||||
print(f"Reading {input_path}")
|
|
||||||
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // (n_frames + 1)
|
|
||||||
processed_input_path = preprocess_video(
|
|
||||||
input_path,
|
|
||||||
remove_bg=remove_bg,
|
|
||||||
n_frames=n_frames,
|
|
||||||
W=W,
|
|
||||||
H=H,
|
|
||||||
output_folder=output_folder,
|
|
||||||
image_frame_ratio=image_frame_ratio,
|
|
||||||
base_count=base_count,
|
|
||||||
)
|
|
||||||
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
|
|
||||||
images_t0 = torch.zeros(n_views, 3, H, W).float().to(device)
|
|
||||||
|
|
||||||
# Get camera viewpoints
|
|
||||||
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
|
||||||
elevations_deg = [elevations_deg] * n_views
|
|
||||||
assert (
|
|
||||||
len(elevations_deg) == n_views
|
|
||||||
), f"Please provide 1 value, or a list of {n_views} values for elevations_deg! Given {len(elevations_deg)}"
|
|
||||||
if azimuths_deg is None:
|
|
||||||
azimuths_deg = np.linspace(0, 360, n_views + 1)[1:] % 360
|
|
||||||
assert (
|
|
||||||
len(azimuths_deg) == n_views
|
|
||||||
), f"Please provide a list of {n_views} values for azimuths_deg! Given {len(azimuths_deg)}"
|
|
||||||
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
|
||||||
azimuths_rad = np.array(
|
|
||||||
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize image matrix
|
|
||||||
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
|
||||||
for i, v in enumerate(subsampled_views):
|
|
||||||
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
|
||||||
for t in range(n_frames):
|
|
||||||
img_matrix[t][0] = images_v0[t]
|
|
||||||
|
|
||||||
# Load SV4D++ model
|
|
||||||
model, _ = load_model(
|
|
||||||
model_config,
|
|
||||||
device,
|
|
||||||
version_dict["T"],
|
|
||||||
num_steps,
|
|
||||||
verbose,
|
|
||||||
model_path,
|
|
||||||
)
|
|
||||||
model.en_and_decode_n_samples_a_time = decoding_t
|
|
||||||
for emb in model.conditioner.embedders:
|
|
||||||
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
|
||||||
emb.en_and_decode_n_samples_a_time = encoding_t
|
|
||||||
|
|
||||||
# Sampling novel-view videos
|
|
||||||
v0 = 0
|
|
||||||
view_indices = np.arange(V) + 1
|
|
||||||
t0_list = range(0, n_frames - T + 1, T - 1)
|
|
||||||
for t0 in tqdm(t0_list):
|
|
||||||
if t0 + T > n_frames:
|
|
||||||
t0 = n_frames - T
|
|
||||||
frame_indices = t0 + np.arange(T)
|
|
||||||
print(f"Sampling frames {frame_indices}")
|
|
||||||
image = img_matrix[t0][v0]
|
|
||||||
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
|
||||||
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
|
||||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
|
||||||
polars = (polars - polars_rad[v0] + torch.pi / 2) % (torch.pi * 2)
|
|
||||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
|
||||||
samples = run_img2vid(
|
|
||||||
version_dict,
|
|
||||||
model,
|
|
||||||
image,
|
|
||||||
seed,
|
|
||||||
polars,
|
|
||||||
azims,
|
|
||||||
cond_motion,
|
|
||||||
cond_view,
|
|
||||||
decoding_t,
|
|
||||||
cond_mv=False,
|
|
||||||
part_maps=True,
|
|
||||||
)
|
|
||||||
samples = samples.view(T, V, 3, H, -1)
|
|
||||||
|
|
||||||
for i, t in enumerate(frame_indices):
|
|
||||||
for j, v in enumerate(view_indices):
|
|
||||||
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
|
||||||
|
|
||||||
# Save output videos
|
|
||||||
for t in frame_indices:
|
|
||||||
vid_file = os.path.join(output_folder, f"{base_count:06d}_{t:03d}.mp4")
|
|
||||||
print(f"Saving {vid_file}")
|
|
||||||
save_video(
|
|
||||||
vid_file,
|
|
||||||
[img_matrix[t][v] for v in range(1, n_views) if img_matrix[t][v] is not None],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
Fire(sample)
|
|
||||||
@@ -38,7 +38,6 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
no_cond_log: bool = False,
|
no_cond_log: bool = False,
|
||||||
compile_model: bool = False,
|
compile_model: bool = False,
|
||||||
en_and_decode_n_samples_a_time: Optional[int] = None,
|
en_and_decode_n_samples_a_time: Optional[int] = None,
|
||||||
dual_concat: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.log_keys = log_keys
|
self.log_keys = log_keys
|
||||||
@@ -48,7 +47,7 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
)
|
)
|
||||||
model = instantiate_from_config(network_config)
|
model = instantiate_from_config(network_config)
|
||||||
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
||||||
model, compile_model=compile_model, dual_concat=dual_concat
|
model, compile_model=compile_model
|
||||||
)
|
)
|
||||||
|
|
||||||
self.denoiser = instantiate_from_config(denoiser_config)
|
self.denoiser = instantiate_from_config(denoiser_config)
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ class LinearPredictionGuider(Guider):
|
|||||||
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
|
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
|
||||||
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
||||||
else:
|
else:
|
||||||
# assert c[k] == uc[k]
|
assert c[k] == uc[k]
|
||||||
c_out[k] = c[k]
|
c_out[k] = c[k]
|
||||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||||
|
|
||||||
@@ -105,7 +105,7 @@ class TrianglePredictionGuider(LinearPredictionGuider):
|
|||||||
max_scale: float,
|
max_scale: float,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
min_scale: float = 1.0,
|
min_scale: float = 1.0,
|
||||||
period: Union[float, List[float]] = 1.0,
|
period: float | List[float] = 1.0,
|
||||||
period_fusing: Literal["mean", "multiply", "max"] = "max",
|
period_fusing: Literal["mean", "multiply", "max"] = "max",
|
||||||
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||||
):
|
):
|
||||||
@@ -129,47 +129,3 @@ class TrianglePredictionGuider(LinearPredictionGuider):
|
|||||||
|
|
||||||
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
|
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
|
||||||
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
||||||
|
|
||||||
|
|
||||||
class TrapezoidPredictionGuider(LinearPredictionGuider):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_scale: float,
|
|
||||||
num_frames: int,
|
|
||||||
min_scale: float = 1.0,
|
|
||||||
edge_perc: float = 0.1,
|
|
||||||
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
|
||||||
):
|
|
||||||
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
|
|
||||||
|
|
||||||
rise_steps = torch.linspace(min_scale, max_scale, int(num_frames * edge_perc))
|
|
||||||
fall_steps = torch.flip(rise_steps, [0])
|
|
||||||
self.scale = torch.cat(
|
|
||||||
[
|
|
||||||
rise_steps,
|
|
||||||
torch.ones(num_frames - 2 * int(num_frames * edge_perc)),
|
|
||||||
fall_steps,
|
|
||||||
]
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class SpatiotemporalPredictionGuider(LinearPredictionGuider):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_scale: float,
|
|
||||||
num_frames: int,
|
|
||||||
num_views: int = 1,
|
|
||||||
min_scale: float = 1.0,
|
|
||||||
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
|
||||||
):
|
|
||||||
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
|
|
||||||
V = num_views
|
|
||||||
T = num_frames // V
|
|
||||||
scale = torch.zeros(num_frames).view(T, V)
|
|
||||||
scale += torch.linspace(0, 1, T)[:,None] * 0.5
|
|
||||||
scale += self.triangle_wave(torch.linspace(0, 1, V))[None,:] * 0.5
|
|
||||||
scale = scale.flatten()
|
|
||||||
self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)
|
|
||||||
|
|
||||||
def triangle_wave(self, values: torch.Tensor, period=1) -> torch.Tensor:
|
|
||||||
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
|
||||||
@@ -746,170 +746,3 @@ class Decoder(nn.Module):
|
|||||||
if self.tanh_out:
|
if self.tanh_out:
|
||||||
h = torch.tanh(h)
|
h = torch.tanh(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class DecoderDual(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
ch,
|
|
||||||
out_ch,
|
|
||||||
ch_mult=(1, 2, 4, 8),
|
|
||||||
num_res_blocks,
|
|
||||||
attn_resolutions,
|
|
||||||
dropout=0.0,
|
|
||||||
resamp_with_conv=True,
|
|
||||||
in_channels,
|
|
||||||
resolution,
|
|
||||||
z_channels,
|
|
||||||
give_pre_end=False,
|
|
||||||
tanh_out=False,
|
|
||||||
use_linear_attn=False,
|
|
||||||
attn_type="vanilla",
|
|
||||||
**ignorekwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if use_linear_attn:
|
|
||||||
attn_type = "linear"
|
|
||||||
self.ch = ch
|
|
||||||
self.temb_ch = 0
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.resolution = resolution
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.give_pre_end = give_pre_end
|
|
||||||
self.tanh_out = tanh_out
|
|
||||||
|
|
||||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
|
||||||
in_ch_mult = (1,) + tuple(ch_mult)
|
|
||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
|
||||||
logpy.info(
|
|
||||||
"Working with z of shape {} = {} dimensions.".format(
|
|
||||||
self.z_shape, np.prod(self.z_shape)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
make_attn_cls = self._make_attn()
|
|
||||||
make_resblock_cls = self._make_resblock()
|
|
||||||
make_conv_cls = self._make_conv()
|
|
||||||
|
|
||||||
# z to block_in (处理单个 latent)
|
|
||||||
self.conv_in = torch.nn.Conv2d(
|
|
||||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# middle
|
|
||||||
self.mid = nn.Module()
|
|
||||||
self.mid.block_1 = make_resblock_cls(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_in,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
|
|
||||||
self.mid.block_2 = make_resblock_cls(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_in,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# upsampling
|
|
||||||
self.up = nn.ModuleList()
|
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
|
||||||
block = nn.ModuleList()
|
|
||||||
attn = nn.ModuleList()
|
|
||||||
block_out = ch * ch_mult[i_level]
|
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
|
||||||
block.append(
|
|
||||||
make_resblock_cls(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_out,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
block_in = block_out
|
|
||||||
if curr_res in attn_resolutions:
|
|
||||||
attn.append(make_attn_cls(block_in, attn_type=attn_type))
|
|
||||||
up = nn.Module()
|
|
||||||
up.block = block
|
|
||||||
up.attn = attn
|
|
||||||
if i_level != 0:
|
|
||||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
|
||||||
curr_res = curr_res * 2
|
|
||||||
self.up.insert(0, up) # prepend to get consistent order
|
|
||||||
|
|
||||||
# end
|
|
||||||
self.norm_out = Normalize(block_in)
|
|
||||||
self.conv_out = make_conv_cls(
|
|
||||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_attn(self) -> Callable:
|
|
||||||
return make_attn
|
|
||||||
|
|
||||||
def _make_resblock(self) -> Callable:
|
|
||||||
return ResnetBlock
|
|
||||||
|
|
||||||
def _make_conv(self) -> Callable:
|
|
||||||
return torch.nn.Conv2d
|
|
||||||
|
|
||||||
def get_last_layer(self, **kwargs):
|
|
||||||
return self.conv_out.weight
|
|
||||||
|
|
||||||
def forward(self, z, **kwargs):
|
|
||||||
"""
|
|
||||||
输入 z 的形状应为 (B, 2 * z_channels, H, W)
|
|
||||||
- 其中前一半通道为第一个 latent,后一半通道为第二个 latent
|
|
||||||
- 分离后分别解码,最终在 W 维度拼接
|
|
||||||
"""
|
|
||||||
# 断言检查,确保输入的通道数是 2 倍的 z_channels
|
|
||||||
assert (
|
|
||||||
z.shape[1] == 2 * self.z_shape[1]
|
|
||||||
), f"Expected {2 * self.z_shape[1]} channels, got {z.shape[1]}"
|
|
||||||
|
|
||||||
# 分割 latent 为两个部分
|
|
||||||
z1, z2 = torch.chunk(z, 2, dim=1) # 按照通道维度 (C) 切分
|
|
||||||
|
|
||||||
# 分别解码
|
|
||||||
img1 = self.decode_single(z1, **kwargs)
|
|
||||||
img2 = self.decode_single(z2, **kwargs)
|
|
||||||
|
|
||||||
# 沿着 W 维度拼接
|
|
||||||
output = torch.cat([img1, img2], dim=-1) # 在 width 维度拼接
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def decode_single(self, z, **kwargs):
|
|
||||||
"""解码单个 latent 到一张图像"""
|
|
||||||
self.last_z_shape = z.shape
|
|
||||||
|
|
||||||
# z to block_in
|
|
||||||
h = self.conv_in(z)
|
|
||||||
|
|
||||||
# middle
|
|
||||||
h = self.mid.block_1(h, None, **kwargs)
|
|
||||||
h = self.mid.attn_1(h, **kwargs)
|
|
||||||
h = self.mid.block_2(h, None, **kwargs)
|
|
||||||
|
|
||||||
# upsampling
|
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
|
||||||
h = self.up[i_level].block[i_block](h, None, **kwargs)
|
|
||||||
if len(self.up[i_level].attn) > 0:
|
|
||||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
|
||||||
if i_level != 0:
|
|
||||||
h = self.up[i_level].upsample(h)
|
|
||||||
|
|
||||||
# end
|
|
||||||
h = self.norm_out(h)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
h = self.conv_out(h, **kwargs)
|
|
||||||
if self.tanh_out:
|
|
||||||
h = torch.tanh(h)
|
|
||||||
|
|
||||||
return h
|
|
||||||
|
|
||||||
@@ -74,62 +74,21 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||||||
x: th.Tensor,
|
x: th.Tensor,
|
||||||
emb: th.Tensor,
|
emb: th.Tensor,
|
||||||
context: Optional[th.Tensor] = None,
|
context: Optional[th.Tensor] = None,
|
||||||
cam: Optional[th.Tensor] = None,
|
|
||||||
image_only_indicator: Optional[th.Tensor] = None,
|
image_only_indicator: Optional[th.Tensor] = None,
|
||||||
cond_view: Optional[th.Tensor] = None,
|
|
||||||
cond_motion: Optional[th.Tensor] = None,
|
|
||||||
time_context: Optional[int] = None,
|
time_context: Optional[int] = None,
|
||||||
num_video_frames: Optional[int] = None,
|
num_video_frames: Optional[int] = None,
|
||||||
time_step: Optional[int] = None,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
from ...modules.diffusionmodules.video_model import VideoResBlock, PostHocResBlockWithTime
|
from ...modules.diffusionmodules.video_model import VideoResBlock
|
||||||
from ...modules.spacetime_attention import (
|
|
||||||
BasicTransformerTimeMixBlock,
|
|
||||||
PostHocSpatialTransformerWithTimeMixing,
|
|
||||||
PostHocSpatialTransformerWithTimeMixingAndMotion,
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer in self:
|
for layer in self:
|
||||||
module = layer
|
module = layer
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(module, TimestepBlock) and not isinstance(
|
||||||
module,
|
module, VideoResBlock
|
||||||
(
|
|
||||||
BasicTransformerTimeMixBlock,
|
|
||||||
PostHocSpatialTransformerWithTimeMixing,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
x = layer(
|
x = layer(x, emb)
|
||||||
x,
|
elif isinstance(module, VideoResBlock):
|
||||||
context,
|
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||||
emb,
|
|
||||||
time_context,
|
|
||||||
num_video_frames,
|
|
||||||
image_only_indicator,
|
|
||||||
cond_view,
|
|
||||||
cond_motion,
|
|
||||||
time_step,
|
|
||||||
name,
|
|
||||||
)
|
|
||||||
elif isinstance(
|
|
||||||
module,
|
|
||||||
(
|
|
||||||
PostHocSpatialTransformerWithTimeMixingAndMotion,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
x = layer(
|
|
||||||
x,
|
|
||||||
context,
|
|
||||||
emb,
|
|
||||||
time_context,
|
|
||||||
num_video_frames,
|
|
||||||
image_only_indicator,
|
|
||||||
cond_view,
|
|
||||||
cond_motion,
|
|
||||||
time_step,
|
|
||||||
name,
|
|
||||||
)
|
|
||||||
elif isinstance(module, SpatialVideoTransformer):
|
elif isinstance(module, SpatialVideoTransformer):
|
||||||
x = layer(
|
x = layer(
|
||||||
x,
|
x,
|
||||||
@@ -137,16 +96,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||||||
time_context,
|
time_context,
|
||||||
num_video_frames,
|
num_video_frames,
|
||||||
image_only_indicator,
|
image_only_indicator,
|
||||||
# time_step,
|
|
||||||
)
|
)
|
||||||
elif isinstance(module, PostHocResBlockWithTime):
|
|
||||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
|
||||||
elif isinstance(module, VideoResBlock):
|
|
||||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
|
||||||
elif isinstance(module, TimestepBlock) and not isinstance(
|
|
||||||
module, VideoResBlock
|
|
||||||
):
|
|
||||||
x = layer(x, emb)
|
|
||||||
elif isinstance(module, SpatialTransformer):
|
elif isinstance(module, SpatialTransformer):
|
||||||
x = layer(x, context)
|
x = layer(x, context)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Optional, Union
|
|
||||||
from ...util import default, instantiate_from_config
|
from ...util import default, instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
@@ -29,10 +29,3 @@ class DiscreteSampling:
|
|||||||
torch.randint(0, self.num_idx, (n_samples,)),
|
torch.randint(0, self.num_idx, (n_samples,)),
|
||||||
)
|
)
|
||||||
return self.idx_to_sigma(idx)
|
return self.idx_to_sigma(idx)
|
||||||
|
|
||||||
|
|
||||||
class ZeroSampler:
|
|
||||||
def __call__(
|
|
||||||
self, n_samples: int, rand: Optional[torch.Tensor] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5
|
|
||||||
|
|||||||
@@ -17,36 +17,6 @@ import torch.nn as nn
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
def get_alpha(
|
|
||||||
merge_strategy: str,
|
|
||||||
mix_factor: Optional[torch.Tensor],
|
|
||||||
image_only_indicator: torch.Tensor,
|
|
||||||
apply_sigmoid: bool = True,
|
|
||||||
is_attn: bool = False,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if merge_strategy == "fixed" or merge_strategy == "learned":
|
|
||||||
alpha = mix_factor
|
|
||||||
elif merge_strategy == "learned_with_images":
|
|
||||||
alpha = torch.where(
|
|
||||||
image_only_indicator.bool(),
|
|
||||||
torch.ones(1, 1, device=image_only_indicator.device),
|
|
||||||
rearrange(mix_factor, "... -> ... 1"),
|
|
||||||
)
|
|
||||||
if is_attn:
|
|
||||||
alpha = rearrange(alpha, "b t -> (b t) 1 1")
|
|
||||||
else:
|
|
||||||
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
|
|
||||||
elif merge_strategy == "fixed_with_images":
|
|
||||||
alpha = image_only_indicator
|
|
||||||
if is_attn:
|
|
||||||
alpha = rearrange(alpha, "b t -> (b t) 1 1")
|
|
||||||
else:
|
|
||||||
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
return torch.sigmoid(alpha) if apply_sigmoid else alpha
|
|
||||||
|
|
||||||
|
|
||||||
def make_beta_schedule(
|
def make_beta_schedule(
|
||||||
schedule,
|
schedule,
|
||||||
n_timestep,
|
n_timestep,
|
||||||
|
|||||||
@@ -5,15 +5,9 @@ from einops import rearrange
|
|||||||
|
|
||||||
from ...modules.diffusionmodules.openaimodel import *
|
from ...modules.diffusionmodules.openaimodel import *
|
||||||
from ...modules.video_attention import SpatialVideoTransformer
|
from ...modules.video_attention import SpatialVideoTransformer
|
||||||
from ...modules.spacetime_attention import (
|
|
||||||
BasicTransformerTimeMixBlock,
|
|
||||||
PostHocSpatialTransformerWithTimeMixing,
|
|
||||||
PostHocSpatialTransformerWithTimeMixingAndMotion,
|
|
||||||
)
|
|
||||||
from ...util import default
|
from ...util import default
|
||||||
from .util import AlphaBlender, get_alpha
|
from .util import AlphaBlender
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
class VideoResBlock(ResBlock):
|
class VideoResBlock(ResBlock):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -497,913 +491,3 @@ class VideoUNet(nn.Module):
|
|||||||
)
|
)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
||||||
|
|
||||||
class PostHocAttentionBlockWithTimeMixing(AttentionBlock):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
n_heads: int,
|
|
||||||
d_head: int,
|
|
||||||
use_checkpoint: bool = False,
|
|
||||||
use_new_attention_order: bool = False,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
use_spatial_context: bool = False,
|
|
||||||
merge_strategy: bool = "fixed",
|
|
||||||
merge_factor: float = 0.5,
|
|
||||||
apply_sigmoid_to_merge: bool = True,
|
|
||||||
ff_in: bool = False,
|
|
||||||
attn_mode: str = "softmax",
|
|
||||||
disable_temporal_crossattention: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
in_channels,
|
|
||||||
n_heads,
|
|
||||||
d_head,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
)
|
|
||||||
inner_dim = n_heads * d_head
|
|
||||||
|
|
||||||
self.time_mix_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
BasicTransformerTimeMixBlock(
|
|
||||||
inner_dim,
|
|
||||||
n_heads,
|
|
||||||
d_head,
|
|
||||||
dropout=dropout,
|
|
||||||
checkpoint=use_checkpoint,
|
|
||||||
ff_in=ff_in,
|
|
||||||
attn_mode=attn_mode,
|
|
||||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
time_embed_dim = self.in_channels * 4
|
|
||||||
self.time_mix_time_embed = nn.Sequential(
|
|
||||||
linear(self.in_channels, time_embed_dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
linear(time_embed_dim, self.in_channels),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.use_spatial_context = use_spatial_context
|
|
||||||
|
|
||||||
if merge_strategy == "fixed":
|
|
||||||
self.register_buffer("mix_factor", th.Tensor([merge_factor]))
|
|
||||||
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
|
||||||
self.register_parameter(
|
|
||||||
"mix_factor", th.nn.Parameter(th.Tensor([merge_factor]))
|
|
||||||
)
|
|
||||||
elif merge_strategy == "fixed_with_images":
|
|
||||||
self.mix_factor = None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
|
||||||
|
|
||||||
self.get_alpha_fn = functools.partial(
|
|
||||||
get_alpha,
|
|
||||||
merge_strategy,
|
|
||||||
self.mix_factor,
|
|
||||||
apply_sigmoid=apply_sigmoid_to_merge,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: th.Tensor,
|
|
||||||
context: Optional[th.Tensor] = None,
|
|
||||||
# cam: Optional[th.Tensor] = None,
|
|
||||||
time_context: Optional[th.Tensor] = None,
|
|
||||||
timesteps: Optional[int] = None,
|
|
||||||
image_only_indicator: Optional[th.Tensor] = None,
|
|
||||||
conv_view: Optional[th.Tensor] = None,
|
|
||||||
conv_motion: Optional[th.Tensor] = None,
|
|
||||||
):
|
|
||||||
if time_context is not None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
_, _, h, w = x.shape
|
|
||||||
if exists(context):
|
|
||||||
context = rearrange(context, "b t ... -> (b t) ...")
|
|
||||||
if self.use_spatial_context:
|
|
||||||
time_context = repeat(context[:, 0], "b ... -> (b n) ...", n=h * w)
|
|
||||||
|
|
||||||
x = super().forward(
|
|
||||||
x,
|
|
||||||
)
|
|
||||||
|
|
||||||
x = rearrange(x, "b c h w -> b (h w) c")
|
|
||||||
x_mix = x
|
|
||||||
|
|
||||||
num_frames = th.arange(timesteps, device=x.device)
|
|
||||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
|
||||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
|
||||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
|
||||||
emb = self.time_mix_time_embed(t_emb)
|
|
||||||
emb = emb[:, None, :]
|
|
||||||
x_mix = x_mix + emb
|
|
||||||
|
|
||||||
x_mix = self.time_mix_blocks[0](
|
|
||||||
x_mix, context=time_context, timesteps=timesteps
|
|
||||||
)
|
|
||||||
|
|
||||||
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
|
||||||
x = alpha * x + (1.0 - alpha) * x_mix
|
|
||||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class PostHocResBlockWithTime(ResBlock):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channels: int,
|
|
||||||
emb_channels: int,
|
|
||||||
dropout: float,
|
|
||||||
time_kernel_size: Union[int, List[int]] = 3,
|
|
||||||
merge_strategy: bool = "fixed",
|
|
||||||
merge_factor: float = 0.5,
|
|
||||||
apply_sigmoid_to_merge: bool = True,
|
|
||||||
out_channels: Optional[int] = None,
|
|
||||||
use_conv: bool = False,
|
|
||||||
use_scale_shift_norm: bool = False,
|
|
||||||
dims: int = 2,
|
|
||||||
use_checkpoint: bool = False,
|
|
||||||
up: bool = False,
|
|
||||||
down: bool = False,
|
|
||||||
time_mix_legacy: bool = True,
|
|
||||||
replicate_bug: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
channels,
|
|
||||||
emb_channels,
|
|
||||||
dropout,
|
|
||||||
out_channels=out_channels,
|
|
||||||
use_conv=use_conv,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
up=up,
|
|
||||||
down=down,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.time_mix_blocks = ResBlock(
|
|
||||||
default(out_channels, channels),
|
|
||||||
emb_channels,
|
|
||||||
dropout=dropout,
|
|
||||||
dims=3,
|
|
||||||
out_channels=default(out_channels, channels),
|
|
||||||
use_scale_shift_norm=False,
|
|
||||||
use_conv=False,
|
|
||||||
up=False,
|
|
||||||
down=False,
|
|
||||||
kernel_size=time_kernel_size,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
exchange_temb_dims=True,
|
|
||||||
)
|
|
||||||
self.time_mix_legacy = time_mix_legacy
|
|
||||||
if self.time_mix_legacy:
|
|
||||||
if merge_strategy == "fixed":
|
|
||||||
self.register_buffer("mix_factor", th.Tensor([merge_factor]))
|
|
||||||
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
|
||||||
self.register_parameter(
|
|
||||||
"mix_factor", th.nn.Parameter(th.Tensor([merge_factor]))
|
|
||||||
)
|
|
||||||
elif merge_strategy == "fixed_with_images":
|
|
||||||
self.mix_factor = None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
|
||||||
|
|
||||||
self.get_alpha_fn = functools.partial(
|
|
||||||
get_alpha,
|
|
||||||
merge_strategy,
|
|
||||||
self.mix_factor,
|
|
||||||
apply_sigmoid=apply_sigmoid_to_merge,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if False: # replicate_bug:
|
|
||||||
logpy.warning(
|
|
||||||
"*****************************************************************************************\n"
|
|
||||||
"GRAVE WARNING: YOU'RE USING THE BUGGY LEGACY ALPHABLENDER!!! ARE YOU SURE YOU WANT THIS?!\n"
|
|
||||||
"*****************************************************************************************"
|
|
||||||
)
|
|
||||||
self.time_mixer = LegacyAlphaBlenderWithBug(
|
|
||||||
alpha=merge_factor,
|
|
||||||
merge_strategy=merge_strategy,
|
|
||||||
rearrange_pattern="b t -> b 1 t 1 1",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.time_mixer = AlphaBlender(
|
|
||||||
alpha=merge_factor,
|
|
||||||
merge_strategy=merge_strategy,
|
|
||||||
rearrange_pattern="b t -> b 1 t 1 1",
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: th.Tensor,
|
|
||||||
emb: th.Tensor,
|
|
||||||
num_video_frames: int,
|
|
||||||
image_only_indicator: Optional[th.Tensor] = None,
|
|
||||||
cond_view: Optional[th.Tensor] = None,
|
|
||||||
cond_motion: Optional[th.Tensor] = None,
|
|
||||||
) -> th.Tensor:
|
|
||||||
x = super().forward(x, emb)
|
|
||||||
|
|
||||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
|
||||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
|
||||||
|
|
||||||
x = self.time_mix_blocks(
|
|
||||||
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.time_mix_legacy:
|
|
||||||
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator*0.0)
|
|
||||||
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
|
||||||
else:
|
|
||||||
x = self.time_mixer(
|
|
||||||
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator*0.0
|
|
||||||
)
|
|
||||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SpatialUNetModelWithTime(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
model_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
num_res_blocks: int,
|
|
||||||
attention_resolutions: int,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
channel_mult: List[int] = (1, 2, 4, 8),
|
|
||||||
conv_resample: bool = True,
|
|
||||||
dims: int = 2,
|
|
||||||
num_classes: Optional[int] = None,
|
|
||||||
use_checkpoint: bool = False,
|
|
||||||
num_heads: int = -1,
|
|
||||||
num_head_channels: int = -1,
|
|
||||||
num_heads_upsample: int = -1,
|
|
||||||
use_scale_shift_norm: bool = False,
|
|
||||||
resblock_updown: bool = False,
|
|
||||||
use_new_attention_order: bool = False,
|
|
||||||
use_spatial_transformer: bool = False,
|
|
||||||
transformer_depth: Union[List[int], int] = 1,
|
|
||||||
transformer_depth_middle: Optional[int] = None,
|
|
||||||
context_dim: Optional[int] = None,
|
|
||||||
time_downup: bool = False,
|
|
||||||
time_context_dim: Optional[int] = None,
|
|
||||||
view_context_dim: Optional[int] = None,
|
|
||||||
motion_context_dim: Optional[int] = None,
|
|
||||||
extra_ff_mix_layer: bool = False,
|
|
||||||
use_spatial_context: bool = False,
|
|
||||||
time_block_merge_strategy: str = "fixed",
|
|
||||||
time_block_merge_factor: float = 0.5,
|
|
||||||
view_block_merge_factor: float = 0.5,
|
|
||||||
motion_block_merge_factor: float = 0.5,
|
|
||||||
spatial_transformer_attn_type: str = "softmax",
|
|
||||||
time_kernel_size: Union[int, List[int]] = 3,
|
|
||||||
use_linear_in_transformer: bool = False,
|
|
||||||
legacy: bool = True,
|
|
||||||
adm_in_channels: Optional[int] = None,
|
|
||||||
use_temporal_resblock: bool = True,
|
|
||||||
disable_temporal_crossattention: bool = False,
|
|
||||||
time_mix_legacy: bool = True,
|
|
||||||
max_ddpm_temb_period: int = 10000,
|
|
||||||
replicate_time_mix_bug: bool = False,
|
|
||||||
use_motion_attention: bool = False,
|
|
||||||
use_camera_emb: bool = False,
|
|
||||||
use_3d_attention: bool = False,
|
|
||||||
separate_motion_merge_factor: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if use_spatial_transformer:
|
|
||||||
assert context_dim is not None
|
|
||||||
|
|
||||||
if context_dim is not None:
|
|
||||||
assert use_spatial_transformer
|
|
||||||
|
|
||||||
if num_heads_upsample == -1:
|
|
||||||
num_heads_upsample = num_heads
|
|
||||||
|
|
||||||
if num_heads == -1:
|
|
||||||
assert num_head_channels != -1
|
|
||||||
|
|
||||||
if num_head_channels == -1:
|
|
||||||
assert num_heads != -1
|
|
||||||
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.model_channels = model_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
if isinstance(transformer_depth, int):
|
|
||||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
|
||||||
transformer_depth_middle = default(
|
|
||||||
transformer_depth_middle, transformer_depth[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.attention_resolutions = attention_resolutions
|
|
||||||
self.dropout = dropout
|
|
||||||
self.channel_mult = channel_mult
|
|
||||||
self.conv_resample = conv_resample
|
|
||||||
self.num_classes = num_classes
|
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.num_head_channels = num_head_channels
|
|
||||||
self.num_heads_upsample = num_heads_upsample
|
|
||||||
self.use_temporal_resblocks = use_temporal_resblock
|
|
||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
|
||||||
self.time_embed = nn.Sequential(
|
|
||||||
linear(model_channels, time_embed_dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
linear(time_embed_dim, time_embed_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.num_classes is not None:
|
|
||||||
if isinstance(self.num_classes, int):
|
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
|
||||||
elif self.num_classes == "continuous":
|
|
||||||
print("setting up linear c_adm embedding layer")
|
|
||||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
|
||||||
elif self.num_classes == "timestep":
|
|
||||||
self.label_emb = nn.Sequential(
|
|
||||||
Timestep(model_channels),
|
|
||||||
nn.Sequential(
|
|
||||||
linear(model_channels, time_embed_dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
linear(time_embed_dim, time_embed_dim),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self.num_classes == "sequential":
|
|
||||||
assert adm_in_channels is not None
|
|
||||||
self.label_emb = nn.Sequential(
|
|
||||||
nn.Sequential(
|
|
||||||
linear(adm_in_channels, time_embed_dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
linear(time_embed_dim, time_embed_dim),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
TimestepEmbedSequential(
|
|
||||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self._feature_size = model_channels
|
|
||||||
input_block_chans = [model_channels]
|
|
||||||
ch = model_channels
|
|
||||||
ds = 1
|
|
||||||
|
|
||||||
def get_attention_layer(
|
|
||||||
ch,
|
|
||||||
num_heads,
|
|
||||||
dim_head,
|
|
||||||
depth=1,
|
|
||||||
context_dim=None,
|
|
||||||
use_checkpoint=False,
|
|
||||||
disabled_sa=False,
|
|
||||||
):
|
|
||||||
if not use_spatial_transformer:
|
|
||||||
return PostHocAttentionBlockWithTimeMixing(
|
|
||||||
ch,
|
|
||||||
num_heads,
|
|
||||||
dim_head,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
dropout=dropout,
|
|
||||||
ff_in=extra_ff_mix_layer,
|
|
||||||
use_spatial_context=use_spatial_context,
|
|
||||||
merge_strategy=time_block_merge_strategy,
|
|
||||||
merge_factor=time_block_merge_factor,
|
|
||||||
attn_mode=spatial_transformer_attn_type,
|
|
||||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif use_motion_attention:
|
|
||||||
return PostHocSpatialTransformerWithTimeMixingAndMotion(
|
|
||||||
ch,
|
|
||||||
num_heads,
|
|
||||||
dim_head,
|
|
||||||
depth=depth,
|
|
||||||
context_dim=context_dim,
|
|
||||||
time_context_dim=time_context_dim,
|
|
||||||
motion_context_dim=motion_context_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
ff_in=extra_ff_mix_layer,
|
|
||||||
use_spatial_context=use_spatial_context,
|
|
||||||
use_camera_emb=use_camera_emb,
|
|
||||||
use_3d_attention=use_3d_attention,
|
|
||||||
separate_motion_merge_factor=separate_motion_merge_factor,
|
|
||||||
adm_in_channels=adm_in_channels,
|
|
||||||
merge_strategy=time_block_merge_strategy,
|
|
||||||
merge_factor=time_block_merge_factor,
|
|
||||||
merge_factor_motion=motion_block_merge_factor,
|
|
||||||
checkpoint=use_checkpoint,
|
|
||||||
use_linear=use_linear_in_transformer,
|
|
||||||
attn_mode=spatial_transformer_attn_type,
|
|
||||||
disable_self_attn=disabled_sa,
|
|
||||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
|
||||||
time_mix_legacy=time_mix_legacy,
|
|
||||||
max_time_embed_period=max_ddpm_temb_period,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return PostHocSpatialTransformerWithTimeMixing(
|
|
||||||
ch,
|
|
||||||
num_heads,
|
|
||||||
dim_head,
|
|
||||||
depth=depth,
|
|
||||||
context_dim=context_dim,
|
|
||||||
time_context_dim=time_context_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
ff_in=extra_ff_mix_layer,
|
|
||||||
use_spatial_context=use_spatial_context,
|
|
||||||
merge_strategy=time_block_merge_strategy,
|
|
||||||
merge_factor=time_block_merge_factor,
|
|
||||||
checkpoint=use_checkpoint,
|
|
||||||
use_linear=use_linear_in_transformer,
|
|
||||||
attn_mode=spatial_transformer_attn_type,
|
|
||||||
disable_self_attn=disabled_sa,
|
|
||||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
|
||||||
time_mix_legacy=time_mix_legacy,
|
|
||||||
max_time_embed_period=max_ddpm_temb_period,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_resblock(
|
|
||||||
time_block_merge_factor,
|
|
||||||
time_block_merge_strategy,
|
|
||||||
time_kernel_size,
|
|
||||||
ch,
|
|
||||||
time_embed_dim,
|
|
||||||
dropout,
|
|
||||||
out_ch,
|
|
||||||
dims,
|
|
||||||
use_checkpoint,
|
|
||||||
use_scale_shift_norm,
|
|
||||||
down=False,
|
|
||||||
up=False,
|
|
||||||
):
|
|
||||||
if self.use_temporal_resblocks:
|
|
||||||
return PostHocResBlockWithTime(
|
|
||||||
merge_factor=time_block_merge_factor,
|
|
||||||
merge_strategy=time_block_merge_strategy,
|
|
||||||
time_kernel_size=time_kernel_size,
|
|
||||||
channels=ch,
|
|
||||||
emb_channels=time_embed_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
out_channels=out_ch,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
down=down,
|
|
||||||
up=up,
|
|
||||||
time_mix_legacy=time_mix_legacy,
|
|
||||||
replicate_bug=replicate_time_mix_bug,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ResBlock(
|
|
||||||
channels=ch,
|
|
||||||
emb_channels=time_embed_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
out_channels=out_ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
dims=dims,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
down=down,
|
|
||||||
up=up,
|
|
||||||
)
|
|
||||||
|
|
||||||
for level, mult in enumerate(channel_mult):
|
|
||||||
for _ in range(num_res_blocks):
|
|
||||||
layers = [
|
|
||||||
get_resblock(
|
|
||||||
time_block_merge_factor=time_block_merge_factor,
|
|
||||||
time_block_merge_strategy=time_block_merge_strategy,
|
|
||||||
time_kernel_size=time_kernel_size,
|
|
||||||
ch=ch,
|
|
||||||
time_embed_dim=time_embed_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
out_ch=mult * model_channels,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
ch = mult * model_channels
|
|
||||||
if ds in attention_resolutions:
|
|
||||||
if num_head_channels == -1:
|
|
||||||
dim_head = ch // num_heads
|
|
||||||
else:
|
|
||||||
num_heads = ch // num_head_channels
|
|
||||||
dim_head = num_head_channels
|
|
||||||
if legacy:
|
|
||||||
dim_head = (
|
|
||||||
ch // num_heads
|
|
||||||
if use_spatial_transformer
|
|
||||||
else num_head_channels
|
|
||||||
)
|
|
||||||
|
|
||||||
layers.append(
|
|
||||||
get_attention_layer(
|
|
||||||
ch,
|
|
||||||
num_heads,
|
|
||||||
dim_head,
|
|
||||||
depth=transformer_depth[level],
|
|
||||||
context_dim=context_dim,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
disabled_sa=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
|
||||||
self._feature_size += ch
|
|
||||||
input_block_chans.append(ch)
|
|
||||||
if level != len(channel_mult) - 1:
|
|
||||||
ds *= 2
|
|
||||||
out_ch = ch
|
|
||||||
self.input_blocks.append(
|
|
||||||
TimestepEmbedSequential(
|
|
||||||
get_resblock(
|
|
||||||
time_block_merge_factor=time_block_merge_factor,
|
|
||||||
time_block_merge_strategy=time_block_merge_strategy,
|
|
||||||
time_kernel_size=time_kernel_size,
|
|
||||||
ch=ch,
|
|
||||||
time_embed_dim=time_embed_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
out_ch=out_ch,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
down=True,
|
|
||||||
)
|
|
||||||
if resblock_updown
|
|
||||||
else Downsample(
|
|
||||||
ch,
|
|
||||||
conv_resample,
|
|
||||||
dims=dims,
|
|
||||||
out_channels=out_ch,
|
|
||||||
third_down=time_downup,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
ch = out_ch
|
|
||||||
input_block_chans.append(ch)
|
|
||||||
|
|
||||||
self._feature_size += ch
|
|
||||||
|
|
||||||
if num_head_channels == -1:
|
|
||||||
dim_head = ch // num_heads
|
|
||||||
else:
|
|
||||||
num_heads = ch // num_head_channels
|
|
||||||
dim_head = num_head_channels
|
|
||||||
if legacy:
|
|
||||||
# num_heads = 1
|
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
|
||||||
|
|
||||||
self.middle_block = TimestepEmbedSequential(
|
|
||||||
get_resblock(
|
|
||||||
time_block_merge_factor=time_block_merge_factor,
|
|
||||||
time_block_merge_strategy=time_block_merge_strategy,
|
|
||||||
time_kernel_size=time_kernel_size,
|
|
||||||
ch=ch,
|
|
||||||
time_embed_dim=time_embed_dim,
|
|
||||||
out_ch=None,
|
|
||||||
dropout=dropout,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
),
|
|
||||||
get_attention_layer(
|
|
||||||
ch,
|
|
||||||
num_heads,
|
|
||||||
dim_head,
|
|
||||||
depth=transformer_depth_middle,
|
|
||||||
context_dim=context_dim,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
),
|
|
||||||
get_resblock(
|
|
||||||
time_block_merge_factor=time_block_merge_factor,
|
|
||||||
time_block_merge_strategy=time_block_merge_strategy,
|
|
||||||
time_kernel_size=time_kernel_size,
|
|
||||||
ch=ch,
|
|
||||||
out_ch=None,
|
|
||||||
time_embed_dim=time_embed_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self._feature_size += ch
|
|
||||||
|
|
||||||
self.output_blocks = nn.ModuleList([])
|
|
||||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
|
||||||
for i in range(num_res_blocks + 1):
|
|
||||||
ich = input_block_chans.pop()
|
|
||||||
layers = [
|
|
||||||
get_resblock(
|
|
||||||
time_block_merge_factor=time_block_merge_factor,
|
|
||||||
time_block_merge_strategy=time_block_merge_strategy,
|
|
||||||
time_kernel_size=time_kernel_size,
|
|
||||||
ch=ch + ich,
|
|
||||||
time_embed_dim=time_embed_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
out_ch=model_channels * mult,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
ch = model_channels * mult
|
|
||||||
if ds in attention_resolutions:
|
|
||||||
if num_head_channels == -1:
|
|
||||||
dim_head = ch // num_heads
|
|
||||||
else:
|
|
||||||
num_heads = ch // num_head_channels
|
|
||||||
dim_head = num_head_channels
|
|
||||||
if legacy:
|
|
||||||
dim_head = (
|
|
||||||
ch // num_heads
|
|
||||||
if use_spatial_transformer
|
|
||||||
else num_head_channels
|
|
||||||
)
|
|
||||||
|
|
||||||
layers.append(
|
|
||||||
get_attention_layer(
|
|
||||||
ch,
|
|
||||||
num_heads,
|
|
||||||
dim_head,
|
|
||||||
depth=transformer_depth[level],
|
|
||||||
context_dim=context_dim,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
disabled_sa=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if level and i == num_res_blocks:
|
|
||||||
out_ch = ch
|
|
||||||
ds //= 2
|
|
||||||
layers.append(
|
|
||||||
get_resblock(
|
|
||||||
time_block_merge_factor=time_block_merge_factor,
|
|
||||||
time_block_merge_strategy=time_block_merge_strategy,
|
|
||||||
time_kernel_size=time_kernel_size,
|
|
||||||
ch=ch,
|
|
||||||
time_embed_dim=time_embed_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
out_ch=out_ch,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
up=True,
|
|
||||||
)
|
|
||||||
if resblock_updown
|
|
||||||
else Upsample(
|
|
||||||
ch,
|
|
||||||
conv_resample,
|
|
||||||
dims=dims,
|
|
||||||
out_channels=out_ch,
|
|
||||||
third_up=time_downup,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
|
||||||
self._feature_size += ch
|
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
|
||||||
normalization(ch),
|
|
||||||
nn.SiLU(),
|
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: th.Tensor,
|
|
||||||
timesteps: th.Tensor,
|
|
||||||
context: Optional[th.Tensor] = None,
|
|
||||||
y: Optional[th.Tensor] = None,
|
|
||||||
cam: Optional[th.Tensor] = None,
|
|
||||||
time_context: Optional[th.Tensor] = None,
|
|
||||||
num_video_frames: Optional[int] = None,
|
|
||||||
image_only_indicator: Optional[th.Tensor] = None,
|
|
||||||
cond_view: Optional[th.Tensor] = None,
|
|
||||||
cond_motion: Optional[th.Tensor] = None,
|
|
||||||
time_step: Optional[int] = None,
|
|
||||||
):
|
|
||||||
assert (y is not None) == (
|
|
||||||
self.num_classes is not None
|
|
||||||
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
|
||||||
hs = []
|
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # 21 x 320
|
|
||||||
emb = self.time_embed(t_emb) # 21 x 1280
|
|
||||||
time = str(timesteps[0].data.cpu().numpy())
|
|
||||||
|
|
||||||
if self.num_classes is not None:
|
|
||||||
assert y.shape[0] == x.shape[0]
|
|
||||||
emb = emb + self.label_emb(y) # 21 x 1280
|
|
||||||
|
|
||||||
h = x # 21 x 8 x 64 x 64
|
|
||||||
for i, module in enumerate(self.input_blocks):
|
|
||||||
h = module(
|
|
||||||
h,
|
|
||||||
emb,
|
|
||||||
context=context,
|
|
||||||
cam=cam,
|
|
||||||
image_only_indicator=image_only_indicator,
|
|
||||||
cond_view=cond_view,
|
|
||||||
cond_motion=cond_motion,
|
|
||||||
time_context=time_context,
|
|
||||||
num_video_frames=num_video_frames,
|
|
||||||
time_step=time_step,
|
|
||||||
name='encoder_{}_{}'.format(time, i)
|
|
||||||
)
|
|
||||||
hs.append(h)
|
|
||||||
h = self.middle_block(
|
|
||||||
h,
|
|
||||||
emb,
|
|
||||||
context=context,
|
|
||||||
cam=cam,
|
|
||||||
image_only_indicator=image_only_indicator,
|
|
||||||
cond_view=cond_view,
|
|
||||||
cond_motion=cond_motion,
|
|
||||||
time_context=time_context,
|
|
||||||
num_video_frames=num_video_frames,
|
|
||||||
time_step=time_step,
|
|
||||||
name='middle_{}_0'.format(time, i)
|
|
||||||
)
|
|
||||||
for i, module in enumerate(self.output_blocks):
|
|
||||||
h = th.cat([h, hs.pop()], dim=1)
|
|
||||||
h = module(
|
|
||||||
h,
|
|
||||||
emb,
|
|
||||||
context=context,
|
|
||||||
cam=cam,
|
|
||||||
image_only_indicator=image_only_indicator,
|
|
||||||
cond_view=cond_view,
|
|
||||||
cond_motion=cond_motion,
|
|
||||||
time_context=time_context,
|
|
||||||
num_video_frames=num_video_frames,
|
|
||||||
time_step=time_step,
|
|
||||||
name='decoder_{}_{}'.format(time, i)
|
|
||||||
)
|
|
||||||
h = h.type(x.dtype)
|
|
||||||
return self.out(h)
|
|
||||||
|
|
||||||
|
|
||||||
class CrossNetworkLayer(nn.Module):
|
|
||||||
def __init__(self, feature_dim: int):
|
|
||||||
super().__init__()
|
|
||||||
self.fusion_conv = nn.Sequential(
|
|
||||||
nn.Conv2d(feature_dim * 2, feature_dim, kernel_size=1),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Conv2d(feature_dim, feature_dim, kernel_size=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, h1: torch.Tensor, h2: torch.Tensor):
|
|
||||||
"""
|
|
||||||
h1, h2: (B, C, H, W)
|
|
||||||
return: (out1, out2), (B, C, H, W)
|
|
||||||
"""
|
|
||||||
fused_input = torch.cat([h1, h2], dim=1) # (B, 2C, H, W)
|
|
||||||
fused_output = self.fusion_conv(fused_input) # (B, C, H, W)
|
|
||||||
out1 = fused_output + h1
|
|
||||||
out2 = fused_output + h2
|
|
||||||
return out1, out2
|
|
||||||
|
|
||||||
|
|
||||||
class DualSpatialUNetWithCrossComm(nn.Module):
|
|
||||||
def __init__(self, unet_config):
|
|
||||||
super().__init__()
|
|
||||||
self.num_classes = unet_config["num_classes"]
|
|
||||||
self.model_channels = unet_config["model_channels"]
|
|
||||||
|
|
||||||
self.net1 = SpatialUNetModelWithTime(**unet_config)
|
|
||||||
self.net2 = SpatialUNetModelWithTime(**unet_config)
|
|
||||||
|
|
||||||
self.input_cross_layers = nn.ModuleList()
|
|
||||||
for block in self.net1.input_blocks:
|
|
||||||
out_ch = self._get_block_out_channels(block)
|
|
||||||
self.input_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
|
|
||||||
|
|
||||||
middle_out_ch = self._get_block_out_channels(self.net1.middle_block)
|
|
||||||
self.middle_cross = CrossNetworkLayer(feature_dim=middle_out_ch)
|
|
||||||
|
|
||||||
self.output_cross_layers = nn.ModuleList()
|
|
||||||
for block in self.net1.output_blocks:
|
|
||||||
out_ch = self._get_block_out_channels(block)
|
|
||||||
self.output_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
|
|
||||||
|
|
||||||
def _get_block_out_channels(self, block: nn.Module) -> int:
|
|
||||||
mod_list = list(block.children())
|
|
||||||
for m in reversed(mod_list):
|
|
||||||
if hasattr(m, "out_channels"):
|
|
||||||
return m.out_channels
|
|
||||||
|
|
||||||
if isinstance(
|
|
||||||
m,
|
|
||||||
(SpatialTransformer, PostHocSpatialTransformerWithTimeMixingAndMotion),
|
|
||||||
):
|
|
||||||
return m.in_channels
|
|
||||||
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
return m.out_channels
|
|
||||||
|
|
||||||
raise ValueError(f"Cannot determine out_channels from block: {block}")
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: th.Tensor,
|
|
||||||
timesteps: th.Tensor,
|
|
||||||
context: Optional[th.Tensor] = None,
|
|
||||||
y: Optional[th.Tensor] = None,
|
|
||||||
cam: Optional[th.Tensor] = None,
|
|
||||||
time_context: Optional[th.Tensor] = None,
|
|
||||||
num_video_frames: Optional[int] = None,
|
|
||||||
image_only_indicator: Optional[th.Tensor] = None,
|
|
||||||
cond_view: Optional[th.Tensor] = None,
|
|
||||||
cond_motion: Optional[th.Tensor] = None,
|
|
||||||
time_step: Optional[int] = None,
|
|
||||||
):
|
|
||||||
|
|
||||||
# ============ encoder ============
|
|
||||||
h1, h2 = x[:, : x.shape[1] // 2], x[:, x.shape[1] // 2 :]
|
|
||||||
|
|
||||||
encoder_feats1 = []
|
|
||||||
encoder_feats2 = []
|
|
||||||
|
|
||||||
assert (y is not None) == (
|
|
||||||
self.num_classes is not None
|
|
||||||
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
|
||||||
|
|
||||||
t_emb = timestep_embedding(
|
|
||||||
timesteps, self.model_channels, repeat_only=False
|
|
||||||
) # 21 x 320
|
|
||||||
|
|
||||||
emb = self.net1.time_embed(t_emb) # 21 x 1280
|
|
||||||
time = str(timesteps[0].data.cpu().numpy())
|
|
||||||
|
|
||||||
if self.num_classes is not None:
|
|
||||||
assert y.shape[0] == h1.shape[0]
|
|
||||||
emb = emb + self.net1.label_emb(y) # 21 x 1280
|
|
||||||
|
|
||||||
filtered_args = {
|
|
||||||
"emb": emb,
|
|
||||||
"context": context,
|
|
||||||
"cam": cam,
|
|
||||||
"cond_view": cond_view,
|
|
||||||
"cond_motion": cond_motion,
|
|
||||||
"time_context": time_context,
|
|
||||||
"num_video_frames": num_video_frames,
|
|
||||||
"image_only_indicator": image_only_indicator,
|
|
||||||
"time_step": time_step,
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, (block1, block2) in enumerate(
|
|
||||||
zip(self.net1.input_blocks, self.net2.input_blocks)
|
|
||||||
):
|
|
||||||
h1 = block1(h1, name="encoder_{}_{}".format(time, i), **filtered_args)
|
|
||||||
h2 = block2(h2, name="encoder_{}_{}".format(time, i), **filtered_args)
|
|
||||||
|
|
||||||
# cross
|
|
||||||
h1, h2 = self.input_cross_layers[i](h1, h2)
|
|
||||||
|
|
||||||
encoder_feats1.append(h1)
|
|
||||||
encoder_feats2.append(h2)
|
|
||||||
|
|
||||||
# ============ middle block ============
|
|
||||||
h1 = self.net1.middle_block(
|
|
||||||
h1, name="middle_{}_0".format(time, i), **filtered_args
|
|
||||||
)
|
|
||||||
h2 = self.net2.middle_block(
|
|
||||||
h2, name="middle_{}_0".format(time, i), **filtered_args
|
|
||||||
)
|
|
||||||
|
|
||||||
# cross
|
|
||||||
h1, h2 = self.middle_cross(h1, h2)
|
|
||||||
|
|
||||||
# ============ decoder ============
|
|
||||||
for i, (block1, block2) in enumerate(
|
|
||||||
zip(self.net1.output_blocks, self.net2.output_blocks)
|
|
||||||
):
|
|
||||||
skip1 = encoder_feats1.pop()
|
|
||||||
skip2 = encoder_feats2.pop()
|
|
||||||
h1 = torch.cat([h1, skip1], dim=1)
|
|
||||||
h2 = torch.cat([h2, skip2], dim=1)
|
|
||||||
|
|
||||||
h1 = block1(h1, name="decoder_{}_{}".format(time, i), **filtered_args)
|
|
||||||
h2 = block2(h2, name="decoder_{}_{}".format(time, i), **filtered_args)
|
|
||||||
|
|
||||||
# cross
|
|
||||||
h1, h2 = self.output_cross_layers[i](h1, h2)
|
|
||||||
|
|
||||||
# ============ output ============
|
|
||||||
out1 = self.net1.out(h1) # shape: (B, out_channels, H, W)
|
|
||||||
out2 = self.net2.out(h2) # same shape
|
|
||||||
out = torch.cat([out1, out2], dim=1)
|
|
||||||
|
|
||||||
return out
|
|
||||||
@@ -6,7 +6,7 @@ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
|
|||||||
|
|
||||||
|
|
||||||
class IdentityWrapper(nn.Module):
|
class IdentityWrapper(nn.Module):
|
||||||
def __init__(self, diffusion_model, compile_model: bool = False, dual_concat: bool = False):
|
def __init__(self, diffusion_model, compile_model: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
compile = (
|
compile = (
|
||||||
torch.compile
|
torch.compile
|
||||||
@@ -15,7 +15,6 @@ class IdentityWrapper(nn.Module):
|
|||||||
else lambda x: x
|
else lambda x: x
|
||||||
)
|
)
|
||||||
self.diffusion_model = compile(diffusion_model)
|
self.diffusion_model = compile(diffusion_model)
|
||||||
self.dual_concat = dual_concat
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self.diffusion_model(*args, **kwargs)
|
return self.diffusion_model(*args, **kwargs)
|
||||||
@@ -25,29 +24,11 @@ class OpenAIWrapper(IdentityWrapper):
|
|||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
|
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.dual_concat:
|
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
||||||
x_1 = x[:, : x.shape[1] // 2]
|
return self.diffusion_model(
|
||||||
x_2 = x[:, x.shape[1] // 2 :]
|
x,
|
||||||
x_1 = torch.cat((x_1, c.get("concat", torch.Tensor([]).type_as(x_1))), dim=1)
|
timesteps=t,
|
||||||
x_2 = torch.cat((x_2, c.get("concat", torch.Tensor([]).type_as(x_2))), dim=1)
|
context=c.get("crossattn", None),
|
||||||
x = torch.cat((x_1, x_2), dim=1)
|
y=c.get("vector", None),
|
||||||
else:
|
**kwargs,
|
||||||
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
)
|
||||||
if "cond_view" in c:
|
|
||||||
return self.diffusion_model(
|
|
||||||
x,
|
|
||||||
timesteps=t,
|
|
||||||
context=c.get("crossattn", None),
|
|
||||||
y=c.get("vector", None),
|
|
||||||
cond_view=c.get("cond_view", None),
|
|
||||||
cond_motion=c.get("cond_motion", None),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self.diffusion_model(
|
|
||||||
x,
|
|
||||||
timesteps=t,
|
|
||||||
context=c.get("crossattn", None),
|
|
||||||
y=c.get("vector", None),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -69,8 +69,8 @@ class AbstractEmbModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GeneralConditioner(nn.Module):
|
class GeneralConditioner(nn.Module):
|
||||||
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat"} # , 5: "concat"}
|
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
|
||||||
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1, "cond_view": 1, "cond_motion": 1}
|
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
|
||||||
|
|
||||||
def __init__(self, emb_models: Union[List, ListConfig]):
|
def __init__(self, emb_models: Union[List, ListConfig]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -138,11 +138,7 @@ class GeneralConditioner(nn.Module):
|
|||||||
if not isinstance(emb_out, (list, tuple)):
|
if not isinstance(emb_out, (list, tuple)):
|
||||||
emb_out = [emb_out]
|
emb_out = [emb_out]
|
||||||
for emb in emb_out:
|
for emb in emb_out:
|
||||||
if embedder.input_key in ["cond_view", "cond_motion"]:
|
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
|
||||||
out_key = embedder.input_key
|
|
||||||
else:
|
|
||||||
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
|
|
||||||
|
|
||||||
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
|
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
|
||||||
emb = (
|
emb = (
|
||||||
expand_dims_like(
|
expand_dims_like(
|
||||||
@@ -998,10 +994,7 @@ class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
|
|||||||
sigmas = self.sigma_sampler(b).to(vid.device)
|
sigmas = self.sigma_sampler(b).to(vid.device)
|
||||||
if self.sigma_cond is not None:
|
if self.sigma_cond is not None:
|
||||||
sigma_cond = self.sigma_cond(sigmas)
|
sigma_cond = self.sigma_cond(sigmas)
|
||||||
if self.n_cond_frames == 1:
|
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
|
||||||
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
|
|
||||||
else:
|
|
||||||
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_cond_frames) # For SV4D
|
|
||||||
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
|
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
|
||||||
noise = torch.randn_like(vid)
|
noise = torch.randn_like(vid)
|
||||||
vid = vid + noise * append_dims(sigmas, vid.ndim)
|
vid = vid + noise * append_dims(sigmas, vid.ndim)
|
||||||
@@ -1024,9 +1017,8 @@ class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
|
|||||||
vid = torch.cat(all_out, dim=0)
|
vid = torch.cat(all_out, dim=0)
|
||||||
vid *= self.scale_factor
|
vid *= self.scale_factor
|
||||||
|
|
||||||
if self.n_cond_frames == 1:
|
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
|
||||||
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
|
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
|
||||||
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
|
|
||||||
|
|
||||||
return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
|
return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
|
||||||
|
|
||||||
|
|||||||
@@ -1,625 +0,0 @@
|
|||||||
from functools import partial
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ..modules.attention import *
|
|
||||||
from ..modules.diffusionmodules.util import (
|
|
||||||
AlphaBlender,
|
|
||||||
get_alpha,
|
|
||||||
linear,
|
|
||||||
mixed_checkpoint,
|
|
||||||
timestep_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TimeMixSequential(nn.Sequential):
|
|
||||||
def forward(self, x, context=None, timesteps=None):
|
|
||||||
for layer in self:
|
|
||||||
x = layer(x, context, timesteps)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerTimeMixBlock(nn.Module):
|
|
||||||
ATTENTION_MODES = {
|
|
||||||
"softmax": CrossAttention,
|
|
||||||
"softmax-xformers": MemoryEfficientCrossAttention,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
n_heads,
|
|
||||||
d_head,
|
|
||||||
dropout=0.0,
|
|
||||||
context_dim=None,
|
|
||||||
gated_ff=True,
|
|
||||||
checkpoint=True,
|
|
||||||
timesteps=None,
|
|
||||||
ff_in=False,
|
|
||||||
inner_dim=None,
|
|
||||||
attn_mode="softmax",
|
|
||||||
disable_self_attn=False,
|
|
||||||
disable_temporal_crossattention=False,
|
|
||||||
switch_temporal_ca_to_sa=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
|
||||||
|
|
||||||
self.ff_in = ff_in or inner_dim is not None
|
|
||||||
if inner_dim is None:
|
|
||||||
inner_dim = dim
|
|
||||||
|
|
||||||
assert int(n_heads * d_head) == inner_dim
|
|
||||||
|
|
||||||
self.is_res = inner_dim == dim
|
|
||||||
|
|
||||||
if self.ff_in:
|
|
||||||
self.norm_in = nn.LayerNorm(dim)
|
|
||||||
self.ff_in = FeedForward(
|
|
||||||
dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
|
|
||||||
)
|
|
||||||
|
|
||||||
self.timesteps = timesteps
|
|
||||||
self.disable_self_attn = disable_self_attn
|
|
||||||
if self.disable_self_attn:
|
|
||||||
self.attn1 = attn_cls(
|
|
||||||
query_dim=inner_dim,
|
|
||||||
heads=n_heads,
|
|
||||||
dim_head=d_head,
|
|
||||||
context_dim=context_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
) # is a cross-attention
|
|
||||||
else:
|
|
||||||
self.attn1 = attn_cls(
|
|
||||||
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
|
||||||
) # is a self-attention
|
|
||||||
|
|
||||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
|
|
||||||
|
|
||||||
if disable_temporal_crossattention:
|
|
||||||
if switch_temporal_ca_to_sa:
|
|
||||||
raise ValueError
|
|
||||||
else:
|
|
||||||
self.attn2 = None
|
|
||||||
else:
|
|
||||||
self.norm2 = nn.LayerNorm(inner_dim)
|
|
||||||
if switch_temporal_ca_to_sa:
|
|
||||||
self.attn2 = attn_cls(
|
|
||||||
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
|
||||||
) # is a self-attention
|
|
||||||
else:
|
|
||||||
self.attn2 = attn_cls(
|
|
||||||
query_dim=inner_dim,
|
|
||||||
context_dim=context_dim,
|
|
||||||
heads=n_heads,
|
|
||||||
dim_head=d_head,
|
|
||||||
dropout=dropout,
|
|
||||||
) # is self-attn if context is none
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(inner_dim)
|
|
||||||
self.norm3 = nn.LayerNorm(inner_dim)
|
|
||||||
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
|
||||||
|
|
||||||
self.checkpoint = checkpoint
|
|
||||||
if self.checkpoint:
|
|
||||||
logpy.info(f"{self.__class__.__name__} is using checkpointing")
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if self.checkpoint:
|
|
||||||
return checkpoint(self._forward, x, context, timesteps)
|
|
||||||
else:
|
|
||||||
return self._forward(x, context, timesteps=timesteps)
|
|
||||||
|
|
||||||
def _forward(self, x, context=None, timesteps=None):
|
|
||||||
assert self.timesteps or timesteps
|
|
||||||
assert not (self.timesteps and timesteps) or self.timesteps == timesteps
|
|
||||||
timesteps = self.timesteps or timesteps
|
|
||||||
B, S, C = x.shape
|
|
||||||
x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
|
|
||||||
|
|
||||||
if self.ff_in:
|
|
||||||
x_skip = x
|
|
||||||
x = self.ff_in(self.norm_in(x))
|
|
||||||
if self.is_res:
|
|
||||||
x += x_skip
|
|
||||||
|
|
||||||
if self.disable_self_attn:
|
|
||||||
x = self.attn1(self.norm1(x), context=context) + x
|
|
||||||
else:
|
|
||||||
x = self.attn1(self.norm1(x)) + x
|
|
||||||
|
|
||||||
if self.attn2 is not None:
|
|
||||||
if self.switch_temporal_ca_to_sa:
|
|
||||||
x = self.attn2(self.norm2(x)) + x
|
|
||||||
else:
|
|
||||||
x = self.attn2(self.norm2(x), context=context) + x
|
|
||||||
x_skip = x
|
|
||||||
x = self.ff(self.norm3(x))
|
|
||||||
if self.is_res:
|
|
||||||
x += x_skip
|
|
||||||
|
|
||||||
x = rearrange(
|
|
||||||
x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
|
||||||
)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def get_last_layer(self):
|
|
||||||
return self.ff.net[-1].weight
|
|
||||||
|
|
||||||
|
|
||||||
class PostHocSpatialTransformerWithTimeMixing(SpatialTransformer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
n_heads,
|
|
||||||
d_head,
|
|
||||||
depth=1,
|
|
||||||
dropout=0.0,
|
|
||||||
use_linear=False,
|
|
||||||
context_dim=None,
|
|
||||||
use_spatial_context=False,
|
|
||||||
timesteps=None,
|
|
||||||
merge_strategy: str = "fixed",
|
|
||||||
merge_factor: float = 0.5,
|
|
||||||
apply_sigmoid_to_merge: bool = True,
|
|
||||||
time_context_dim=None,
|
|
||||||
ff_in=False,
|
|
||||||
checkpoint=False,
|
|
||||||
time_depth=1,
|
|
||||||
attn_mode="softmax",
|
|
||||||
disable_self_attn=False,
|
|
||||||
disable_temporal_crossattention=False,
|
|
||||||
time_mix_legacy: bool = True,
|
|
||||||
max_time_embed_period: int = 10000,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
in_channels,
|
|
||||||
n_heads,
|
|
||||||
d_head,
|
|
||||||
depth=depth,
|
|
||||||
dropout=dropout,
|
|
||||||
attn_type=attn_mode,
|
|
||||||
use_checkpoint=checkpoint,
|
|
||||||
context_dim=context_dim,
|
|
||||||
use_linear=use_linear,
|
|
||||||
disable_self_attn=disable_self_attn,
|
|
||||||
)
|
|
||||||
self.time_depth = time_depth
|
|
||||||
self.depth = depth
|
|
||||||
self.max_time_embed_period = max_time_embed_period
|
|
||||||
|
|
||||||
time_mix_d_head = d_head
|
|
||||||
n_time_mix_heads = n_heads
|
|
||||||
|
|
||||||
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
|
||||||
|
|
||||||
inner_dim = n_heads * d_head
|
|
||||||
if use_spatial_context:
|
|
||||||
time_context_dim = context_dim
|
|
||||||
|
|
||||||
self.time_mix_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
BasicTransformerTimeMixBlock(
|
|
||||||
inner_dim,
|
|
||||||
n_time_mix_heads,
|
|
||||||
time_mix_d_head,
|
|
||||||
dropout=dropout,
|
|
||||||
context_dim=time_context_dim,
|
|
||||||
timesteps=timesteps,
|
|
||||||
checkpoint=checkpoint,
|
|
||||||
ff_in=ff_in,
|
|
||||||
inner_dim=time_mix_inner_dim,
|
|
||||||
attn_mode=attn_mode,
|
|
||||||
disable_self_attn=disable_self_attn,
|
|
||||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
|
||||||
)
|
|
||||||
for _ in range(self.depth)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(self.time_mix_blocks) == len(self.transformer_blocks)
|
|
||||||
|
|
||||||
self.use_spatial_context = use_spatial_context
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
time_embed_dim = self.in_channels * 4
|
|
||||||
self.time_mix_time_embed = nn.Sequential(
|
|
||||||
linear(self.in_channels, time_embed_dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
linear(time_embed_dim, self.in_channels),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.time_mix_legacy = time_mix_legacy
|
|
||||||
if self.time_mix_legacy:
|
|
||||||
if merge_strategy == "fixed":
|
|
||||||
self.register_buffer("mix_factor", torch.Tensor([merge_factor]))
|
|
||||||
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
|
||||||
self.register_parameter(
|
|
||||||
"mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor]))
|
|
||||||
)
|
|
||||||
elif merge_strategy == "fixed_with_images":
|
|
||||||
self.mix_factor = None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
|
||||||
|
|
||||||
self.get_alpha_fn = partial(
|
|
||||||
get_alpha,
|
|
||||||
merge_strategy,
|
|
||||||
self.mix_factor,
|
|
||||||
apply_sigmoid=apply_sigmoid_to_merge,
|
|
||||||
is_attn=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.time_mixer = AlphaBlender(
|
|
||||||
alpha=merge_factor, merge_strategy=merge_strategy
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
context: Optional[torch.Tensor] = None,
|
|
||||||
# cam: Optional[torch.Tensor] = None,
|
|
||||||
time_context: Optional[torch.Tensor] = None,
|
|
||||||
timesteps: Optional[int] = None,
|
|
||||||
image_only_indicator: Optional[torch.Tensor] = None,
|
|
||||||
cond_view: Optional[torch.Tensor] = None,
|
|
||||||
cond_motion: Optional[torch.Tensor] = None,
|
|
||||||
time_step: Optional[int] = None,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
_, _, h, w = x.shape
|
|
||||||
x_in = x
|
|
||||||
spatial_context = None
|
|
||||||
if exists(context):
|
|
||||||
spatial_context = context
|
|
||||||
|
|
||||||
if self.use_spatial_context:
|
|
||||||
assert (
|
|
||||||
context.ndim == 3
|
|
||||||
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
|
||||||
|
|
||||||
time_context = context
|
|
||||||
time_context_first_timestep = time_context[::timesteps]
|
|
||||||
time_context = repeat(
|
|
||||||
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
|
|
||||||
)
|
|
||||||
elif time_context is not None and not self.use_spatial_context:
|
|
||||||
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
|
|
||||||
if time_context.ndim == 2:
|
|
||||||
time_context = rearrange(time_context, "b c -> b 1 c")
|
|
||||||
|
|
||||||
x = self.norm(x)
|
|
||||||
if not self.use_linear:
|
|
||||||
x = self.proj_in(x)
|
|
||||||
x = rearrange(x, "b c h w -> b (h w) c")
|
|
||||||
if self.use_linear:
|
|
||||||
x = self.proj_in(x)
|
|
||||||
|
|
||||||
if self.time_mix_legacy:
|
|
||||||
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
|
||||||
|
|
||||||
num_frames = torch.arange(timesteps, device=x.device)
|
|
||||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
|
||||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
|
||||||
t_emb = timestep_embedding(
|
|
||||||
num_frames,
|
|
||||||
self.in_channels,
|
|
||||||
repeat_only=False,
|
|
||||||
max_period=self.max_time_embed_period,
|
|
||||||
)
|
|
||||||
emb = self.time_mix_time_embed(t_emb)
|
|
||||||
emb = emb[:, None, :]
|
|
||||||
|
|
||||||
for it_, (block, mix_block) in enumerate(
|
|
||||||
zip(self.transformer_blocks, self.time_mix_blocks)
|
|
||||||
):
|
|
||||||
# spatial attention
|
|
||||||
x = block(
|
|
||||||
x,
|
|
||||||
context=spatial_context,
|
|
||||||
time_step=time_step,
|
|
||||||
name=name + '_' + str(it_)
|
|
||||||
)
|
|
||||||
|
|
||||||
x_mix = x
|
|
||||||
x_mix = x_mix + emb
|
|
||||||
|
|
||||||
# temporal attention
|
|
||||||
x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
|
|
||||||
if self.time_mix_legacy:
|
|
||||||
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
|
||||||
else:
|
|
||||||
x = self.time_mixer(
|
|
||||||
x_spatial=x,
|
|
||||||
x_temporal=x_mix,
|
|
||||||
image_only_indicator=image_only_indicator,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_linear:
|
|
||||||
x = self.proj_out(x)
|
|
||||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
|
||||||
if not self.use_linear:
|
|
||||||
x = self.proj_out(x)
|
|
||||||
out = x + x_in
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class PostHocSpatialTransformerWithTimeMixingAndMotion(SpatialTransformer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
n_heads,
|
|
||||||
d_head,
|
|
||||||
depth=1,
|
|
||||||
dropout=0.0,
|
|
||||||
use_linear=False,
|
|
||||||
context_dim=None,
|
|
||||||
use_spatial_context=False,
|
|
||||||
use_camera_emb=False,
|
|
||||||
use_3d_attention=False,
|
|
||||||
separate_motion_merge_factor=False,
|
|
||||||
adm_in_channels=None,
|
|
||||||
timesteps=None,
|
|
||||||
merge_strategy: str = "fixed",
|
|
||||||
merge_factor: float = 0.5,
|
|
||||||
merge_factor_motion: float = 0.5,
|
|
||||||
apply_sigmoid_to_merge: bool = True,
|
|
||||||
time_context_dim=None,
|
|
||||||
motion_context_dim=None,
|
|
||||||
ff_in=False,
|
|
||||||
checkpoint=False,
|
|
||||||
time_depth=1,
|
|
||||||
attn_mode="softmax",
|
|
||||||
disable_self_attn=False,
|
|
||||||
disable_temporal_crossattention=False,
|
|
||||||
time_mix_legacy: bool = True,
|
|
||||||
max_time_embed_period: int = 10000,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
in_channels,
|
|
||||||
n_heads,
|
|
||||||
d_head,
|
|
||||||
depth=depth,
|
|
||||||
dropout=dropout,
|
|
||||||
attn_type=attn_mode,
|
|
||||||
use_checkpoint=checkpoint,
|
|
||||||
context_dim=context_dim,
|
|
||||||
use_linear=use_linear,
|
|
||||||
disable_self_attn=disable_self_attn,
|
|
||||||
)
|
|
||||||
self.time_depth = time_depth
|
|
||||||
self.depth = depth
|
|
||||||
self.max_time_embed_period = max_time_embed_period
|
|
||||||
self.use_camera_emb = use_camera_emb
|
|
||||||
self.motion_context_dim = motion_context_dim
|
|
||||||
self.use_3d_attention = use_3d_attention
|
|
||||||
self.separate_motion_merge_factor = separate_motion_merge_factor
|
|
||||||
|
|
||||||
time_mix_d_head = d_head
|
|
||||||
n_time_mix_heads = n_heads
|
|
||||||
|
|
||||||
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
|
||||||
|
|
||||||
inner_dim = n_heads * d_head
|
|
||||||
if use_spatial_context:
|
|
||||||
time_context_dim = context_dim
|
|
||||||
|
|
||||||
# Camera attention layer
|
|
||||||
self.time_mix_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
BasicTransformerTimeMixBlock(
|
|
||||||
inner_dim,
|
|
||||||
n_time_mix_heads,
|
|
||||||
time_mix_d_head,
|
|
||||||
dropout=dropout,
|
|
||||||
context_dim=time_context_dim,
|
|
||||||
timesteps=timesteps,
|
|
||||||
checkpoint=checkpoint,
|
|
||||||
ff_in=ff_in,
|
|
||||||
inner_dim=time_mix_inner_dim,
|
|
||||||
attn_mode=attn_mode,
|
|
||||||
disable_self_attn=disable_self_attn,
|
|
||||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
|
||||||
)
|
|
||||||
for _ in range(self.depth)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Motion attention layer
|
|
||||||
self.motion_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
BasicTransformerTimeMixBlock(
|
|
||||||
inner_dim,
|
|
||||||
n_time_mix_heads,
|
|
||||||
time_mix_d_head,
|
|
||||||
dropout=dropout,
|
|
||||||
context_dim=motion_context_dim,
|
|
||||||
timesteps=timesteps,
|
|
||||||
checkpoint=checkpoint,
|
|
||||||
ff_in=ff_in,
|
|
||||||
inner_dim=time_mix_inner_dim,
|
|
||||||
attn_mode=attn_mode,
|
|
||||||
disable_self_attn=disable_self_attn,
|
|
||||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
|
||||||
)
|
|
||||||
for _ in range(self.depth)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(self.time_mix_blocks) == len(self.transformer_blocks)
|
|
||||||
|
|
||||||
self.use_spatial_context = use_spatial_context
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
time_embed_dim = self.in_channels * 4
|
|
||||||
time_embed_channels = adm_in_channels if self.use_camera_emb else self.in_channels
|
|
||||||
# Camera view embedding
|
|
||||||
self.time_mix_time_embed = nn.Sequential(
|
|
||||||
linear(time_embed_channels, time_embed_dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
linear(time_embed_dim, self.in_channels),
|
|
||||||
)
|
|
||||||
# Motion time embedding
|
|
||||||
self.time_mix_motion_embed = nn.Sequential(
|
|
||||||
linear(self.in_channels, time_embed_dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
linear(time_embed_dim, self.in_channels),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.time_mix_legacy = time_mix_legacy
|
|
||||||
if self.time_mix_legacy:
|
|
||||||
if merge_strategy == "fixed":
|
|
||||||
self.register_buffer("mix_factor", torch.Tensor([merge_factor]))
|
|
||||||
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
|
||||||
self.register_parameter(
|
|
||||||
"mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor]))
|
|
||||||
)
|
|
||||||
elif merge_strategy == "fixed_with_images":
|
|
||||||
self.mix_factor = None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
|
||||||
|
|
||||||
self.get_alpha_fn = partial(
|
|
||||||
get_alpha,
|
|
||||||
merge_strategy,
|
|
||||||
self.mix_factor,
|
|
||||||
apply_sigmoid=apply_sigmoid_to_merge,
|
|
||||||
is_attn=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.time_mixer = AlphaBlender(
|
|
||||||
alpha=merge_factor, merge_strategy=merge_strategy
|
|
||||||
)
|
|
||||||
if self.separate_motion_merge_factor:
|
|
||||||
self.time_mixer_motion = AlphaBlender(
|
|
||||||
alpha=merge_factor_motion, merge_strategy=merge_strategy
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
context: Optional[torch.Tensor] = None,
|
|
||||||
cam: Optional[torch.Tensor] = None,
|
|
||||||
time_context: Optional[torch.Tensor] = None,
|
|
||||||
timesteps: Optional[int] = None,
|
|
||||||
image_only_indicator: Optional[torch.Tensor] = None,
|
|
||||||
cond_view: Optional[torch.Tensor] = None,
|
|
||||||
cond_motion: Optional[torch.Tensor] = None,
|
|
||||||
time_step: Optional[int] = None,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# context: b t 1024
|
|
||||||
# cond_view: b*v 4 h w
|
|
||||||
# cond_motion: b*t 4 h w
|
|
||||||
# image_only_indicator: b t*v
|
|
||||||
b, t, d1 = context.shape # CLIP
|
|
||||||
v, d2 = cond_view.shape[0]//b, cond_view.shape[1] # VAE
|
|
||||||
_, c, h, w = x.shape
|
|
||||||
|
|
||||||
x_in = x
|
|
||||||
spatial_context = None
|
|
||||||
if exists(context):
|
|
||||||
spatial_context = context
|
|
||||||
|
|
||||||
cond_view = torch.nn.functional.interpolate(cond_view, size=(h,w), mode="bilinear") # b*v d h w
|
|
||||||
spatial_context = context[:,:,None].repeat(1,1,v,1).reshape(-1,1,d1) # (b*t*v) 1 d1
|
|
||||||
camera_context = context[:,:,None].repeat(1,1,h*w,1).reshape(-1,1,d1) # (b*t*h*w) 1 d1
|
|
||||||
motion_context = cond_view.permute(0,2,3,1).reshape(-1,1,d2) # (b*v*h*w) 1 d2
|
|
||||||
|
|
||||||
x = self.norm(x)
|
|
||||||
if not self.use_linear:
|
|
||||||
x = self.proj_in(x)
|
|
||||||
x = rearrange(x, "b c h w -> b (h w) c")
|
|
||||||
if self.use_linear:
|
|
||||||
x = self.proj_in(x)
|
|
||||||
|
|
||||||
if self.time_mix_legacy:
|
|
||||||
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
|
||||||
|
|
||||||
num_frames = torch.arange(t, device=x.device)
|
|
||||||
num_frames = repeat(num_frames, "t -> b t", b=b)
|
|
||||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
|
||||||
t_emb = timestep_embedding(
|
|
||||||
num_frames,
|
|
||||||
self.in_channels,
|
|
||||||
repeat_only=False,
|
|
||||||
max_period=self.max_time_embed_period,
|
|
||||||
)
|
|
||||||
emb_time = self.time_mix_motion_embed(t_emb)
|
|
||||||
emb_time = emb_time[:, None, :] # b*t 1 c
|
|
||||||
|
|
||||||
if self.use_camera_emb:
|
|
||||||
emb_view = self.time_mix_time_embed(cam.view(b,t,v,-1)[:,0].reshape(b*v,-1))
|
|
||||||
emb_view = emb_view[:, None, :]
|
|
||||||
else:
|
|
||||||
num_views = torch.arange(v, device=x.device)
|
|
||||||
num_views = repeat(num_views, "t -> b t", b=b)
|
|
||||||
num_views = rearrange(num_views, "b t -> (b t)")
|
|
||||||
v_emb = timestep_embedding(
|
|
||||||
num_views,
|
|
||||||
self.in_channels,
|
|
||||||
repeat_only=False,
|
|
||||||
max_period=self.max_time_embed_period,
|
|
||||||
)
|
|
||||||
emb_view = self.time_mix_time_embed(v_emb)
|
|
||||||
emb_view = emb_view[:, None, :] # b*v 1 c
|
|
||||||
|
|
||||||
if self.use_3d_attention:
|
|
||||||
emb_view = emb_view.repeat(1, h*w, 1).view(-1,1,c) # b*v*h*w 1 c
|
|
||||||
|
|
||||||
for it_, (block, time_block, mot_block) in enumerate(
|
|
||||||
zip(self.transformer_blocks, self.time_mix_blocks, self.motion_blocks)
|
|
||||||
):
|
|
||||||
# Spatial attention
|
|
||||||
x = block(
|
|
||||||
x,
|
|
||||||
context=spatial_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Camera attention
|
|
||||||
if self.use_3d_attention:
|
|
||||||
x = x.view(b, t, v, h*w, c).permute(0,2,3,1,4).reshape(-1,t,c) # b*v*h*w t c
|
|
||||||
else:
|
|
||||||
x = x.view(b, t, v, h*w, c).permute(0,2,1,3,4).reshape(b*v,-1,c) # b*v t*h*w c
|
|
||||||
x_mix = x + emb_view
|
|
||||||
x_mix = time_block(x_mix, context=camera_context, timesteps=v)
|
|
||||||
if self.time_mix_legacy:
|
|
||||||
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
|
||||||
else:
|
|
||||||
x = self.time_mixer(
|
|
||||||
x_spatial=x,
|
|
||||||
x_temporal=x_mix,
|
|
||||||
image_only_indicator=torch.zeros_like(image_only_indicator[:,:1].repeat(1,x.shape[0]//b)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Motion attention
|
|
||||||
if self.use_3d_attention:
|
|
||||||
x = x.view(b, v, h*w, t, c).permute(0,3,1,2,4).reshape(b*t,-1,c) # b*t v*h*w c
|
|
||||||
else:
|
|
||||||
x = x.view(b, v, t, h*w, c).permute(0,2,1,3,4).reshape(b*t,-1,c) # b*t v*h*w c
|
|
||||||
x_mix = x + emb_time
|
|
||||||
x_mix = mot_block(x_mix, context=motion_context, timesteps=t)
|
|
||||||
if self.time_mix_legacy:
|
|
||||||
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
|
||||||
else:
|
|
||||||
motion_mixer = self.time_mixer_motion if self.separate_motion_merge_factor else self.time_mixer
|
|
||||||
x = motion_mixer(
|
|
||||||
x_spatial=x,
|
|
||||||
x_temporal=x_mix,
|
|
||||||
image_only_indicator=torch.zeros_like(image_only_indicator[:,:1].repeat(1,x.shape[0]//b)),
|
|
||||||
)
|
|
||||||
|
|
||||||
x = x.view(b, t, v, h*w, c).reshape(-1,h*w,c) # b*t*v h*w c
|
|
||||||
|
|
||||||
if self.use_linear:
|
|
||||||
x = self.proj_out(x)
|
|
||||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
|
||||||
if not self.use_linear:
|
|
||||||
x = self.proj_out(x)
|
|
||||||
out = x + x_in
|
|
||||||
return out
|
|
||||||