Adds SV4D code

This commit is contained in:
Vikram Voleti
2024-07-23 20:17:16 +00:00
parent fbdc58cab9
commit abe9ed3d40
16 changed files with 3174 additions and 23 deletions

View File

@@ -4,6 +4,30 @@
## News
**July 24, 2024**
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
- To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.
- Please check our [project page](), [tech report]() and [video summary]() for more details.
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [SV4D](https://huggingface.co/stabilityai/sv4d) and [SV3D_u]((https://huggingface.co/stabilityai/sv3d)) from HuggingFace)
To run **SV4D** on a single input video of 21 frames:
- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`
- Run `python scripts/sampling/simple_video_sample_4d.py --input_path <path/to/video>`
- `input_path` : The input video `<path/to/video>` can be
- a single video file in `gif` or `mp4` format, such as `assets/test_video1.mp4`, or
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
- a file name pattern matching images of video frames.
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos (with noisy background), try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
![tile](assets/sv4d.gif)
**March 18, 2024**
- We are releasing **[SV3D](https://huggingface.co/stabilityai/sv3d)**, an image-to-video model for novel multi-view synthesis, for research purposes:
- **SV3D** was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object.

BIN
assets/hiphop_parrot.mp4 Normal file

Binary file not shown.

BIN
assets/sv4d.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.0 MiB

BIN
assets/test_video1.mp4 Normal file

Binary file not shown.

BIN
assets/test_video2.mp4 Normal file

Binary file not shown.

1207
scripts/demo/sv4d_helpers.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,208 @@
N_TIME: 5
N_VIEW: 8
N_FRAMES: 40
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
en_and_decode_n_samples_a_time: 7
disable_first_stage_autocast: True
ckpt_path: checkpoints/sv4d.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
params:
adm_in_channels: 1280
attention_resolutions: [4, 2, 1]
channel_mult: [1, 2, 4, 4]
context_dim: 1024
extra_ff_mix_layer: True
in_channels: 8
legacy: False
model_channels: 320
num_classes: sequential
num_head_channels: 64
num_res_blocks: 2
out_channels: 4
replicate_time_mix_bug: True
spatial_transformer_attn_type: softmax-xformers
time_block_merge_factor: 0.0
time_block_merge_strategy: learned_with_images
time_kernel_size: [3, 1, 1]
time_mix_legacy: False
transformer_depth: 1
use_checkpoint: False
use_linear_in_transformer: True
use_spatial_context: True
use_spatial_transformer: True
use_motion_attention: True
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
is_trainable: False
params:
n_cond_frames: ${N_TIME}
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: cond_frames
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
is_trainable: False
params:
is_ae: True
n_cond_frames: ${N_FRAMES}
n_copies: 1
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
ddconfig:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4
embed_dim: 4
lossconfig:
target: torch.nn.Identity
monitor: val/rec_loss
sigma_cond_config:
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
# - input_key: cond_aug
# is_trainable: False
# target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
# params:
# outdim: 256
- input_key: polar_rad
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 512
- input_key: azimuth_rad
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 512
- input_key: cond_view
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4
lossconfig:
target: torch.nn.Identity
is_ae: True
n_cond_frames: ${N_VIEW}
n_copies: 1
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
- input_key: cond_motion
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
is_ae: True
n_cond_frames: ${N_TIME}
n_copies: 1
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4
lossconfig:
target: torch.nn.Identity
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: torch.nn.Identity
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params:
attn_resolutions: []
attn_type: vanilla-xformers
ch: 128
ch_mult: [1, 2, 4, 4]
double_z: True
dropout: 0.0
in_channels: 3
num_res_blocks: 2
out_ch: 3
resolution: 256
z_channels: 4
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 500.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
params:
max_scale: 2.5
num_frames: ${N_FRAMES}
additional_cond_keys: [ cond_view, cond_motion ]

View File

@@ -0,0 +1,236 @@
import os
import sys
from glob import glob
from typing import List, Optional, Union
from tqdm import tqdm
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
import numpy as np
import torch
from fire import Fire
from scripts.demo.sv4d_helpers import (
decode_latents,
load_model,
read_video,
run_img2vid,
run_img2vid_per_step,
sample_sv3d,
save_video,
)
def sample(
input_path: str = "assets/test_video.mp4", # Can either be image file or folder with image files
output_folder: Optional[str] = "outputs/sv4d",
num_steps: Optional[int] = 20,
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
fps_id: int = 6,
motion_bucket_id: int = 127,
cond_aug: float = 1e-5,
seed: int = 23,
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda",
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
azimuths_deg: Optional[List[float]] = None,
image_frame_ratio: Optional[float] = None,
verbose: Optional[bool] = False,
remove_bg: bool = False,
):
"""
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
"""
# Set model config
T = 5 # number of frames per sample
V = 8 # number of views per sample
F = 8 # vae factor to downsize image->latent
C = 4
H, W = 576, 576
n_frames = 21 # number of input and output video frames
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
n_views_sv3d = 21
subsampled_views = np.array(
[0, 2, 5, 7, 9, 12, 14, 16, 19]
) # subsample (V+1=)9 (uniform) views from 21 SV3D views
model_config = "scripts/sampling/configs/sv4d.yaml"
version_dict = {
"T": T * V,
"H": H,
"W": W,
"C": C,
"f": F,
"options": {
"discretization": 1,
"cfg": 2.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 5,
"num_steps": num_steps,
"force_uc_zero_embeddings": [
"cond_frames",
"cond_frames_without_noise",
"cond_view",
"cond_motion",
],
"additional_guider_kwargs": {
"additional_cond_keys": ["cond_view", "cond_motion"]
},
},
}
torch.manual_seed(seed)
os.makedirs(output_folder, exist_ok=True)
# Read input video frames i.e. images at view 0
print(f"Reading {input_path}")
images_v0 = read_video(
input_path,
n_frames=n_frames,
W=W,
H=H,
remove_bg=remove_bg,
image_frame_ratio=image_frame_ratio,
device=device,
)
# Get camera viewpoints
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
elevations_deg = [elevations_deg] * n_views_sv3d
assert (
len(elevations_deg) == n_views_sv3d
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
if azimuths_deg is None:
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
assert (
len(azimuths_deg) == n_views_sv3d
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
azimuths_rad = np.array(
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
)
# Sample multi-view images of the first frame using SV3D i.e. images at time 0
images_t0 = sample_sv3d(
images_v0[0],
n_views_sv3d,
num_steps,
sv3d_version,
fps_id,
motion_bucket_id,
cond_aug,
decoding_t,
device,
polars_rad,
azimuths_rad,
verbose,
)
images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame
# Initialize image matrix
img_matrix = [[None] * n_views for _ in range(n_frames)]
for i, v in enumerate(subsampled_views):
img_matrix[0][i] = images_t0[v].unsqueeze(0)
for t in range(n_frames):
img_matrix[t][0] = images_v0[t]
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10
save_video(
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
img_matrix[0],
)
save_video(
os.path.join(output_folder, f"{base_count:06d}_v000.mp4"),
[img_matrix[t][0] for t in range(n_frames)],
)
# Load SV4D model
model, filter = load_model(
model_config,
device,
version_dict["T"],
num_steps,
verbose,
)
# Interleaved sampling for anchor frames
t0, v0 = 0, 0
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
view_indices = np.arange(V) + 1
print(f"Sampling anchor frames {frame_indices}")
image = img_matrix[t0][v0]
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
samples = run_img2vid(
version_dict, model, image, seed, polars, azims, cond_motion, cond_view
)
samples = samples.view(T, V, 3, H, W)
for i, t in enumerate(frame_indices):
for j, v in enumerate(view_indices):
if img_matrix[t][v] is None:
img_matrix[t][v] = samples[i, j][None] * 2 - 1
# Dense sampling for the rest
print(f"Sampling dense frames:")
for t0 in tqdm(np.arange(0, n_frames - 1, T - 1)): # [0, 4, 8, 12, 16]
frame_indices = t0 + np.arange(T)
print(f"Sampling dense frames {frame_indices}")
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
for step in tqdm(range(num_steps)):
frame_indices = frame_indices[
::-1
].copy() # alternate between forward and backward conditioning
t0 = frame_indices[0]
image = img_matrix[t0][v0]
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
samples = run_img2vid_per_step(
version_dict,
model,
image,
seed,
polars,
azims,
cond_motion,
cond_view,
step,
noisy_latents,
)
samples = samples.view(T, V, C, H // F, W // F)
for i, t in enumerate(frame_indices):
for j, v in enumerate(view_indices):
latent_matrix[t, v] = samples[i, j]
for t in frame_indices:
for v in view_indices:
if t != 0 and v != 0:
img = decode_latents(model, latent_matrix[t, v][None], T)
img_matrix[t][v] = img * 2 - 1
# Save output videos
for v in view_indices:
vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4")
print(f"Saving {vid_file}")
save_video(vid_file, [img_matrix[t][v] for t in range(n_frames)])
# Save diagonal video
diag_frames = [
img_matrix[t][(t // (n_frames // n_views)) % n_views] for t in range(n_frames)
]
vid_file = os.path.join(output_folder, f"{base_count:06d}_diag.mp4")
print(f"Saving {vid_file}")
save_video(vid_file, diag_frames)
if __name__ == "__main__":
Fire(sample)

View File

@@ -94,7 +94,7 @@ class LinearPredictionGuider(Guider):
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
# assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
@@ -105,7 +105,7 @@ class TrianglePredictionGuider(LinearPredictionGuider):
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
period: float | List[float] = 1.0,
period: Union[float, List[float]] = 1.0,
period_fusing: Literal["mean", "multiply", "max"] = "max",
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
@@ -129,3 +129,47 @@ class TrianglePredictionGuider(LinearPredictionGuider):
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
class TrapezoidPredictionGuider(LinearPredictionGuider):
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
edge_perc: float = 0.1,
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
rise_steps = torch.linspace(min_scale, max_scale, int(num_frames * edge_perc))
fall_steps = torch.flip(rise_steps, [0])
self.scale = torch.cat(
[
rise_steps,
torch.ones(num_frames - 2 * int(num_frames * edge_perc)),
fall_steps,
]
).unsqueeze(0)
class SpatiotemporalPredictionGuider(LinearPredictionGuider):
def __init__(
self,
max_scale: float,
num_frames: int,
num_views: int = 1,
min_scale: float = 1.0,
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
V = num_views
T = num_frames // V
scale = torch.zeros(num_frames).view(T, V)
scale += torch.linspace(0, 1, T)[:,None] * 0.5
scale += self.triangle_wave(torch.linspace(0, 1, V))[None,:] * 0.5
scale = scale.flatten()
self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)
def triangle_wave(self, values: torch.Tensor, period=1) -> torch.Tensor:
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()

View File

@@ -75,20 +75,43 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
emb: th.Tensor,
context: Optional[th.Tensor] = None,
image_only_indicator: Optional[th.Tensor] = None,
cond_view: Optional[th.Tensor] = None,
cond_motion: Optional[th.Tensor] = None,
time_context: Optional[int] = None,
num_video_frames: Optional[int] = None,
time_step: Optional[int] = None,
name: Optional[str] = None,
):
from ...modules.diffusionmodules.video_model import VideoResBlock
from ...modules.diffusionmodules.video_model import VideoResBlock, PostHocResBlockWithTime
from ...modules.spacetime_attention import (
BasicTransformerTimeMixBlock,
PostHocSpatialTransformerWithTimeMixing,
PostHocSpatialTransformerWithTimeMixingAndMotion
)
for layer in self:
module = layer
if isinstance(module, TimestepBlock) and not isinstance(
module, VideoResBlock
if isinstance(
module,
(
BasicTransformerTimeMixBlock,
PostHocSpatialTransformerWithTimeMixing,
PostHocSpatialTransformerWithTimeMixingAndMotion
),
):
x = layer(x, emb)
elif isinstance(module, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
x = layer(
x,
context,
# cam,
time_context,
num_video_frames,
image_only_indicator,
cond_view,
cond_motion,
time_step,
name,
)
elif isinstance(module, SpatialVideoTransformer):
x = layer(
x,
@@ -96,7 +119,16 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
time_context,
num_video_frames,
image_only_indicator,
# time_step,
)
elif isinstance(module, PostHocResBlockWithTime):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(module, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(module, TimestepBlock) and not isinstance(
module, VideoResBlock
):
x = layer(x, emb)
elif isinstance(module, SpatialTransformer):
x = layer(x, context)
else:

View File

@@ -1,5 +1,5 @@
import torch
from typing import Optional, Union
from ...util import default, instantiate_from_config
@@ -29,3 +29,10 @@ class DiscreteSampling:
torch.randint(0, self.num_idx, (n_samples,)),
)
return self.idx_to_sigma(idx)
class ZeroSampler:
def __call__(
self, n_samples: int, rand: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5

View File

@@ -17,6 +17,36 @@ import torch.nn as nn
from einops import rearrange, repeat
def get_alpha(
merge_strategy: str,
mix_factor: Optional[torch.Tensor],
image_only_indicator: torch.Tensor,
apply_sigmoid: bool = True,
is_attn: bool = False,
) -> torch.Tensor:
if merge_strategy == "fixed" or merge_strategy == "learned":
alpha = mix_factor
elif merge_strategy == "learned_with_images":
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(mix_factor, "... -> ... 1"),
)
if is_attn:
alpha = rearrange(alpha, "b t -> (b t) 1 1")
else:
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
elif merge_strategy == "fixed_with_images":
alpha = image_only_indicator
if is_attn:
alpha = rearrange(alpha, "b t -> (b t) 1 1")
else:
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
else:
raise NotImplementedError
return torch.sigmoid(alpha) if apply_sigmoid else alpha
def make_beta_schedule(
schedule,
n_timestep,

View File

@@ -5,8 +5,13 @@ from einops import rearrange
from ...modules.diffusionmodules.openaimodel import *
from ...modules.video_attention import SpatialVideoTransformer
from ...modules.spacetime_attention import (
BasicTransformerTimeMixBlock,
PostHocSpatialTransformerWithTimeMixing,
PostHocSpatialTransformerWithTimeMixingAndMotion
)
from ...util import default
from .util import AlphaBlender
from .util import AlphaBlender # , LegacyAlphaBlenderWithBug, get_alpha
class VideoResBlock(ResBlock):
@@ -491,3 +496,746 @@ class VideoUNet(nn.Module):
)
h = h.type(x.dtype)
return self.out(h)
class PostHocAttentionBlockWithTimeMixing(AttentionBlock):
def __init__(
self,
in_channels: int,
n_heads: int,
d_head: int,
use_checkpoint: bool = False,
use_new_attention_order: bool = False,
dropout: float = 0.0,
use_spatial_context: bool = False,
merge_strategy: bool = "fixed",
merge_factor: float = 0.5,
apply_sigmoid_to_merge: bool = True,
ff_in: bool = False,
attn_mode: str = "softmax",
disable_temporal_crossattention: bool = False,
):
super().__init__(
in_channels,
n_heads,
d_head,
use_checkpoint=use_checkpoint,
use_new_attention_order=use_new_attention_order,
)
inner_dim = n_heads * d_head
self.time_mix_blocks = nn.ModuleList(
[
BasicTransformerTimeMixBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
checkpoint=use_checkpoint,
ff_in=ff_in,
attn_mode=attn_mode,
disable_temporal_crossattention=disable_temporal_crossattention,
)
]
)
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
self.time_mix_time_embed = nn.Sequential(
linear(self.in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, self.in_channels),
)
self.use_spatial_context = use_spatial_context
if merge_strategy == "fixed":
self.register_buffer("mix_factor", th.Tensor([merge_factor]))
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
self.register_parameter(
"mix_factor", th.nn.Parameter(th.Tensor([merge_factor]))
)
elif merge_strategy == "fixed_with_images":
self.mix_factor = None
else:
raise ValueError(f"unknown merge strategy {merge_strategy}")
self.get_alpha_fn = functools.partial(
get_alpha,
merge_strategy,
self.mix_factor,
apply_sigmoid=apply_sigmoid_to_merge,
)
def forward(
self,
x: th.Tensor,
context: Optional[th.Tensor] = None,
# cam: Optional[th.Tensor] = None,
time_context: Optional[th.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[th.Tensor] = None,
conv_view: Optional[th.Tensor] = None,
conv_motion: Optional[th.Tensor] = None,
):
if time_context is not None:
raise NotImplementedError
_, _, h, w = x.shape
if exists(context):
context = rearrange(context, "b t ... -> (b t) ...")
if self.use_spatial_context:
time_context = repeat(context[:, 0], "b ... -> (b n) ...", n=h * w)
x = super().forward(
x,
)
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = th.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.time_mix_time_embed(t_emb)
emb = emb[:, None, :]
x_mix = x_mix + emb
x_mix = self.time_mix_blocks[0](
x_mix, context=time_context, timesteps=timesteps
)
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
return x
class PostHocResBlockWithTime(ResBlock):
def __init__(
self,
channels: int,
emb_channels: int,
dropout: float,
time_kernel_size: Union[int, List[int]] = 3,
merge_strategy: bool = "fixed",
merge_factor: float = 0.5,
apply_sigmoid_to_merge: bool = True,
out_channels: Optional[int] = None,
use_conv: bool = False,
use_scale_shift_norm: bool = False,
dims: int = 2,
use_checkpoint: bool = False,
up: bool = False,
down: bool = False,
time_mix_legacy: bool = True,
replicate_bug: bool = False,
):
super().__init__(
channels,
emb_channels,
dropout,
out_channels=out_channels,
use_conv=use_conv,
use_scale_shift_norm=use_scale_shift_norm,
dims=dims,
use_checkpoint=use_checkpoint,
up=up,
down=down,
)
self.time_mix_blocks = ResBlock(
default(out_channels, channels),
emb_channels,
dropout=dropout,
dims=3,
out_channels=default(out_channels, channels),
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=time_kernel_size,
use_checkpoint=use_checkpoint,
exchange_temb_dims=True,
)
self.time_mix_legacy = time_mix_legacy
if self.time_mix_legacy:
if merge_strategy == "fixed":
self.register_buffer("mix_factor", th.Tensor([merge_factor]))
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
self.register_parameter(
"mix_factor", th.nn.Parameter(th.Tensor([merge_factor]))
)
elif merge_strategy == "fixed_with_images":
self.mix_factor = None
else:
raise ValueError(f"unknown merge strategy {merge_strategy}")
self.get_alpha_fn = functools.partial(
get_alpha,
merge_strategy,
self.mix_factor,
apply_sigmoid=apply_sigmoid_to_merge,
)
else:
if False: # replicate_bug:
logpy.warning(
"*****************************************************************************************\n"
"GRAVE WARNING: YOU'RE USING THE BUGGY LEGACY ALPHABLENDER!!! ARE YOU SURE YOU WANT THIS?!\n"
"*****************************************************************************************"
)
self.time_mixer = LegacyAlphaBlenderWithBug(
alpha=merge_factor,
merge_strategy=merge_strategy,
rearrange_pattern="b t -> b 1 t 1 1",
)
else:
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
rearrange_pattern="b t -> b 1 t 1 1",
)
def forward(
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator: Optional[th.Tensor] = None,
cond_view: Optional[th.Tensor] = None,
cond_motion: Optional[th.Tensor] = None,
) -> th.Tensor:
x = super().forward(x, emb)
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = self.time_mix_blocks(
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
)
if self.time_mix_legacy:
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
else:
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class SpatialUNetModelWithTime(nn.Module):
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
num_res_blocks: int,
attention_resolutions: int,
dropout: float = 0.0,
channel_mult: List[int] = (1, 2, 4, 8),
conv_resample: bool = True,
dims: int = 2,
num_classes: Optional[int] = None,
use_checkpoint: bool = False,
num_heads: int = -1,
num_head_channels: int = -1,
num_heads_upsample: int = -1,
use_scale_shift_norm: bool = False,
resblock_updown: bool = False,
use_new_attention_order: bool = False,
use_spatial_transformer: bool = False,
transformer_depth: Union[List[int], int] = 1,
transformer_depth_middle: Optional[int] = None,
context_dim: Optional[int] = None,
time_downup: bool = False,
time_context_dim: Optional[int] = None,
extra_ff_mix_layer: bool = False,
use_spatial_context: bool = False,
time_block_merge_strategy: str = "fixed",
time_block_merge_factor: float = 0.5,
spatial_transformer_attn_type: str = "softmax",
time_kernel_size: Union[int, List[int]] = 3,
use_linear_in_transformer: bool = False,
legacy: bool = True,
adm_in_channels: Optional[int] = None,
use_temporal_resblock: bool = True,
disable_temporal_crossattention: bool = False,
time_mix_legacy: bool = True,
max_ddpm_temb_period: int = 10000,
replicate_time_mix_bug: bool = False,
use_motion_attention: bool = False,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None
if context_dim is not None:
assert use_spatial_transformer
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1
if num_head_channels == -1:
assert num_heads != -1
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
transformer_depth_middle = default(
transformer_depth_middle, transformer_depth[-1]
)
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.use_temporal_resblocks = use_temporal_resblock
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "timestep":
self.label_emb = nn.Sequential(
Timestep(model_channels),
nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
),
)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
raise ValueError()
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
def get_attention_layer(
ch,
num_heads,
dim_head,
depth=1,
context_dim=None,
use_checkpoint=False,
disabled_sa=False,
):
if not use_spatial_transformer:
return PostHocAttentionBlockWithTimeMixing(
ch,
num_heads,
dim_head,
use_checkpoint=use_checkpoint,
use_new_attention_order=use_new_attention_order,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
merge_strategy=time_block_merge_strategy,
merge_factor=time_block_merge_factor,
attn_mode=spatial_transformer_attn_type,
disable_temporal_crossattention=disable_temporal_crossattention,
)
elif use_motion_attention:
return PostHocSpatialTransformerWithTimeMixingAndMotion(
ch,
num_heads,
dim_head,
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
merge_strategy=time_block_merge_strategy,
merge_factor=time_block_merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
attn_mode=spatial_transformer_attn_type,
disable_self_attn=disabled_sa,
disable_temporal_crossattention=disable_temporal_crossattention,
time_mix_legacy=time_mix_legacy,
max_time_embed_period=max_ddpm_temb_period,
)
else:
return PostHocSpatialTransformerWithTimeMixing(
ch,
num_heads,
dim_head,
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
merge_strategy=time_block_merge_strategy,
merge_factor=time_block_merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
attn_mode=spatial_transformer_attn_type,
disable_self_attn=disabled_sa,
disable_temporal_crossattention=disable_temporal_crossattention,
time_mix_legacy=time_mix_legacy,
max_time_embed_period=max_ddpm_temb_period,
)
def get_resblock(
time_block_merge_factor,
time_block_merge_strategy,
time_kernel_size,
ch,
time_embed_dim,
dropout,
out_ch,
dims,
use_checkpoint,
use_scale_shift_norm,
down=False,
up=False,
):
if self.use_temporal_resblocks:
return PostHocResBlockWithTime(
merge_factor=time_block_merge_factor,
merge_strategy=time_block_merge_strategy,
time_kernel_size=time_kernel_size,
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
time_mix_legacy=time_mix_legacy,
replicate_bug=replicate_time_mix_bug,
)
else:
return ResBlock(
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
use_checkpoint=use_checkpoint,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
)
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
get_resblock(
time_block_merge_factor=time_block_merge_factor,
time_block_merge_strategy=time_block_merge_strategy,
time_kernel_size=time_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append(
get_attention_layer(
ch,
num_heads,
dim_head,
depth=transformer_depth[level],
context_dim=context_dim,
use_checkpoint=use_checkpoint,
disabled_sa=False,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
ds *= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
get_resblock(
time_block_merge_factor=time_block_merge_factor,
time_block_merge_strategy=time_block_merge_strategy,
time_kernel_size=time_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch,
conv_resample,
dims=dims,
out_channels=out_ch,
third_down=time_downup,
)
)
)
ch = out_ch
input_block_chans.append(ch)
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
get_resblock(
time_block_merge_factor=time_block_merge_factor,
time_block_merge_strategy=time_block_merge_strategy,
time_kernel_size=time_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
out_ch=None,
dropout=dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
get_attention_layer(
ch,
num_heads,
dim_head,
depth=transformer_depth_middle,
context_dim=context_dim,
use_checkpoint=use_checkpoint,
),
get_resblock(
time_block_merge_factor=time_block_merge_factor,
time_block_merge_strategy=time_block_merge_strategy,
time_kernel_size=time_kernel_size,
ch=ch,
out_ch=None,
time_embed_dim=time_embed_dim,
dropout=dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
get_resblock(
time_block_merge_factor=time_block_merge_factor,
time_block_merge_strategy=time_block_merge_strategy,
time_kernel_size=time_kernel_size,
ch=ch + ich,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append(
get_attention_layer(
ch,
num_heads,
dim_head,
depth=transformer_depth[level],
context_dim=context_dim,
use_checkpoint=use_checkpoint,
disabled_sa=False,
)
)
if level and i == num_res_blocks:
out_ch = ch
ds //= 2
layers.append(
get_resblock(
time_block_merge_factor=time_block_merge_factor,
time_block_merge_strategy=time_block_merge_strategy,
time_kernel_size=time_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(
ch,
conv_resample,
dims=dims,
out_channels=out_ch,
third_up=time_downup,
)
)
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
def forward(
self,
x: th.Tensor,
timesteps: th.Tensor,
context: Optional[th.Tensor] = None,
y: Optional[th.Tensor] = None,
# cam: Optional[th.Tensor] = None,
time_context: Optional[th.Tensor] = None,
num_video_frames: Optional[int] = None,
image_only_indicator: Optional[th.Tensor] = None,
cond_view: Optional[th.Tensor] = None,
cond_motion: Optional[th.Tensor] = None,
time_step: Optional[int] = None,
):
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # 21 x 320
emb = self.time_embed(t_emb) # 21 x 1280
time = str(timesteps[0].data.cpu().numpy())
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y) # 21 x 1280
h = x # 21 x 8 x 64 x 64
for i, module in enumerate(self.input_blocks):
h = module(
h,
emb,
context=context,
# cam=cam,
image_only_indicator=image_only_indicator,
cond_view=cond_view,
cond_motion=cond_motion,
time_context=time_context,
num_video_frames=num_video_frames,
time_step=time_step,
name='encoder_{}_{}'.format(time, i)
)
hs.append(h)
h = self.middle_block(
h,
emb,
context=context,
# cam=cam,
image_only_indicator=image_only_indicator,
cond_view=cond_view,
cond_motion=cond_motion,
time_context=time_context,
num_video_frames=num_video_frames,
time_step=time_step,
name='middle_{}_0'.format(time, i)
)
for i, module in enumerate(self.output_blocks):
h = th.cat([h, hs.pop()], dim=1)
h = module(
h,
emb,
context=context,
# cam=cam,
image_only_indicator=image_only_indicator,
cond_view=cond_view,
cond_motion=cond_motion,
time_context=time_context,
num_video_frames=num_video_frames,
time_step=time_step,
name='decoder_{}_{}'.format(time, i)
)
h = h.type(x.dtype)
return self.out(h)

View File

@@ -25,6 +25,17 @@ class OpenAIWrapper(IdentityWrapper):
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
) -> torch.Tensor:
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
if "cond_view" in c:
return self.diffusion_model(
x,
timesteps=t,
context=c.get("crossattn", None),
y=c.get("vector", None),
cond_view=c.get("cond_view", None),
cond_motion=c.get("cond_motion", None),
**kwargs,
)
else:
return self.diffusion_model(
x,
timesteps=t,

View File

@@ -69,8 +69,8 @@ class AbstractEmbModel(nn.Module):
class GeneralConditioner(nn.Module):
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat"} # , 5: "concat"}
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1, "cond_view": 1, "cond_motion": 1}
def __init__(self, emb_models: Union[List, ListConfig]):
super().__init__()
@@ -138,7 +138,11 @@ class GeneralConditioner(nn.Module):
if not isinstance(emb_out, (list, tuple)):
emb_out = [emb_out]
for emb in emb_out:
if embedder.input_key in ["cond_view", "cond_motion"]:
out_key = embedder.input_key
else:
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
emb = (
expand_dims_like(
@@ -994,7 +998,10 @@ class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
sigmas = self.sigma_sampler(b).to(vid.device)
if self.sigma_cond is not None:
sigma_cond = self.sigma_cond(sigmas)
if self.n_cond_frames == 1:
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
else:
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_cond_frames) # For SV4D
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
noise = torch.randn_like(vid)
vid = vid + noise * append_dims(sigmas, vid.ndim)
@@ -1017,6 +1024,7 @@ class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
vid = torch.cat(all_out, dim=0)
vid *= self.scale_factor
if self.n_cond_frames == 1:
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)

View File

@@ -0,0 +1,596 @@
from functools import partial
import torch
from ..modules.attention import *
from ..modules.diffusionmodules.util import (
AlphaBlender,
get_alpha,
linear,
mixed_checkpoint,
timestep_embedding,
)
class TimeMixSequential(nn.Sequential):
def forward(self, x, context=None, timesteps=None):
for layer in self:
x = layer(x, context, timesteps)
return x
class BasicTransformerTimeMixBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention,
"softmax-xformers": MemoryEfficientCrossAttention,
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
timesteps=None,
ff_in=False,
inner_dim=None,
attn_mode="softmax",
disable_self_attn=False,
disable_temporal_crossattention=False,
switch_temporal_ca_to_sa=False,
):
super().__init__()
attn_cls = self.ATTENTION_MODES[attn_mode]
self.ff_in = ff_in or inner_dim is not None
if inner_dim is None:
inner_dim = dim
assert int(n_heads * d_head) == inner_dim
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(
dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
)
self.timesteps = timesteps
self.disable_self_attn = disable_self_attn
if self.disable_self_attn:
self.attn1 = attn_cls(
query_dim=inner_dim,
heads=n_heads,
dim_head=d_head,
context_dim=context_dim,
dropout=dropout,
) # is a cross-attention
else:
self.attn1 = attn_cls(
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
if disable_temporal_crossattention:
if switch_temporal_ca_to_sa:
raise ValueError
else:
self.attn2 = None
else:
self.norm2 = nn.LayerNorm(inner_dim)
if switch_temporal_ca_to_sa:
self.attn2 = attn_cls(
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
else:
self.attn2 = attn_cls(
query_dim=inner_dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(inner_dim)
self.norm3 = nn.LayerNorm(inner_dim)
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
self.checkpoint = checkpoint
if self.checkpoint:
logpy.info(f"{self.__class__.__name__} is using checkpointing")
def forward(
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
) -> torch.Tensor:
if self.checkpoint:
return checkpoint(self._forward, x, context, timesteps)
else:
return self._forward(x, context, timesteps=timesteps)
def _forward(self, x, context=None, timesteps=None):
assert self.timesteps or timesteps
assert not (self.timesteps and timesteps) or self.timesteps == timesteps
timesteps = self.timesteps or timesteps
B, S, C = x.shape
x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
if self.ff_in:
x_skip = x
x = self.ff_in(self.norm_in(x))
if self.is_res:
x += x_skip
if self.disable_self_attn:
x = self.attn1(self.norm1(x), context=context) + x
else:
x = self.attn1(self.norm1(x)) + x
if self.attn2 is not None:
if self.switch_temporal_ca_to_sa:
x = self.attn2(self.norm2(x)) + x
else:
x = self.attn2(self.norm2(x), context=context) + x
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = rearrange(
x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
return x
def get_last_layer(self):
return self.ff.net[-1].weight
class PostHocSpatialTransformerWithTimeMixing(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
apply_sigmoid_to_merge: bool = True,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
attn_mode="softmax",
disable_self_attn=False,
disable_temporal_crossattention=False,
time_mix_legacy: bool = True,
max_time_embed_period: int = 10000,
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
attn_type=attn_mode,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim
self.time_mix_blocks = nn.ModuleList(
[
BasicTransformerTimeMixBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=time_context_dim,
timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
attn_mode=attn_mode,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
)
for _ in range(self.depth)
]
)
assert len(self.time_mix_blocks) == len(self.transformer_blocks)
self.use_spatial_context = use_spatial_context
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
self.time_mix_time_embed = nn.Sequential(
linear(self.in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, self.in_channels),
)
self.time_mix_legacy = time_mix_legacy
if self.time_mix_legacy:
if merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([merge_factor]))
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor]))
)
elif merge_strategy == "fixed_with_images":
self.mix_factor = None
else:
raise ValueError(f"unknown merge strategy {merge_strategy}")
self.get_alpha_fn = partial(
get_alpha,
merge_strategy,
self.mix_factor,
apply_sigmoid=apply_sigmoid_to_merge,
is_attn=True,
)
else:
self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
# cam: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
cond_view: Optional[torch.Tensor] = None,
cond_motion: Optional[torch.Tensor] = None,
time_step: Optional[int] = None,
name: Optional[str] = None,
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = repeat(
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
)
elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
if time_context.ndim == 2:
time_context = rearrange(time_context, "b c -> b 1 c")
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
if self.use_linear:
x = self.proj_in(x)
if self.time_mix_legacy:
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(
num_frames,
self.in_channels,
repeat_only=False,
max_period=self.max_time_embed_period,
)
emb = self.time_mix_time_embed(t_emb)
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_mix_blocks)
):
# spatial attention
x = block(
x,
context=spatial_context,
time_step=time_step,
name=name + '_' + str(it_)
)
x_mix = x
x_mix = x_mix + emb
# temporal attention
x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
if self.time_mix_legacy:
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
else:
x = self.time_mixer(
x_spatial=x,
x_temporal=x_mix,
image_only_indicator=image_only_indicator,
)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out
class PostHocSpatialTransformerWithTimeMixingAndMotion(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
apply_sigmoid_to_merge: bool = True,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
attn_mode="softmax",
disable_self_attn=False,
disable_temporal_crossattention=False,
time_mix_legacy: bool = True,
max_time_embed_period: int = 10000,
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
attn_type=attn_mode,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim
camera_context_dim = time_context_dim
motion_context_dim = 4 # time_context_dim
# Camera attention layer
self.time_mix_blocks = nn.ModuleList(
[
BasicTransformerTimeMixBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=camera_context_dim,
timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
attn_mode=attn_mode,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
)
for _ in range(self.depth)
]
)
# Motion attention layer
self.motion_blocks = nn.ModuleList(
[
BasicTransformerTimeMixBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=motion_context_dim,
timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
attn_mode=attn_mode,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
)
for _ in range(self.depth)
]
)
assert len(self.time_mix_blocks) == len(self.transformer_blocks)
self.use_spatial_context = use_spatial_context
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
# Camera view embedding
self.time_mix_time_embed = nn.Sequential(
linear(self.in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, self.in_channels),
)
# Motion time embedding
self.time_mix_motion_embed = nn.Sequential(
linear(self.in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, self.in_channels),
)
self.time_mix_legacy = time_mix_legacy
if self.time_mix_legacy:
if merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([merge_factor]))
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor]))
)
elif merge_strategy == "fixed_with_images":
self.mix_factor = None
else:
raise ValueError(f"unknown merge strategy {merge_strategy}")
self.get_alpha_fn = partial(
get_alpha,
merge_strategy,
self.mix_factor,
apply_sigmoid=apply_sigmoid_to_merge,
is_attn=True,
)
else:
self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
# cam: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
cond_view: Optional[torch.Tensor] = None,
cond_motion: Optional[torch.Tensor] = None,
time_step: Optional[int] = None,
name: Optional[str] = None,
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
# cond_view: b v 4 h w
# cond_motion: b t 4 h w
b, t, d1 = context.shape # CLIP
v, d2 = cond_view.shape[0]//b, cond_view.shape[1] # VAE
cond_view = torch.nn.functional.interpolate(cond_view, size=(h,w), mode="bilinear") # b*v d h w
spatial_context = context[:,:,None].repeat(1,1,v,1).reshape(-1,1,d1) # (b*t*v) 1 d1
camera_context = context[:,:,None].repeat(1,1,h*w,1).reshape(-1,1,d1) # (b*t*h*w) 1 d1
motion_context = cond_view.permute(0,2,3,1).reshape(-1,1,d2) # (b*v*h*w) 1 d2
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c") # 21 x 4096 x 320
if self.use_linear:
x = self.proj_in(x)
c = x.shape[-1]
if self.time_mix_legacy:
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
num_frames = torch.arange(t, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=b)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(
num_frames,
self.in_channels,
repeat_only=False,
max_period=self.max_time_embed_period,
)
emb_time = self.time_mix_motion_embed(t_emb)
emb_time = emb_time[:, None, :] # b*t x 1 x 320
num_views = torch.arange(v, device=x.device)
num_views = repeat(num_views, "t -> b t", b=b)
num_views = rearrange(num_views, "b t -> (b t)")
v_emb = timestep_embedding(
num_views,
self.in_channels,
repeat_only=False,
max_period=self.max_time_embed_period,
)
emb_view = self.time_mix_time_embed(v_emb)
emb_view = emb_view[:, None, :] # b*v x 1 x 320
for it_, (block, time_block, mot_block) in enumerate(
zip(self.transformer_blocks, self.time_mix_blocks, self.motion_blocks)
):
# Spatial attention
x = block(
x,
context=spatial_context,
)
# Camera attention
x = x.view(b, t, v, h*w, c).permute(0,2,1,3,4).reshape(b*v,-1,c) # b*v t*h*w c
x_mix = x + emb_view
x_mix = time_block(x_mix, context=camera_context, timesteps=v)
if self.time_mix_legacy:
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
else:
x = self.time_mixer(
x_spatial=x,
x_temporal=x_mix,
image_only_indicator=image_only_indicator[:,:v],
)
# Motion attention
x = x.view(b, v, t, h*w, c).permute(0,2,1,3,4).reshape(b*t,-1,c) # b*t v*h*w c
x_mix = x + emb_time
x_mix = mot_block(x_mix, context=motion_context, timesteps=t)
if self.time_mix_legacy:
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
else:
x = self.time_mixer(
x_spatial=x,
x_temporal=x_mix,
image_only_indicator=image_only_indicator[:,:t],
)
x = x.view(b, t, v, h*w, c).reshape(-1,h*w,c) # b*t*v h*w c
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out