mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 22:34:22 +01:00
Adds SV4D code
This commit is contained in:
24
README.md
24
README.md
@@ -4,6 +4,30 @@
|
|||||||
|
|
||||||
## News
|
## News
|
||||||
|
|
||||||
|
|
||||||
|
**July 24, 2024**
|
||||||
|
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
|
||||||
|
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
|
||||||
|
- To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.
|
||||||
|
- Please check our [project page](), [tech report]() and [video summary]() for more details.
|
||||||
|
|
||||||
|
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [SV4D](https://huggingface.co/stabilityai/sv4d) and [SV3D_u]((https://huggingface.co/stabilityai/sv3d)) from HuggingFace)
|
||||||
|
|
||||||
|
To run **SV4D** on a single input video of 21 frames:
|
||||||
|
- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`
|
||||||
|
- Run `python scripts/sampling/simple_video_sample_4d.py --input_path <path/to/video>`
|
||||||
|
- `input_path` : The input video `<path/to/video>` can be
|
||||||
|
- a single video file in `gif` or `mp4` format, such as `assets/test_video1.mp4`, or
|
||||||
|
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
|
||||||
|
- a file name pattern matching images of video frames.
|
||||||
|
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
|
||||||
|
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
|
||||||
|
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
|
||||||
|
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos (with noisy background), try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
**March 18, 2024**
|
**March 18, 2024**
|
||||||
- We are releasing **[SV3D](https://huggingface.co/stabilityai/sv3d)**, an image-to-video model for novel multi-view synthesis, for research purposes:
|
- We are releasing **[SV3D](https://huggingface.co/stabilityai/sv3d)**, an image-to-video model for novel multi-view synthesis, for research purposes:
|
||||||
- **SV3D** was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object.
|
- **SV3D** was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object.
|
||||||
|
|||||||
BIN
assets/hiphop_parrot.mp4
Normal file
BIN
assets/hiphop_parrot.mp4
Normal file
Binary file not shown.
BIN
assets/sv4d.gif
Normal file
BIN
assets/sv4d.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.0 MiB |
BIN
assets/test_video1.mp4
Normal file
BIN
assets/test_video1.mp4
Normal file
Binary file not shown.
BIN
assets/test_video2.mp4
Normal file
BIN
assets/test_video2.mp4
Normal file
Binary file not shown.
1207
scripts/demo/sv4d_helpers.py
Normal file
1207
scripts/demo/sv4d_helpers.py
Normal file
File diff suppressed because it is too large
Load Diff
208
scripts/sampling/configs/sv4d.yaml
Normal file
208
scripts/sampling/configs/sv4d.yaml
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
N_TIME: 5
|
||||||
|
N_VIEW: 8
|
||||||
|
N_FRAMES: 40
|
||||||
|
|
||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
en_and_decode_n_samples_a_time: 7
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/sv4d.safetensors
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
|
||||||
|
params:
|
||||||
|
adm_in_channels: 1280
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
context_dim: 1024
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
in_channels: 8
|
||||||
|
legacy: False
|
||||||
|
model_channels: 320
|
||||||
|
num_classes: sequential
|
||||||
|
num_head_channels: 64
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_channels: 4
|
||||||
|
replicate_time_mix_bug: True
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
time_block_merge_factor: 0.0
|
||||||
|
time_block_merge_strategy: learned_with_images
|
||||||
|
time_kernel_size: [3, 1, 1]
|
||||||
|
time_mix_legacy: False
|
||||||
|
transformer_depth: 1
|
||||||
|
use_checkpoint: False
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_motion_attention: True
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
|
||||||
|
- input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
is_trainable: False
|
||||||
|
params:
|
||||||
|
n_cond_frames: ${N_TIME}
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
is_trainable: False
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_FRAMES}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
embed_dim: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
monitor: val/rec_loss
|
||||||
|
sigma_cond_config:
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
# - input_key: cond_aug
|
||||||
|
# is_trainable: False
|
||||||
|
# target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
# params:
|
||||||
|
# outdim: 256
|
||||||
|
|
||||||
|
- input_key: polar_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: azimuth_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: cond_view
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_VIEW}
|
||||||
|
n_copies: 1
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
- input_key: cond_motion
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_TIME}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 500.0
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 2.5
|
||||||
|
num_frames: ${N_FRAMES}
|
||||||
|
additional_cond_keys: [ cond_view, cond_motion ]
|
||||||
236
scripts/sampling/simple_video_sample_4d.py
Normal file
236
scripts/sampling/simple_video_sample_4d.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from glob import glob
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from fire import Fire
|
||||||
|
|
||||||
|
from scripts.demo.sv4d_helpers import (
|
||||||
|
decode_latents,
|
||||||
|
load_model,
|
||||||
|
read_video,
|
||||||
|
run_img2vid,
|
||||||
|
run_img2vid_per_step,
|
||||||
|
sample_sv3d,
|
||||||
|
save_video,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
input_path: str = "assets/test_video.mp4", # Can either be image file or folder with image files
|
||||||
|
output_folder: Optional[str] = "outputs/sv4d",
|
||||||
|
num_steps: Optional[int] = 20,
|
||||||
|
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
|
||||||
|
fps_id: int = 6,
|
||||||
|
motion_bucket_id: int = 127,
|
||||||
|
cond_aug: float = 1e-5,
|
||||||
|
seed: int = 23,
|
||||||
|
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
device: str = "cuda",
|
||||||
|
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
|
||||||
|
azimuths_deg: Optional[List[float]] = None,
|
||||||
|
image_frame_ratio: Optional[float] = None,
|
||||||
|
verbose: Optional[bool] = False,
|
||||||
|
remove_bg: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
||||||
|
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
||||||
|
"""
|
||||||
|
# Set model config
|
||||||
|
T = 5 # number of frames per sample
|
||||||
|
V = 8 # number of views per sample
|
||||||
|
F = 8 # vae factor to downsize image->latent
|
||||||
|
C = 4
|
||||||
|
H, W = 576, 576
|
||||||
|
n_frames = 21 # number of input and output video frames
|
||||||
|
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
||||||
|
n_views_sv3d = 21
|
||||||
|
subsampled_views = np.array(
|
||||||
|
[0, 2, 5, 7, 9, 12, 14, 16, 19]
|
||||||
|
) # subsample (V+1=)9 (uniform) views from 21 SV3D views
|
||||||
|
|
||||||
|
model_config = "scripts/sampling/configs/sv4d.yaml"
|
||||||
|
version_dict = {
|
||||||
|
"T": T * V,
|
||||||
|
"H": H,
|
||||||
|
"W": W,
|
||||||
|
"C": C,
|
||||||
|
"f": F,
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 2.5,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 5,
|
||||||
|
"num_steps": num_steps,
|
||||||
|
"force_uc_zero_embeddings": [
|
||||||
|
"cond_frames",
|
||||||
|
"cond_frames_without_noise",
|
||||||
|
"cond_view",
|
||||||
|
"cond_motion",
|
||||||
|
],
|
||||||
|
"additional_guider_kwargs": {
|
||||||
|
"additional_cond_keys": ["cond_view", "cond_motion"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# Read input video frames i.e. images at view 0
|
||||||
|
print(f"Reading {input_path}")
|
||||||
|
images_v0 = read_video(
|
||||||
|
input_path,
|
||||||
|
n_frames=n_frames,
|
||||||
|
W=W,
|
||||||
|
H=H,
|
||||||
|
remove_bg=remove_bg,
|
||||||
|
image_frame_ratio=image_frame_ratio,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get camera viewpoints
|
||||||
|
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||||
|
elevations_deg = [elevations_deg] * n_views_sv3d
|
||||||
|
assert (
|
||||||
|
len(elevations_deg) == n_views_sv3d
|
||||||
|
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
|
||||||
|
if azimuths_deg is None:
|
||||||
|
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
|
||||||
|
assert (
|
||||||
|
len(azimuths_deg) == n_views_sv3d
|
||||||
|
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||||
|
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
||||||
|
azimuths_rad = np.array(
|
||||||
|
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample multi-view images of the first frame using SV3D i.e. images at time 0
|
||||||
|
images_t0 = sample_sv3d(
|
||||||
|
images_v0[0],
|
||||||
|
n_views_sv3d,
|
||||||
|
num_steps,
|
||||||
|
sv3d_version,
|
||||||
|
fps_id,
|
||||||
|
motion_bucket_id,
|
||||||
|
cond_aug,
|
||||||
|
decoding_t,
|
||||||
|
device,
|
||||||
|
polars_rad,
|
||||||
|
azimuths_rad,
|
||||||
|
verbose,
|
||||||
|
)
|
||||||
|
images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame
|
||||||
|
|
||||||
|
# Initialize image matrix
|
||||||
|
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
||||||
|
for i, v in enumerate(subsampled_views):
|
||||||
|
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
||||||
|
for t in range(n_frames):
|
||||||
|
img_matrix[t][0] = images_v0[t]
|
||||||
|
|
||||||
|
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10
|
||||||
|
save_video(
|
||||||
|
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
|
||||||
|
img_matrix[0],
|
||||||
|
)
|
||||||
|
save_video(
|
||||||
|
os.path.join(output_folder, f"{base_count:06d}_v000.mp4"),
|
||||||
|
[img_matrix[t][0] for t in range(n_frames)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load SV4D model
|
||||||
|
model, filter = load_model(
|
||||||
|
model_config,
|
||||||
|
device,
|
||||||
|
version_dict["T"],
|
||||||
|
num_steps,
|
||||||
|
verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Interleaved sampling for anchor frames
|
||||||
|
t0, v0 = 0, 0
|
||||||
|
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
|
||||||
|
view_indices = np.arange(V) + 1
|
||||||
|
print(f"Sampling anchor frames {frame_indices}")
|
||||||
|
image = img_matrix[t0][v0]
|
||||||
|
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
||||||
|
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
||||||
|
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||||
|
samples = run_img2vid(
|
||||||
|
version_dict, model, image, seed, polars, azims, cond_motion, cond_view
|
||||||
|
)
|
||||||
|
samples = samples.view(T, V, 3, H, W)
|
||||||
|
for i, t in enumerate(frame_indices):
|
||||||
|
for j, v in enumerate(view_indices):
|
||||||
|
if img_matrix[t][v] is None:
|
||||||
|
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
||||||
|
|
||||||
|
# Dense sampling for the rest
|
||||||
|
print(f"Sampling dense frames:")
|
||||||
|
for t0 in tqdm(np.arange(0, n_frames - 1, T - 1)): # [0, 4, 8, 12, 16]
|
||||||
|
frame_indices = t0 + np.arange(T)
|
||||||
|
print(f"Sampling dense frames {frame_indices}")
|
||||||
|
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
|
||||||
|
for step in tqdm(range(num_steps)):
|
||||||
|
frame_indices = frame_indices[
|
||||||
|
::-1
|
||||||
|
].copy() # alternate between forward and backward conditioning
|
||||||
|
t0 = frame_indices[0]
|
||||||
|
image = img_matrix[t0][v0]
|
||||||
|
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
||||||
|
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
||||||
|
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||||
|
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
|
||||||
|
samples = run_img2vid_per_step(
|
||||||
|
version_dict,
|
||||||
|
model,
|
||||||
|
image,
|
||||||
|
seed,
|
||||||
|
polars,
|
||||||
|
azims,
|
||||||
|
cond_motion,
|
||||||
|
cond_view,
|
||||||
|
step,
|
||||||
|
noisy_latents,
|
||||||
|
)
|
||||||
|
samples = samples.view(T, V, C, H // F, W // F)
|
||||||
|
for i, t in enumerate(frame_indices):
|
||||||
|
for j, v in enumerate(view_indices):
|
||||||
|
latent_matrix[t, v] = samples[i, j]
|
||||||
|
|
||||||
|
for t in frame_indices:
|
||||||
|
for v in view_indices:
|
||||||
|
if t != 0 and v != 0:
|
||||||
|
img = decode_latents(model, latent_matrix[t, v][None], T)
|
||||||
|
img_matrix[t][v] = img * 2 - 1
|
||||||
|
|
||||||
|
# Save output videos
|
||||||
|
for v in view_indices:
|
||||||
|
vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4")
|
||||||
|
print(f"Saving {vid_file}")
|
||||||
|
save_video(vid_file, [img_matrix[t][v] for t in range(n_frames)])
|
||||||
|
|
||||||
|
# Save diagonal video
|
||||||
|
diag_frames = [
|
||||||
|
img_matrix[t][(t // (n_frames // n_views)) % n_views] for t in range(n_frames)
|
||||||
|
]
|
||||||
|
vid_file = os.path.join(output_folder, f"{base_count:06d}_diag.mp4")
|
||||||
|
print(f"Saving {vid_file}")
|
||||||
|
save_video(vid_file, diag_frames)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Fire(sample)
|
||||||
@@ -94,7 +94,7 @@ class LinearPredictionGuider(Guider):
|
|||||||
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
|
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
|
||||||
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
||||||
else:
|
else:
|
||||||
assert c[k] == uc[k]
|
# assert c[k] == uc[k]
|
||||||
c_out[k] = c[k]
|
c_out[k] = c[k]
|
||||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||||
|
|
||||||
@@ -105,7 +105,7 @@ class TrianglePredictionGuider(LinearPredictionGuider):
|
|||||||
max_scale: float,
|
max_scale: float,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
min_scale: float = 1.0,
|
min_scale: float = 1.0,
|
||||||
period: float | List[float] = 1.0,
|
period: Union[float, List[float]] = 1.0,
|
||||||
period_fusing: Literal["mean", "multiply", "max"] = "max",
|
period_fusing: Literal["mean", "multiply", "max"] = "max",
|
||||||
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||||
):
|
):
|
||||||
@@ -129,3 +129,47 @@ class TrianglePredictionGuider(LinearPredictionGuider):
|
|||||||
|
|
||||||
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
|
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
|
||||||
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
||||||
|
|
||||||
|
|
||||||
|
class TrapezoidPredictionGuider(LinearPredictionGuider):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_scale: float,
|
||||||
|
num_frames: int,
|
||||||
|
min_scale: float = 1.0,
|
||||||
|
edge_perc: float = 0.1,
|
||||||
|
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
|
||||||
|
|
||||||
|
rise_steps = torch.linspace(min_scale, max_scale, int(num_frames * edge_perc))
|
||||||
|
fall_steps = torch.flip(rise_steps, [0])
|
||||||
|
self.scale = torch.cat(
|
||||||
|
[
|
||||||
|
rise_steps,
|
||||||
|
torch.ones(num_frames - 2 * int(num_frames * edge_perc)),
|
||||||
|
fall_steps,
|
||||||
|
]
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class SpatiotemporalPredictionGuider(LinearPredictionGuider):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_scale: float,
|
||||||
|
num_frames: int,
|
||||||
|
num_views: int = 1,
|
||||||
|
min_scale: float = 1.0,
|
||||||
|
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
|
||||||
|
V = num_views
|
||||||
|
T = num_frames // V
|
||||||
|
scale = torch.zeros(num_frames).view(T, V)
|
||||||
|
scale += torch.linspace(0, 1, T)[:,None] * 0.5
|
||||||
|
scale += self.triangle_wave(torch.linspace(0, 1, V))[None,:] * 0.5
|
||||||
|
scale = scale.flatten()
|
||||||
|
self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)
|
||||||
|
|
||||||
|
def triangle_wave(self, values: torch.Tensor, period=1) -> torch.Tensor:
|
||||||
|
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
||||||
@@ -75,20 +75,43 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||||||
emb: th.Tensor,
|
emb: th.Tensor,
|
||||||
context: Optional[th.Tensor] = None,
|
context: Optional[th.Tensor] = None,
|
||||||
image_only_indicator: Optional[th.Tensor] = None,
|
image_only_indicator: Optional[th.Tensor] = None,
|
||||||
|
cond_view: Optional[th.Tensor] = None,
|
||||||
|
cond_motion: Optional[th.Tensor] = None,
|
||||||
time_context: Optional[int] = None,
|
time_context: Optional[int] = None,
|
||||||
num_video_frames: Optional[int] = None,
|
num_video_frames: Optional[int] = None,
|
||||||
|
time_step: Optional[int] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
from ...modules.diffusionmodules.video_model import VideoResBlock
|
from ...modules.diffusionmodules.video_model import VideoResBlock, PostHocResBlockWithTime
|
||||||
|
from ...modules.spacetime_attention import (
|
||||||
|
BasicTransformerTimeMixBlock,
|
||||||
|
PostHocSpatialTransformerWithTimeMixing,
|
||||||
|
PostHocSpatialTransformerWithTimeMixingAndMotion
|
||||||
|
)
|
||||||
|
|
||||||
for layer in self:
|
for layer in self:
|
||||||
module = layer
|
module = layer
|
||||||
|
|
||||||
if isinstance(module, TimestepBlock) and not isinstance(
|
if isinstance(
|
||||||
module, VideoResBlock
|
module,
|
||||||
|
(
|
||||||
|
BasicTransformerTimeMixBlock,
|
||||||
|
PostHocSpatialTransformerWithTimeMixing,
|
||||||
|
PostHocSpatialTransformerWithTimeMixingAndMotion
|
||||||
|
),
|
||||||
):
|
):
|
||||||
x = layer(x, emb)
|
x = layer(
|
||||||
elif isinstance(module, VideoResBlock):
|
x,
|
||||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
context,
|
||||||
|
# cam,
|
||||||
|
time_context,
|
||||||
|
num_video_frames,
|
||||||
|
image_only_indicator,
|
||||||
|
cond_view,
|
||||||
|
cond_motion,
|
||||||
|
time_step,
|
||||||
|
name,
|
||||||
|
)
|
||||||
elif isinstance(module, SpatialVideoTransformer):
|
elif isinstance(module, SpatialVideoTransformer):
|
||||||
x = layer(
|
x = layer(
|
||||||
x,
|
x,
|
||||||
@@ -96,7 +119,16 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||||||
time_context,
|
time_context,
|
||||||
num_video_frames,
|
num_video_frames,
|
||||||
image_only_indicator,
|
image_only_indicator,
|
||||||
|
# time_step,
|
||||||
)
|
)
|
||||||
|
elif isinstance(module, PostHocResBlockWithTime):
|
||||||
|
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||||
|
elif isinstance(module, VideoResBlock):
|
||||||
|
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||||
|
elif isinstance(module, TimestepBlock) and not isinstance(
|
||||||
|
module, VideoResBlock
|
||||||
|
):
|
||||||
|
x = layer(x, emb)
|
||||||
elif isinstance(module, SpatialTransformer):
|
elif isinstance(module, SpatialTransformer):
|
||||||
x = layer(x, context)
|
x = layer(x, context)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from typing import Optional, Union
|
||||||
from ...util import default, instantiate_from_config
|
from ...util import default, instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
@@ -29,3 +29,10 @@ class DiscreteSampling:
|
|||||||
torch.randint(0, self.num_idx, (n_samples,)),
|
torch.randint(0, self.num_idx, (n_samples,)),
|
||||||
)
|
)
|
||||||
return self.idx_to_sigma(idx)
|
return self.idx_to_sigma(idx)
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroSampler:
|
||||||
|
def __call__(
|
||||||
|
self, n_samples: int, rand: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5
|
||||||
|
|||||||
@@ -17,6 +17,36 @@ import torch.nn as nn
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
|
def get_alpha(
|
||||||
|
merge_strategy: str,
|
||||||
|
mix_factor: Optional[torch.Tensor],
|
||||||
|
image_only_indicator: torch.Tensor,
|
||||||
|
apply_sigmoid: bool = True,
|
||||||
|
is_attn: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if merge_strategy == "fixed" or merge_strategy == "learned":
|
||||||
|
alpha = mix_factor
|
||||||
|
elif merge_strategy == "learned_with_images":
|
||||||
|
alpha = torch.where(
|
||||||
|
image_only_indicator.bool(),
|
||||||
|
torch.ones(1, 1, device=image_only_indicator.device),
|
||||||
|
rearrange(mix_factor, "... -> ... 1"),
|
||||||
|
)
|
||||||
|
if is_attn:
|
||||||
|
alpha = rearrange(alpha, "b t -> (b t) 1 1")
|
||||||
|
else:
|
||||||
|
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
|
||||||
|
elif merge_strategy == "fixed_with_images":
|
||||||
|
alpha = image_only_indicator
|
||||||
|
if is_attn:
|
||||||
|
alpha = rearrange(alpha, "b t -> (b t) 1 1")
|
||||||
|
else:
|
||||||
|
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return torch.sigmoid(alpha) if apply_sigmoid else alpha
|
||||||
|
|
||||||
|
|
||||||
def make_beta_schedule(
|
def make_beta_schedule(
|
||||||
schedule,
|
schedule,
|
||||||
n_timestep,
|
n_timestep,
|
||||||
|
|||||||
@@ -5,8 +5,13 @@ from einops import rearrange
|
|||||||
|
|
||||||
from ...modules.diffusionmodules.openaimodel import *
|
from ...modules.diffusionmodules.openaimodel import *
|
||||||
from ...modules.video_attention import SpatialVideoTransformer
|
from ...modules.video_attention import SpatialVideoTransformer
|
||||||
|
from ...modules.spacetime_attention import (
|
||||||
|
BasicTransformerTimeMixBlock,
|
||||||
|
PostHocSpatialTransformerWithTimeMixing,
|
||||||
|
PostHocSpatialTransformerWithTimeMixingAndMotion
|
||||||
|
)
|
||||||
from ...util import default
|
from ...util import default
|
||||||
from .util import AlphaBlender
|
from .util import AlphaBlender # , LegacyAlphaBlenderWithBug, get_alpha
|
||||||
|
|
||||||
|
|
||||||
class VideoResBlock(ResBlock):
|
class VideoResBlock(ResBlock):
|
||||||
@@ -491,3 +496,746 @@ class VideoUNet(nn.Module):
|
|||||||
)
|
)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
||||||
|
|
||||||
|
class PostHocAttentionBlockWithTimeMixing(AttentionBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
n_heads: int,
|
||||||
|
d_head: int,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
use_new_attention_order: bool = False,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
use_spatial_context: bool = False,
|
||||||
|
merge_strategy: bool = "fixed",
|
||||||
|
merge_factor: float = 0.5,
|
||||||
|
apply_sigmoid_to_merge: bool = True,
|
||||||
|
ff_in: bool = False,
|
||||||
|
attn_mode: str = "softmax",
|
||||||
|
disable_temporal_crossattention: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
in_channels,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
)
|
||||||
|
inner_dim = n_heads * d_head
|
||||||
|
|
||||||
|
self.time_mix_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerTimeMixBlock(
|
||||||
|
inner_dim,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
checkpoint=use_checkpoint,
|
||||||
|
ff_in=ff_in,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
time_embed_dim = self.in_channels * 4
|
||||||
|
self.time_mix_time_embed = nn.Sequential(
|
||||||
|
linear(self.in_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, self.in_channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_spatial_context = use_spatial_context
|
||||||
|
|
||||||
|
if merge_strategy == "fixed":
|
||||||
|
self.register_buffer("mix_factor", th.Tensor([merge_factor]))
|
||||||
|
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
||||||
|
self.register_parameter(
|
||||||
|
"mix_factor", th.nn.Parameter(th.Tensor([merge_factor]))
|
||||||
|
)
|
||||||
|
elif merge_strategy == "fixed_with_images":
|
||||||
|
self.mix_factor = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
||||||
|
|
||||||
|
self.get_alpha_fn = functools.partial(
|
||||||
|
get_alpha,
|
||||||
|
merge_strategy,
|
||||||
|
self.mix_factor,
|
||||||
|
apply_sigmoid=apply_sigmoid_to_merge,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: th.Tensor,
|
||||||
|
context: Optional[th.Tensor] = None,
|
||||||
|
# cam: Optional[th.Tensor] = None,
|
||||||
|
time_context: Optional[th.Tensor] = None,
|
||||||
|
timesteps: Optional[int] = None,
|
||||||
|
image_only_indicator: Optional[th.Tensor] = None,
|
||||||
|
conv_view: Optional[th.Tensor] = None,
|
||||||
|
conv_motion: Optional[th.Tensor] = None,
|
||||||
|
):
|
||||||
|
if time_context is not None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
_, _, h, w = x.shape
|
||||||
|
if exists(context):
|
||||||
|
context = rearrange(context, "b t ... -> (b t) ...")
|
||||||
|
if self.use_spatial_context:
|
||||||
|
time_context = repeat(context[:, 0], "b ... -> (b n) ...", n=h * w)
|
||||||
|
|
||||||
|
x = super().forward(
|
||||||
|
x,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = rearrange(x, "b c h w -> b (h w) c")
|
||||||
|
x_mix = x
|
||||||
|
|
||||||
|
num_frames = th.arange(timesteps, device=x.device)
|
||||||
|
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||||
|
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||||
|
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
||||||
|
emb = self.time_mix_time_embed(t_emb)
|
||||||
|
emb = emb[:, None, :]
|
||||||
|
x_mix = x_mix + emb
|
||||||
|
|
||||||
|
x_mix = self.time_mix_blocks[0](
|
||||||
|
x_mix, context=time_context, timesteps=timesteps
|
||||||
|
)
|
||||||
|
|
||||||
|
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
||||||
|
x = alpha * x + (1.0 - alpha) * x_mix
|
||||||
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PostHocResBlockWithTime(ResBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
emb_channels: int,
|
||||||
|
dropout: float,
|
||||||
|
time_kernel_size: Union[int, List[int]] = 3,
|
||||||
|
merge_strategy: bool = "fixed",
|
||||||
|
merge_factor: float = 0.5,
|
||||||
|
apply_sigmoid_to_merge: bool = True,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
use_conv: bool = False,
|
||||||
|
use_scale_shift_norm: bool = False,
|
||||||
|
dims: int = 2,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
up: bool = False,
|
||||||
|
down: bool = False,
|
||||||
|
time_mix_legacy: bool = True,
|
||||||
|
replicate_bug: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
channels,
|
||||||
|
emb_channels,
|
||||||
|
dropout,
|
||||||
|
out_channels=out_channels,
|
||||||
|
use_conv=use_conv,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
up=up,
|
||||||
|
down=down,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.time_mix_blocks = ResBlock(
|
||||||
|
default(out_channels, channels),
|
||||||
|
emb_channels,
|
||||||
|
dropout=dropout,
|
||||||
|
dims=3,
|
||||||
|
out_channels=default(out_channels, channels),
|
||||||
|
use_scale_shift_norm=False,
|
||||||
|
use_conv=False,
|
||||||
|
up=False,
|
||||||
|
down=False,
|
||||||
|
kernel_size=time_kernel_size,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
exchange_temb_dims=True,
|
||||||
|
)
|
||||||
|
self.time_mix_legacy = time_mix_legacy
|
||||||
|
if self.time_mix_legacy:
|
||||||
|
if merge_strategy == "fixed":
|
||||||
|
self.register_buffer("mix_factor", th.Tensor([merge_factor]))
|
||||||
|
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
||||||
|
self.register_parameter(
|
||||||
|
"mix_factor", th.nn.Parameter(th.Tensor([merge_factor]))
|
||||||
|
)
|
||||||
|
elif merge_strategy == "fixed_with_images":
|
||||||
|
self.mix_factor = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
||||||
|
|
||||||
|
self.get_alpha_fn = functools.partial(
|
||||||
|
get_alpha,
|
||||||
|
merge_strategy,
|
||||||
|
self.mix_factor,
|
||||||
|
apply_sigmoid=apply_sigmoid_to_merge,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if False: # replicate_bug:
|
||||||
|
logpy.warning(
|
||||||
|
"*****************************************************************************************\n"
|
||||||
|
"GRAVE WARNING: YOU'RE USING THE BUGGY LEGACY ALPHABLENDER!!! ARE YOU SURE YOU WANT THIS?!\n"
|
||||||
|
"*****************************************************************************************"
|
||||||
|
)
|
||||||
|
self.time_mixer = LegacyAlphaBlenderWithBug(
|
||||||
|
alpha=merge_factor,
|
||||||
|
merge_strategy=merge_strategy,
|
||||||
|
rearrange_pattern="b t -> b 1 t 1 1",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.time_mixer = AlphaBlender(
|
||||||
|
alpha=merge_factor,
|
||||||
|
merge_strategy=merge_strategy,
|
||||||
|
rearrange_pattern="b t -> b 1 t 1 1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: th.Tensor,
|
||||||
|
emb: th.Tensor,
|
||||||
|
num_video_frames: int,
|
||||||
|
image_only_indicator: Optional[th.Tensor] = None,
|
||||||
|
cond_view: Optional[th.Tensor] = None,
|
||||||
|
cond_motion: Optional[th.Tensor] = None,
|
||||||
|
) -> th.Tensor:
|
||||||
|
x = super().forward(x, emb)
|
||||||
|
|
||||||
|
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||||
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||||
|
|
||||||
|
x = self.time_mix_blocks(
|
||||||
|
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.time_mix_legacy:
|
||||||
|
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
||||||
|
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
||||||
|
else:
|
||||||
|
x = self.time_mixer(
|
||||||
|
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
||||||
|
)
|
||||||
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialUNetModelWithTime(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
model_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_res_blocks: int,
|
||||||
|
attention_resolutions: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
channel_mult: List[int] = (1, 2, 4, 8),
|
||||||
|
conv_resample: bool = True,
|
||||||
|
dims: int = 2,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
num_heads: int = -1,
|
||||||
|
num_head_channels: int = -1,
|
||||||
|
num_heads_upsample: int = -1,
|
||||||
|
use_scale_shift_norm: bool = False,
|
||||||
|
resblock_updown: bool = False,
|
||||||
|
use_new_attention_order: bool = False,
|
||||||
|
use_spatial_transformer: bool = False,
|
||||||
|
transformer_depth: Union[List[int], int] = 1,
|
||||||
|
transformer_depth_middle: Optional[int] = None,
|
||||||
|
context_dim: Optional[int] = None,
|
||||||
|
time_downup: bool = False,
|
||||||
|
time_context_dim: Optional[int] = None,
|
||||||
|
extra_ff_mix_layer: bool = False,
|
||||||
|
use_spatial_context: bool = False,
|
||||||
|
time_block_merge_strategy: str = "fixed",
|
||||||
|
time_block_merge_factor: float = 0.5,
|
||||||
|
spatial_transformer_attn_type: str = "softmax",
|
||||||
|
time_kernel_size: Union[int, List[int]] = 3,
|
||||||
|
use_linear_in_transformer: bool = False,
|
||||||
|
legacy: bool = True,
|
||||||
|
adm_in_channels: Optional[int] = None,
|
||||||
|
use_temporal_resblock: bool = True,
|
||||||
|
disable_temporal_crossattention: bool = False,
|
||||||
|
time_mix_legacy: bool = True,
|
||||||
|
max_ddpm_temb_period: int = 10000,
|
||||||
|
replicate_time_mix_bug: bool = False,
|
||||||
|
use_motion_attention: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if use_spatial_transformer:
|
||||||
|
assert context_dim is not None
|
||||||
|
|
||||||
|
if context_dim is not None:
|
||||||
|
assert use_spatial_transformer
|
||||||
|
|
||||||
|
if num_heads_upsample == -1:
|
||||||
|
num_heads_upsample = num_heads
|
||||||
|
|
||||||
|
if num_heads == -1:
|
||||||
|
assert num_head_channels != -1
|
||||||
|
|
||||||
|
if num_head_channels == -1:
|
||||||
|
assert num_heads != -1
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.model_channels = model_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
if isinstance(transformer_depth, int):
|
||||||
|
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||||
|
transformer_depth_middle = default(
|
||||||
|
transformer_depth_middle, transformer_depth[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attention_resolutions = attention_resolutions
|
||||||
|
self.dropout = dropout
|
||||||
|
self.channel_mult = channel_mult
|
||||||
|
self.conv_resample = conv_resample
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_head_channels = num_head_channels
|
||||||
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
self.use_temporal_resblocks = use_temporal_resblock
|
||||||
|
|
||||||
|
time_embed_dim = model_channels * 4
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
linear(model_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.num_classes is not None:
|
||||||
|
if isinstance(self.num_classes, int):
|
||||||
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||||
|
elif self.num_classes == "continuous":
|
||||||
|
print("setting up linear c_adm embedding layer")
|
||||||
|
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||||
|
elif self.num_classes == "timestep":
|
||||||
|
self.label_emb = nn.Sequential(
|
||||||
|
Timestep(model_channels),
|
||||||
|
nn.Sequential(
|
||||||
|
linear(model_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.num_classes == "sequential":
|
||||||
|
assert adm_in_channels is not None
|
||||||
|
self.label_emb = nn.Sequential(
|
||||||
|
nn.Sequential(
|
||||||
|
linear(adm_in_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
self.input_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self._feature_size = model_channels
|
||||||
|
input_block_chans = [model_channels]
|
||||||
|
ch = model_channels
|
||||||
|
ds = 1
|
||||||
|
|
||||||
|
def get_attention_layer(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
depth=1,
|
||||||
|
context_dim=None,
|
||||||
|
use_checkpoint=False,
|
||||||
|
disabled_sa=False,
|
||||||
|
):
|
||||||
|
if not use_spatial_transformer:
|
||||||
|
return PostHocAttentionBlockWithTimeMixing(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
dropout=dropout,
|
||||||
|
ff_in=extra_ff_mix_layer,
|
||||||
|
use_spatial_context=use_spatial_context,
|
||||||
|
merge_strategy=time_block_merge_strategy,
|
||||||
|
merge_factor=time_block_merge_factor,
|
||||||
|
attn_mode=spatial_transformer_attn_type,
|
||||||
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif use_motion_attention:
|
||||||
|
return PostHocSpatialTransformerWithTimeMixingAndMotion(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
depth=depth,
|
||||||
|
context_dim=context_dim,
|
||||||
|
time_context_dim=time_context_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
ff_in=extra_ff_mix_layer,
|
||||||
|
use_spatial_context=use_spatial_context,
|
||||||
|
merge_strategy=time_block_merge_strategy,
|
||||||
|
merge_factor=time_block_merge_factor,
|
||||||
|
checkpoint=use_checkpoint,
|
||||||
|
use_linear=use_linear_in_transformer,
|
||||||
|
attn_mode=spatial_transformer_attn_type,
|
||||||
|
disable_self_attn=disabled_sa,
|
||||||
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||||
|
time_mix_legacy=time_mix_legacy,
|
||||||
|
max_time_embed_period=max_ddpm_temb_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return PostHocSpatialTransformerWithTimeMixing(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
depth=depth,
|
||||||
|
context_dim=context_dim,
|
||||||
|
time_context_dim=time_context_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
ff_in=extra_ff_mix_layer,
|
||||||
|
use_spatial_context=use_spatial_context,
|
||||||
|
merge_strategy=time_block_merge_strategy,
|
||||||
|
merge_factor=time_block_merge_factor,
|
||||||
|
checkpoint=use_checkpoint,
|
||||||
|
use_linear=use_linear_in_transformer,
|
||||||
|
attn_mode=spatial_transformer_attn_type,
|
||||||
|
disable_self_attn=disabled_sa,
|
||||||
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||||
|
time_mix_legacy=time_mix_legacy,
|
||||||
|
max_time_embed_period=max_ddpm_temb_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_resblock(
|
||||||
|
time_block_merge_factor,
|
||||||
|
time_block_merge_strategy,
|
||||||
|
time_kernel_size,
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_ch,
|
||||||
|
dims,
|
||||||
|
use_checkpoint,
|
||||||
|
use_scale_shift_norm,
|
||||||
|
down=False,
|
||||||
|
up=False,
|
||||||
|
):
|
||||||
|
if self.use_temporal_resblocks:
|
||||||
|
return PostHocResBlockWithTime(
|
||||||
|
merge_factor=time_block_merge_factor,
|
||||||
|
merge_strategy=time_block_merge_strategy,
|
||||||
|
time_kernel_size=time_kernel_size,
|
||||||
|
channels=ch,
|
||||||
|
emb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
out_channels=out_ch,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
down=down,
|
||||||
|
up=up,
|
||||||
|
time_mix_legacy=time_mix_legacy,
|
||||||
|
replicate_bug=replicate_time_mix_bug,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ResBlock(
|
||||||
|
channels=ch,
|
||||||
|
emb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
out_channels=out_ch,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
dims=dims,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
down=down,
|
||||||
|
up=up,
|
||||||
|
)
|
||||||
|
|
||||||
|
for level, mult in enumerate(channel_mult):
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
get_resblock(
|
||||||
|
time_block_merge_factor=time_block_merge_factor,
|
||||||
|
time_block_merge_strategy=time_block_merge_strategy,
|
||||||
|
time_kernel_size=time_kernel_size,
|
||||||
|
ch=ch,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
out_ch=mult * model_channels,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ch = mult * model_channels
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
dim_head = (
|
||||||
|
ch // num_heads
|
||||||
|
if use_spatial_transformer
|
||||||
|
else num_head_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
get_attention_layer(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
depth=transformer_depth[level],
|
||||||
|
context_dim=context_dim,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
disabled_sa=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self._feature_size += ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
if level != len(channel_mult) - 1:
|
||||||
|
ds *= 2
|
||||||
|
out_ch = ch
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
get_resblock(
|
||||||
|
time_block_merge_factor=time_block_merge_factor,
|
||||||
|
time_block_merge_strategy=time_block_merge_strategy,
|
||||||
|
time_kernel_size=time_kernel_size,
|
||||||
|
ch=ch,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
out_ch=out_ch,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
down=True,
|
||||||
|
)
|
||||||
|
if resblock_updown
|
||||||
|
else Downsample(
|
||||||
|
ch,
|
||||||
|
conv_resample,
|
||||||
|
dims=dims,
|
||||||
|
out_channels=out_ch,
|
||||||
|
third_down=time_downup,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ch = out_ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
# num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
|
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
get_resblock(
|
||||||
|
time_block_merge_factor=time_block_merge_factor,
|
||||||
|
time_block_merge_strategy=time_block_merge_strategy,
|
||||||
|
time_kernel_size=time_kernel_size,
|
||||||
|
ch=ch,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
out_ch=None,
|
||||||
|
dropout=dropout,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
),
|
||||||
|
get_attention_layer(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
depth=transformer_depth_middle,
|
||||||
|
context_dim=context_dim,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
),
|
||||||
|
get_resblock(
|
||||||
|
time_block_merge_factor=time_block_merge_factor,
|
||||||
|
time_block_merge_strategy=time_block_merge_strategy,
|
||||||
|
time_kernel_size=time_kernel_size,
|
||||||
|
ch=ch,
|
||||||
|
out_ch=None,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.output_blocks = nn.ModuleList([])
|
||||||
|
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||||
|
for i in range(num_res_blocks + 1):
|
||||||
|
ich = input_block_chans.pop()
|
||||||
|
layers = [
|
||||||
|
get_resblock(
|
||||||
|
time_block_merge_factor=time_block_merge_factor,
|
||||||
|
time_block_merge_strategy=time_block_merge_strategy,
|
||||||
|
time_kernel_size=time_kernel_size,
|
||||||
|
ch=ch + ich,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
out_ch=model_channels * mult,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ch = model_channels * mult
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
dim_head = (
|
||||||
|
ch // num_heads
|
||||||
|
if use_spatial_transformer
|
||||||
|
else num_head_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
get_attention_layer(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
depth=transformer_depth[level],
|
||||||
|
context_dim=context_dim,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
disabled_sa=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if level and i == num_res_blocks:
|
||||||
|
out_ch = ch
|
||||||
|
ds //= 2
|
||||||
|
layers.append(
|
||||||
|
get_resblock(
|
||||||
|
time_block_merge_factor=time_block_merge_factor,
|
||||||
|
time_block_merge_strategy=time_block_merge_strategy,
|
||||||
|
time_kernel_size=time_kernel_size,
|
||||||
|
ch=ch,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
out_ch=out_ch,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
up=True,
|
||||||
|
)
|
||||||
|
if resblock_updown
|
||||||
|
else Upsample(
|
||||||
|
ch,
|
||||||
|
conv_resample,
|
||||||
|
dims=dims,
|
||||||
|
out_channels=out_ch,
|
||||||
|
third_up=time_downup,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
normalization(ch),
|
||||||
|
nn.SiLU(),
|
||||||
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: th.Tensor,
|
||||||
|
timesteps: th.Tensor,
|
||||||
|
context: Optional[th.Tensor] = None,
|
||||||
|
y: Optional[th.Tensor] = None,
|
||||||
|
# cam: Optional[th.Tensor] = None,
|
||||||
|
time_context: Optional[th.Tensor] = None,
|
||||||
|
num_video_frames: Optional[int] = None,
|
||||||
|
image_only_indicator: Optional[th.Tensor] = None,
|
||||||
|
cond_view: Optional[th.Tensor] = None,
|
||||||
|
cond_motion: Optional[th.Tensor] = None,
|
||||||
|
time_step: Optional[int] = None,
|
||||||
|
):
|
||||||
|
assert (y is not None) == (
|
||||||
|
self.num_classes is not None
|
||||||
|
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
||||||
|
hs = []
|
||||||
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # 21 x 320
|
||||||
|
emb = self.time_embed(t_emb) # 21 x 1280
|
||||||
|
time = str(timesteps[0].data.cpu().numpy())
|
||||||
|
|
||||||
|
if self.num_classes is not None:
|
||||||
|
assert y.shape[0] == x.shape[0]
|
||||||
|
emb = emb + self.label_emb(y) # 21 x 1280
|
||||||
|
|
||||||
|
h = x # 21 x 8 x 64 x 64
|
||||||
|
for i, module in enumerate(self.input_blocks):
|
||||||
|
h = module(
|
||||||
|
h,
|
||||||
|
emb,
|
||||||
|
context=context,
|
||||||
|
# cam=cam,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
cond_view=cond_view,
|
||||||
|
cond_motion=cond_motion,
|
||||||
|
time_context=time_context,
|
||||||
|
num_video_frames=num_video_frames,
|
||||||
|
time_step=time_step,
|
||||||
|
name='encoder_{}_{}'.format(time, i)
|
||||||
|
)
|
||||||
|
hs.append(h)
|
||||||
|
h = self.middle_block(
|
||||||
|
h,
|
||||||
|
emb,
|
||||||
|
context=context,
|
||||||
|
# cam=cam,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
cond_view=cond_view,
|
||||||
|
cond_motion=cond_motion,
|
||||||
|
time_context=time_context,
|
||||||
|
num_video_frames=num_video_frames,
|
||||||
|
time_step=time_step,
|
||||||
|
name='middle_{}_0'.format(time, i)
|
||||||
|
)
|
||||||
|
for i, module in enumerate(self.output_blocks):
|
||||||
|
h = th.cat([h, hs.pop()], dim=1)
|
||||||
|
h = module(
|
||||||
|
h,
|
||||||
|
emb,
|
||||||
|
context=context,
|
||||||
|
# cam=cam,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
cond_view=cond_view,
|
||||||
|
cond_motion=cond_motion,
|
||||||
|
time_context=time_context,
|
||||||
|
num_video_frames=num_video_frames,
|
||||||
|
time_step=time_step,
|
||||||
|
name='decoder_{}_{}'.format(time, i)
|
||||||
|
)
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
return self.out(h)
|
||||||
|
|||||||
@@ -25,10 +25,21 @@ class OpenAIWrapper(IdentityWrapper):
|
|||||||
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
|
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
||||||
return self.diffusion_model(
|
if "cond_view" in c:
|
||||||
x,
|
return self.diffusion_model(
|
||||||
timesteps=t,
|
x,
|
||||||
context=c.get("crossattn", None),
|
timesteps=t,
|
||||||
y=c.get("vector", None),
|
context=c.get("crossattn", None),
|
||||||
**kwargs,
|
y=c.get("vector", None),
|
||||||
)
|
cond_view=c.get("cond_view", None),
|
||||||
|
cond_motion=c.get("cond_motion", None),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.diffusion_model(
|
||||||
|
x,
|
||||||
|
timesteps=t,
|
||||||
|
context=c.get("crossattn", None),
|
||||||
|
y=c.get("vector", None),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|||||||
@@ -69,8 +69,8 @@ class AbstractEmbModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GeneralConditioner(nn.Module):
|
class GeneralConditioner(nn.Module):
|
||||||
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
|
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat"} # , 5: "concat"}
|
||||||
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
|
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1, "cond_view": 1, "cond_motion": 1}
|
||||||
|
|
||||||
def __init__(self, emb_models: Union[List, ListConfig]):
|
def __init__(self, emb_models: Union[List, ListConfig]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -138,7 +138,11 @@ class GeneralConditioner(nn.Module):
|
|||||||
if not isinstance(emb_out, (list, tuple)):
|
if not isinstance(emb_out, (list, tuple)):
|
||||||
emb_out = [emb_out]
|
emb_out = [emb_out]
|
||||||
for emb in emb_out:
|
for emb in emb_out:
|
||||||
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
|
if embedder.input_key in ["cond_view", "cond_motion"]:
|
||||||
|
out_key = embedder.input_key
|
||||||
|
else:
|
||||||
|
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
|
||||||
|
|
||||||
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
|
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
|
||||||
emb = (
|
emb = (
|
||||||
expand_dims_like(
|
expand_dims_like(
|
||||||
@@ -994,7 +998,10 @@ class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
|
|||||||
sigmas = self.sigma_sampler(b).to(vid.device)
|
sigmas = self.sigma_sampler(b).to(vid.device)
|
||||||
if self.sigma_cond is not None:
|
if self.sigma_cond is not None:
|
||||||
sigma_cond = self.sigma_cond(sigmas)
|
sigma_cond = self.sigma_cond(sigmas)
|
||||||
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
|
if self.n_cond_frames == 1:
|
||||||
|
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
|
||||||
|
else:
|
||||||
|
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_cond_frames) # For SV4D
|
||||||
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
|
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
|
||||||
noise = torch.randn_like(vid)
|
noise = torch.randn_like(vid)
|
||||||
vid = vid + noise * append_dims(sigmas, vid.ndim)
|
vid = vid + noise * append_dims(sigmas, vid.ndim)
|
||||||
@@ -1017,8 +1024,9 @@ class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
|
|||||||
vid = torch.cat(all_out, dim=0)
|
vid = torch.cat(all_out, dim=0)
|
||||||
vid *= self.scale_factor
|
vid *= self.scale_factor
|
||||||
|
|
||||||
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
|
if self.n_cond_frames == 1:
|
||||||
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
|
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
|
||||||
|
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
|
||||||
|
|
||||||
return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
|
return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
|
||||||
|
|
||||||
|
|||||||
596
sgm/modules/spacetime_attention.py
Normal file
596
sgm/modules/spacetime_attention.py
Normal file
@@ -0,0 +1,596 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..modules.attention import *
|
||||||
|
from ..modules.diffusionmodules.util import (
|
||||||
|
AlphaBlender,
|
||||||
|
get_alpha,
|
||||||
|
linear,
|
||||||
|
mixed_checkpoint,
|
||||||
|
timestep_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeMixSequential(nn.Sequential):
|
||||||
|
def forward(self, x, context=None, timesteps=None):
|
||||||
|
for layer in self:
|
||||||
|
x = layer(x, context, timesteps)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerTimeMixBlock(nn.Module):
|
||||||
|
ATTENTION_MODES = {
|
||||||
|
"softmax": CrossAttention,
|
||||||
|
"softmax-xformers": MemoryEfficientCrossAttention,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
dropout=0.0,
|
||||||
|
context_dim=None,
|
||||||
|
gated_ff=True,
|
||||||
|
checkpoint=True,
|
||||||
|
timesteps=None,
|
||||||
|
ff_in=False,
|
||||||
|
inner_dim=None,
|
||||||
|
attn_mode="softmax",
|
||||||
|
disable_self_attn=False,
|
||||||
|
disable_temporal_crossattention=False,
|
||||||
|
switch_temporal_ca_to_sa=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||||
|
|
||||||
|
self.ff_in = ff_in or inner_dim is not None
|
||||||
|
if inner_dim is None:
|
||||||
|
inner_dim = dim
|
||||||
|
|
||||||
|
assert int(n_heads * d_head) == inner_dim
|
||||||
|
|
||||||
|
self.is_res = inner_dim == dim
|
||||||
|
|
||||||
|
if self.ff_in:
|
||||||
|
self.norm_in = nn.LayerNorm(dim)
|
||||||
|
self.ff_in = FeedForward(
|
||||||
|
dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
|
||||||
|
)
|
||||||
|
|
||||||
|
self.timesteps = timesteps
|
||||||
|
self.disable_self_attn = disable_self_attn
|
||||||
|
if self.disable_self_attn:
|
||||||
|
self.attn1 = attn_cls(
|
||||||
|
query_dim=inner_dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
context_dim=context_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
) # is a cross-attention
|
||||||
|
else:
|
||||||
|
self.attn1 = attn_cls(
|
||||||
|
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||||
|
) # is a self-attention
|
||||||
|
|
||||||
|
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
|
||||||
|
|
||||||
|
if disable_temporal_crossattention:
|
||||||
|
if switch_temporal_ca_to_sa:
|
||||||
|
raise ValueError
|
||||||
|
else:
|
||||||
|
self.attn2 = None
|
||||||
|
else:
|
||||||
|
self.norm2 = nn.LayerNorm(inner_dim)
|
||||||
|
if switch_temporal_ca_to_sa:
|
||||||
|
self.attn2 = attn_cls(
|
||||||
|
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||||
|
) # is a self-attention
|
||||||
|
else:
|
||||||
|
self.attn2 = attn_cls(
|
||||||
|
query_dim=inner_dim,
|
||||||
|
context_dim=context_dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
) # is self-attn if context is none
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(inner_dim)
|
||||||
|
self.norm3 = nn.LayerNorm(inner_dim)
|
||||||
|
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
||||||
|
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
if self.checkpoint:
|
||||||
|
logpy.info(f"{self.__class__.__name__} is using checkpointing")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.checkpoint:
|
||||||
|
return checkpoint(self._forward, x, context, timesteps)
|
||||||
|
else:
|
||||||
|
return self._forward(x, context, timesteps=timesteps)
|
||||||
|
|
||||||
|
def _forward(self, x, context=None, timesteps=None):
|
||||||
|
assert self.timesteps or timesteps
|
||||||
|
assert not (self.timesteps and timesteps) or self.timesteps == timesteps
|
||||||
|
timesteps = self.timesteps or timesteps
|
||||||
|
B, S, C = x.shape
|
||||||
|
x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
|
||||||
|
|
||||||
|
if self.ff_in:
|
||||||
|
x_skip = x
|
||||||
|
x = self.ff_in(self.norm_in(x))
|
||||||
|
if self.is_res:
|
||||||
|
x += x_skip
|
||||||
|
|
||||||
|
if self.disable_self_attn:
|
||||||
|
x = self.attn1(self.norm1(x), context=context) + x
|
||||||
|
else:
|
||||||
|
x = self.attn1(self.norm1(x)) + x
|
||||||
|
|
||||||
|
if self.attn2 is not None:
|
||||||
|
if self.switch_temporal_ca_to_sa:
|
||||||
|
x = self.attn2(self.norm2(x)) + x
|
||||||
|
else:
|
||||||
|
x = self.attn2(self.norm2(x), context=context) + x
|
||||||
|
x_skip = x
|
||||||
|
x = self.ff(self.norm3(x))
|
||||||
|
if self.is_res:
|
||||||
|
x += x_skip
|
||||||
|
|
||||||
|
x = rearrange(
|
||||||
|
x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.ff.net[-1].weight
|
||||||
|
|
||||||
|
|
||||||
|
class PostHocSpatialTransformerWithTimeMixing(SpatialTransformer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
depth=1,
|
||||||
|
dropout=0.0,
|
||||||
|
use_linear=False,
|
||||||
|
context_dim=None,
|
||||||
|
use_spatial_context=False,
|
||||||
|
timesteps=None,
|
||||||
|
merge_strategy: str = "fixed",
|
||||||
|
merge_factor: float = 0.5,
|
||||||
|
apply_sigmoid_to_merge: bool = True,
|
||||||
|
time_context_dim=None,
|
||||||
|
ff_in=False,
|
||||||
|
checkpoint=False,
|
||||||
|
time_depth=1,
|
||||||
|
attn_mode="softmax",
|
||||||
|
disable_self_attn=False,
|
||||||
|
disable_temporal_crossattention=False,
|
||||||
|
time_mix_legacy: bool = True,
|
||||||
|
max_time_embed_period: int = 10000,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
in_channels,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
depth=depth,
|
||||||
|
dropout=dropout,
|
||||||
|
attn_type=attn_mode,
|
||||||
|
use_checkpoint=checkpoint,
|
||||||
|
context_dim=context_dim,
|
||||||
|
use_linear=use_linear,
|
||||||
|
disable_self_attn=disable_self_attn,
|
||||||
|
)
|
||||||
|
self.time_depth = time_depth
|
||||||
|
self.depth = depth
|
||||||
|
self.max_time_embed_period = max_time_embed_period
|
||||||
|
|
||||||
|
time_mix_d_head = d_head
|
||||||
|
n_time_mix_heads = n_heads
|
||||||
|
|
||||||
|
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
||||||
|
|
||||||
|
inner_dim = n_heads * d_head
|
||||||
|
if use_spatial_context:
|
||||||
|
time_context_dim = context_dim
|
||||||
|
|
||||||
|
self.time_mix_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerTimeMixBlock(
|
||||||
|
inner_dim,
|
||||||
|
n_time_mix_heads,
|
||||||
|
time_mix_d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
context_dim=time_context_dim,
|
||||||
|
timesteps=timesteps,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
ff_in=ff_in,
|
||||||
|
inner_dim=time_mix_inner_dim,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
disable_self_attn=disable_self_attn,
|
||||||
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||||
|
)
|
||||||
|
for _ in range(self.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(self.time_mix_blocks) == len(self.transformer_blocks)
|
||||||
|
|
||||||
|
self.use_spatial_context = use_spatial_context
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
time_embed_dim = self.in_channels * 4
|
||||||
|
self.time_mix_time_embed = nn.Sequential(
|
||||||
|
linear(self.in_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, self.in_channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.time_mix_legacy = time_mix_legacy
|
||||||
|
if self.time_mix_legacy:
|
||||||
|
if merge_strategy == "fixed":
|
||||||
|
self.register_buffer("mix_factor", torch.Tensor([merge_factor]))
|
||||||
|
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
||||||
|
self.register_parameter(
|
||||||
|
"mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor]))
|
||||||
|
)
|
||||||
|
elif merge_strategy == "fixed_with_images":
|
||||||
|
self.mix_factor = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
||||||
|
|
||||||
|
self.get_alpha_fn = partial(
|
||||||
|
get_alpha,
|
||||||
|
merge_strategy,
|
||||||
|
self.mix_factor,
|
||||||
|
apply_sigmoid=apply_sigmoid_to_merge,
|
||||||
|
is_attn=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.time_mixer = AlphaBlender(
|
||||||
|
alpha=merge_factor, merge_strategy=merge_strategy
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
# cam: Optional[torch.Tensor] = None,
|
||||||
|
time_context: Optional[torch.Tensor] = None,
|
||||||
|
timesteps: Optional[int] = None,
|
||||||
|
image_only_indicator: Optional[torch.Tensor] = None,
|
||||||
|
cond_view: Optional[torch.Tensor] = None,
|
||||||
|
cond_motion: Optional[torch.Tensor] = None,
|
||||||
|
time_step: Optional[int] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
_, _, h, w = x.shape
|
||||||
|
x_in = x
|
||||||
|
spatial_context = None
|
||||||
|
if exists(context):
|
||||||
|
spatial_context = context
|
||||||
|
|
||||||
|
if self.use_spatial_context:
|
||||||
|
assert (
|
||||||
|
context.ndim == 3
|
||||||
|
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
||||||
|
|
||||||
|
time_context = context
|
||||||
|
time_context_first_timestep = time_context[::timesteps]
|
||||||
|
time_context = repeat(
|
||||||
|
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
|
||||||
|
)
|
||||||
|
elif time_context is not None and not self.use_spatial_context:
|
||||||
|
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
|
||||||
|
if time_context.ndim == 2:
|
||||||
|
time_context = rearrange(time_context, "b c -> b 1 c")
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
if not self.use_linear:
|
||||||
|
x = self.proj_in(x)
|
||||||
|
x = rearrange(x, "b c h w -> b (h w) c")
|
||||||
|
if self.use_linear:
|
||||||
|
x = self.proj_in(x)
|
||||||
|
|
||||||
|
if self.time_mix_legacy:
|
||||||
|
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
||||||
|
|
||||||
|
num_frames = torch.arange(timesteps, device=x.device)
|
||||||
|
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||||
|
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||||
|
t_emb = timestep_embedding(
|
||||||
|
num_frames,
|
||||||
|
self.in_channels,
|
||||||
|
repeat_only=False,
|
||||||
|
max_period=self.max_time_embed_period,
|
||||||
|
)
|
||||||
|
emb = self.time_mix_time_embed(t_emb)
|
||||||
|
emb = emb[:, None, :]
|
||||||
|
|
||||||
|
for it_, (block, mix_block) in enumerate(
|
||||||
|
zip(self.transformer_blocks, self.time_mix_blocks)
|
||||||
|
):
|
||||||
|
# spatial attention
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
context=spatial_context,
|
||||||
|
time_step=time_step,
|
||||||
|
name=name + '_' + str(it_)
|
||||||
|
)
|
||||||
|
|
||||||
|
x_mix = x
|
||||||
|
x_mix = x_mix + emb
|
||||||
|
|
||||||
|
# temporal attention
|
||||||
|
x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
|
||||||
|
if self.time_mix_legacy:
|
||||||
|
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
||||||
|
else:
|
||||||
|
x = self.time_mixer(
|
||||||
|
x_spatial=x,
|
||||||
|
x_temporal=x_mix,
|
||||||
|
image_only_indicator=image_only_indicator,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_linear:
|
||||||
|
x = self.proj_out(x)
|
||||||
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||||
|
if not self.use_linear:
|
||||||
|
x = self.proj_out(x)
|
||||||
|
out = x + x_in
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PostHocSpatialTransformerWithTimeMixingAndMotion(SpatialTransformer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
depth=1,
|
||||||
|
dropout=0.0,
|
||||||
|
use_linear=False,
|
||||||
|
context_dim=None,
|
||||||
|
use_spatial_context=False,
|
||||||
|
timesteps=None,
|
||||||
|
merge_strategy: str = "fixed",
|
||||||
|
merge_factor: float = 0.5,
|
||||||
|
apply_sigmoid_to_merge: bool = True,
|
||||||
|
time_context_dim=None,
|
||||||
|
ff_in=False,
|
||||||
|
checkpoint=False,
|
||||||
|
time_depth=1,
|
||||||
|
attn_mode="softmax",
|
||||||
|
disable_self_attn=False,
|
||||||
|
disable_temporal_crossattention=False,
|
||||||
|
time_mix_legacy: bool = True,
|
||||||
|
max_time_embed_period: int = 10000,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
in_channels,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
depth=depth,
|
||||||
|
dropout=dropout,
|
||||||
|
attn_type=attn_mode,
|
||||||
|
use_checkpoint=checkpoint,
|
||||||
|
context_dim=context_dim,
|
||||||
|
use_linear=use_linear,
|
||||||
|
disable_self_attn=disable_self_attn,
|
||||||
|
)
|
||||||
|
self.time_depth = time_depth
|
||||||
|
self.depth = depth
|
||||||
|
self.max_time_embed_period = max_time_embed_period
|
||||||
|
|
||||||
|
time_mix_d_head = d_head
|
||||||
|
n_time_mix_heads = n_heads
|
||||||
|
|
||||||
|
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
||||||
|
|
||||||
|
inner_dim = n_heads * d_head
|
||||||
|
if use_spatial_context:
|
||||||
|
time_context_dim = context_dim
|
||||||
|
|
||||||
|
camera_context_dim = time_context_dim
|
||||||
|
motion_context_dim = 4 # time_context_dim
|
||||||
|
|
||||||
|
# Camera attention layer
|
||||||
|
self.time_mix_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerTimeMixBlock(
|
||||||
|
inner_dim,
|
||||||
|
n_time_mix_heads,
|
||||||
|
time_mix_d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
context_dim=camera_context_dim,
|
||||||
|
timesteps=timesteps,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
ff_in=ff_in,
|
||||||
|
inner_dim=time_mix_inner_dim,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
disable_self_attn=disable_self_attn,
|
||||||
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||||
|
)
|
||||||
|
for _ in range(self.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Motion attention layer
|
||||||
|
self.motion_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerTimeMixBlock(
|
||||||
|
inner_dim,
|
||||||
|
n_time_mix_heads,
|
||||||
|
time_mix_d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
context_dim=motion_context_dim,
|
||||||
|
timesteps=timesteps,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
ff_in=ff_in,
|
||||||
|
inner_dim=time_mix_inner_dim,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
disable_self_attn=disable_self_attn,
|
||||||
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||||
|
)
|
||||||
|
for _ in range(self.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(self.time_mix_blocks) == len(self.transformer_blocks)
|
||||||
|
|
||||||
|
self.use_spatial_context = use_spatial_context
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
time_embed_dim = self.in_channels * 4
|
||||||
|
# Camera view embedding
|
||||||
|
self.time_mix_time_embed = nn.Sequential(
|
||||||
|
linear(self.in_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, self.in_channels),
|
||||||
|
)
|
||||||
|
# Motion time embedding
|
||||||
|
self.time_mix_motion_embed = nn.Sequential(
|
||||||
|
linear(self.in_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, self.in_channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.time_mix_legacy = time_mix_legacy
|
||||||
|
if self.time_mix_legacy:
|
||||||
|
if merge_strategy == "fixed":
|
||||||
|
self.register_buffer("mix_factor", torch.Tensor([merge_factor]))
|
||||||
|
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
||||||
|
self.register_parameter(
|
||||||
|
"mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor]))
|
||||||
|
)
|
||||||
|
elif merge_strategy == "fixed_with_images":
|
||||||
|
self.mix_factor = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
||||||
|
|
||||||
|
self.get_alpha_fn = partial(
|
||||||
|
get_alpha,
|
||||||
|
merge_strategy,
|
||||||
|
self.mix_factor,
|
||||||
|
apply_sigmoid=apply_sigmoid_to_merge,
|
||||||
|
is_attn=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.time_mixer = AlphaBlender(
|
||||||
|
alpha=merge_factor, merge_strategy=merge_strategy
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
# cam: Optional[torch.Tensor] = None,
|
||||||
|
time_context: Optional[torch.Tensor] = None,
|
||||||
|
timesteps: Optional[int] = None,
|
||||||
|
image_only_indicator: Optional[torch.Tensor] = None,
|
||||||
|
cond_view: Optional[torch.Tensor] = None,
|
||||||
|
cond_motion: Optional[torch.Tensor] = None,
|
||||||
|
time_step: Optional[int] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
_, _, h, w = x.shape
|
||||||
|
x_in = x
|
||||||
|
spatial_context = None
|
||||||
|
if exists(context):
|
||||||
|
spatial_context = context
|
||||||
|
|
||||||
|
# cond_view: b v 4 h w
|
||||||
|
# cond_motion: b t 4 h w
|
||||||
|
b, t, d1 = context.shape # CLIP
|
||||||
|
v, d2 = cond_view.shape[0]//b, cond_view.shape[1] # VAE
|
||||||
|
cond_view = torch.nn.functional.interpolate(cond_view, size=(h,w), mode="bilinear") # b*v d h w
|
||||||
|
spatial_context = context[:,:,None].repeat(1,1,v,1).reshape(-1,1,d1) # (b*t*v) 1 d1
|
||||||
|
camera_context = context[:,:,None].repeat(1,1,h*w,1).reshape(-1,1,d1) # (b*t*h*w) 1 d1
|
||||||
|
motion_context = cond_view.permute(0,2,3,1).reshape(-1,1,d2) # (b*v*h*w) 1 d2
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
if not self.use_linear:
|
||||||
|
x = self.proj_in(x)
|
||||||
|
x = rearrange(x, "b c h w -> b (h w) c") # 21 x 4096 x 320
|
||||||
|
if self.use_linear:
|
||||||
|
x = self.proj_in(x)
|
||||||
|
c = x.shape[-1]
|
||||||
|
|
||||||
|
if self.time_mix_legacy:
|
||||||
|
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
||||||
|
|
||||||
|
num_frames = torch.arange(t, device=x.device)
|
||||||
|
num_frames = repeat(num_frames, "t -> b t", b=b)
|
||||||
|
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||||
|
t_emb = timestep_embedding(
|
||||||
|
num_frames,
|
||||||
|
self.in_channels,
|
||||||
|
repeat_only=False,
|
||||||
|
max_period=self.max_time_embed_period,
|
||||||
|
)
|
||||||
|
emb_time = self.time_mix_motion_embed(t_emb)
|
||||||
|
emb_time = emb_time[:, None, :] # b*t x 1 x 320
|
||||||
|
|
||||||
|
num_views = torch.arange(v, device=x.device)
|
||||||
|
num_views = repeat(num_views, "t -> b t", b=b)
|
||||||
|
num_views = rearrange(num_views, "b t -> (b t)")
|
||||||
|
v_emb = timestep_embedding(
|
||||||
|
num_views,
|
||||||
|
self.in_channels,
|
||||||
|
repeat_only=False,
|
||||||
|
max_period=self.max_time_embed_period,
|
||||||
|
)
|
||||||
|
emb_view = self.time_mix_time_embed(v_emb)
|
||||||
|
emb_view = emb_view[:, None, :] # b*v x 1 x 320
|
||||||
|
|
||||||
|
for it_, (block, time_block, mot_block) in enumerate(
|
||||||
|
zip(self.transformer_blocks, self.time_mix_blocks, self.motion_blocks)
|
||||||
|
):
|
||||||
|
# Spatial attention
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
context=spatial_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Camera attention
|
||||||
|
x = x.view(b, t, v, h*w, c).permute(0,2,1,3,4).reshape(b*v,-1,c) # b*v t*h*w c
|
||||||
|
x_mix = x + emb_view
|
||||||
|
x_mix = time_block(x_mix, context=camera_context, timesteps=v)
|
||||||
|
if self.time_mix_legacy:
|
||||||
|
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
||||||
|
else:
|
||||||
|
x = self.time_mixer(
|
||||||
|
x_spatial=x,
|
||||||
|
x_temporal=x_mix,
|
||||||
|
image_only_indicator=image_only_indicator[:,:v],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Motion attention
|
||||||
|
x = x.view(b, v, t, h*w, c).permute(0,2,1,3,4).reshape(b*t,-1,c) # b*t v*h*w c
|
||||||
|
x_mix = x + emb_time
|
||||||
|
x_mix = mot_block(x_mix, context=motion_context, timesteps=t)
|
||||||
|
if self.time_mix_legacy:
|
||||||
|
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
||||||
|
else:
|
||||||
|
x = self.time_mixer(
|
||||||
|
x_spatial=x,
|
||||||
|
x_temporal=x_mix,
|
||||||
|
image_only_indicator=image_only_indicator[:,:t],
|
||||||
|
)
|
||||||
|
|
||||||
|
x = x.view(b, t, v, h*w, c).reshape(-1,h*w,c) # b*t*v h*w c
|
||||||
|
|
||||||
|
if self.use_linear:
|
||||||
|
x = self.proj_out(x)
|
||||||
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||||
|
if not self.use_linear:
|
||||||
|
x = self.proj_out(x)
|
||||||
|
out = x + x_in
|
||||||
|
return out
|
||||||
Reference in New Issue
Block a user