mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-23 00:04:21 +01:00
596 lines
20 KiB
Python
Executable File
596 lines
20 KiB
Python
Executable File
from functools import partial
|
|
|
|
import torch
|
|
|
|
from ..modules.attention import *
|
|
from ..modules.diffusionmodules.util import (
|
|
AlphaBlender,
|
|
get_alpha,
|
|
linear,
|
|
mixed_checkpoint,
|
|
timestep_embedding,
|
|
)
|
|
|
|
|
|
class TimeMixSequential(nn.Sequential):
|
|
def forward(self, x, context=None, timesteps=None):
|
|
for layer in self:
|
|
x = layer(x, context, timesteps)
|
|
|
|
return x
|
|
|
|
|
|
class BasicTransformerTimeMixBlock(nn.Module):
|
|
ATTENTION_MODES = {
|
|
"softmax": CrossAttention,
|
|
"softmax-xformers": MemoryEfficientCrossAttention,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
n_heads,
|
|
d_head,
|
|
dropout=0.0,
|
|
context_dim=None,
|
|
gated_ff=True,
|
|
checkpoint=True,
|
|
timesteps=None,
|
|
ff_in=False,
|
|
inner_dim=None,
|
|
attn_mode="softmax",
|
|
disable_self_attn=False,
|
|
disable_temporal_crossattention=False,
|
|
switch_temporal_ca_to_sa=False,
|
|
):
|
|
super().__init__()
|
|
|
|
attn_cls = self.ATTENTION_MODES[attn_mode]
|
|
|
|
self.ff_in = ff_in or inner_dim is not None
|
|
if inner_dim is None:
|
|
inner_dim = dim
|
|
|
|
assert int(n_heads * d_head) == inner_dim
|
|
|
|
self.is_res = inner_dim == dim
|
|
|
|
if self.ff_in:
|
|
self.norm_in = nn.LayerNorm(dim)
|
|
self.ff_in = FeedForward(
|
|
dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
|
|
)
|
|
|
|
self.timesteps = timesteps
|
|
self.disable_self_attn = disable_self_attn
|
|
if self.disable_self_attn:
|
|
self.attn1 = attn_cls(
|
|
query_dim=inner_dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
context_dim=context_dim,
|
|
dropout=dropout,
|
|
) # is a cross-attention
|
|
else:
|
|
self.attn1 = attn_cls(
|
|
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
|
) # is a self-attention
|
|
|
|
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
|
|
|
|
if disable_temporal_crossattention:
|
|
if switch_temporal_ca_to_sa:
|
|
raise ValueError
|
|
else:
|
|
self.attn2 = None
|
|
else:
|
|
self.norm2 = nn.LayerNorm(inner_dim)
|
|
if switch_temporal_ca_to_sa:
|
|
self.attn2 = attn_cls(
|
|
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
|
) # is a self-attention
|
|
else:
|
|
self.attn2 = attn_cls(
|
|
query_dim=inner_dim,
|
|
context_dim=context_dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
dropout=dropout,
|
|
) # is self-attn if context is none
|
|
|
|
self.norm1 = nn.LayerNorm(inner_dim)
|
|
self.norm3 = nn.LayerNorm(inner_dim)
|
|
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
|
|
|
self.checkpoint = checkpoint
|
|
if self.checkpoint:
|
|
logpy.info(f"{self.__class__.__name__} is using checkpointing")
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
|
|
) -> torch.Tensor:
|
|
if self.checkpoint:
|
|
return checkpoint(self._forward, x, context, timesteps)
|
|
else:
|
|
return self._forward(x, context, timesteps=timesteps)
|
|
|
|
def _forward(self, x, context=None, timesteps=None):
|
|
assert self.timesteps or timesteps
|
|
assert not (self.timesteps and timesteps) or self.timesteps == timesteps
|
|
timesteps = self.timesteps or timesteps
|
|
B, S, C = x.shape
|
|
x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
|
|
|
|
if self.ff_in:
|
|
x_skip = x
|
|
x = self.ff_in(self.norm_in(x))
|
|
if self.is_res:
|
|
x += x_skip
|
|
|
|
if self.disable_self_attn:
|
|
x = self.attn1(self.norm1(x), context=context) + x
|
|
else:
|
|
x = self.attn1(self.norm1(x)) + x
|
|
|
|
if self.attn2 is not None:
|
|
if self.switch_temporal_ca_to_sa:
|
|
x = self.attn2(self.norm2(x)) + x
|
|
else:
|
|
x = self.attn2(self.norm2(x), context=context) + x
|
|
x_skip = x
|
|
x = self.ff(self.norm3(x))
|
|
if self.is_res:
|
|
x += x_skip
|
|
|
|
x = rearrange(
|
|
x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
|
)
|
|
return x
|
|
|
|
def get_last_layer(self):
|
|
return self.ff.net[-1].weight
|
|
|
|
|
|
class PostHocSpatialTransformerWithTimeMixing(SpatialTransformer):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
n_heads,
|
|
d_head,
|
|
depth=1,
|
|
dropout=0.0,
|
|
use_linear=False,
|
|
context_dim=None,
|
|
use_spatial_context=False,
|
|
timesteps=None,
|
|
merge_strategy: str = "fixed",
|
|
merge_factor: float = 0.5,
|
|
apply_sigmoid_to_merge: bool = True,
|
|
time_context_dim=None,
|
|
ff_in=False,
|
|
checkpoint=False,
|
|
time_depth=1,
|
|
attn_mode="softmax",
|
|
disable_self_attn=False,
|
|
disable_temporal_crossattention=False,
|
|
time_mix_legacy: bool = True,
|
|
max_time_embed_period: int = 10000,
|
|
):
|
|
super().__init__(
|
|
in_channels,
|
|
n_heads,
|
|
d_head,
|
|
depth=depth,
|
|
dropout=dropout,
|
|
attn_type=attn_mode,
|
|
use_checkpoint=checkpoint,
|
|
context_dim=context_dim,
|
|
use_linear=use_linear,
|
|
disable_self_attn=disable_self_attn,
|
|
)
|
|
self.time_depth = time_depth
|
|
self.depth = depth
|
|
self.max_time_embed_period = max_time_embed_period
|
|
|
|
time_mix_d_head = d_head
|
|
n_time_mix_heads = n_heads
|
|
|
|
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
|
|
|
inner_dim = n_heads * d_head
|
|
if use_spatial_context:
|
|
time_context_dim = context_dim
|
|
|
|
self.time_mix_blocks = nn.ModuleList(
|
|
[
|
|
BasicTransformerTimeMixBlock(
|
|
inner_dim,
|
|
n_time_mix_heads,
|
|
time_mix_d_head,
|
|
dropout=dropout,
|
|
context_dim=time_context_dim,
|
|
timesteps=timesteps,
|
|
checkpoint=checkpoint,
|
|
ff_in=ff_in,
|
|
inner_dim=time_mix_inner_dim,
|
|
attn_mode=attn_mode,
|
|
disable_self_attn=disable_self_attn,
|
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
|
)
|
|
for _ in range(self.depth)
|
|
]
|
|
)
|
|
|
|
assert len(self.time_mix_blocks) == len(self.transformer_blocks)
|
|
|
|
self.use_spatial_context = use_spatial_context
|
|
self.in_channels = in_channels
|
|
|
|
time_embed_dim = self.in_channels * 4
|
|
self.time_mix_time_embed = nn.Sequential(
|
|
linear(self.in_channels, time_embed_dim),
|
|
nn.SiLU(),
|
|
linear(time_embed_dim, self.in_channels),
|
|
)
|
|
|
|
self.time_mix_legacy = time_mix_legacy
|
|
if self.time_mix_legacy:
|
|
if merge_strategy == "fixed":
|
|
self.register_buffer("mix_factor", torch.Tensor([merge_factor]))
|
|
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
|
self.register_parameter(
|
|
"mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor]))
|
|
)
|
|
elif merge_strategy == "fixed_with_images":
|
|
self.mix_factor = None
|
|
else:
|
|
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
|
|
|
self.get_alpha_fn = partial(
|
|
get_alpha,
|
|
merge_strategy,
|
|
self.mix_factor,
|
|
apply_sigmoid=apply_sigmoid_to_merge,
|
|
is_attn=True,
|
|
)
|
|
else:
|
|
self.time_mixer = AlphaBlender(
|
|
alpha=merge_factor, merge_strategy=merge_strategy
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
context: Optional[torch.Tensor] = None,
|
|
# cam: Optional[torch.Tensor] = None,
|
|
time_context: Optional[torch.Tensor] = None,
|
|
timesteps: Optional[int] = None,
|
|
image_only_indicator: Optional[torch.Tensor] = None,
|
|
cond_view: Optional[torch.Tensor] = None,
|
|
cond_motion: Optional[torch.Tensor] = None,
|
|
time_step: Optional[int] = None,
|
|
name: Optional[str] = None,
|
|
) -> torch.Tensor:
|
|
_, _, h, w = x.shape
|
|
x_in = x
|
|
spatial_context = None
|
|
if exists(context):
|
|
spatial_context = context
|
|
|
|
if self.use_spatial_context:
|
|
assert (
|
|
context.ndim == 3
|
|
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
|
|
|
time_context = context
|
|
time_context_first_timestep = time_context[::timesteps]
|
|
time_context = repeat(
|
|
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
|
|
)
|
|
elif time_context is not None and not self.use_spatial_context:
|
|
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
|
|
if time_context.ndim == 2:
|
|
time_context = rearrange(time_context, "b c -> b 1 c")
|
|
|
|
x = self.norm(x)
|
|
if not self.use_linear:
|
|
x = self.proj_in(x)
|
|
x = rearrange(x, "b c h w -> b (h w) c")
|
|
if self.use_linear:
|
|
x = self.proj_in(x)
|
|
|
|
if self.time_mix_legacy:
|
|
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
|
|
|
num_frames = torch.arange(timesteps, device=x.device)
|
|
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
|
num_frames = rearrange(num_frames, "b t -> (b t)")
|
|
t_emb = timestep_embedding(
|
|
num_frames,
|
|
self.in_channels,
|
|
repeat_only=False,
|
|
max_period=self.max_time_embed_period,
|
|
)
|
|
emb = self.time_mix_time_embed(t_emb)
|
|
emb = emb[:, None, :]
|
|
|
|
for it_, (block, mix_block) in enumerate(
|
|
zip(self.transformer_blocks, self.time_mix_blocks)
|
|
):
|
|
# spatial attention
|
|
x = block(
|
|
x,
|
|
context=spatial_context,
|
|
time_step=time_step,
|
|
name=name + '_' + str(it_)
|
|
)
|
|
|
|
x_mix = x
|
|
x_mix = x_mix + emb
|
|
|
|
# temporal attention
|
|
x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
|
|
if self.time_mix_legacy:
|
|
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
|
else:
|
|
x = self.time_mixer(
|
|
x_spatial=x,
|
|
x_temporal=x_mix,
|
|
image_only_indicator=image_only_indicator,
|
|
)
|
|
|
|
if self.use_linear:
|
|
x = self.proj_out(x)
|
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
|
if not self.use_linear:
|
|
x = self.proj_out(x)
|
|
out = x + x_in
|
|
return out
|
|
|
|
|
|
class PostHocSpatialTransformerWithTimeMixingAndMotion(SpatialTransformer):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
n_heads,
|
|
d_head,
|
|
depth=1,
|
|
dropout=0.0,
|
|
use_linear=False,
|
|
context_dim=None,
|
|
use_spatial_context=False,
|
|
timesteps=None,
|
|
merge_strategy: str = "fixed",
|
|
merge_factor: float = 0.5,
|
|
apply_sigmoid_to_merge: bool = True,
|
|
time_context_dim=None,
|
|
ff_in=False,
|
|
checkpoint=False,
|
|
time_depth=1,
|
|
attn_mode="softmax",
|
|
disable_self_attn=False,
|
|
disable_temporal_crossattention=False,
|
|
time_mix_legacy: bool = True,
|
|
max_time_embed_period: int = 10000,
|
|
):
|
|
super().__init__(
|
|
in_channels,
|
|
n_heads,
|
|
d_head,
|
|
depth=depth,
|
|
dropout=dropout,
|
|
attn_type=attn_mode,
|
|
use_checkpoint=checkpoint,
|
|
context_dim=context_dim,
|
|
use_linear=use_linear,
|
|
disable_self_attn=disable_self_attn,
|
|
)
|
|
self.time_depth = time_depth
|
|
self.depth = depth
|
|
self.max_time_embed_period = max_time_embed_period
|
|
|
|
time_mix_d_head = d_head
|
|
n_time_mix_heads = n_heads
|
|
|
|
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
|
|
|
inner_dim = n_heads * d_head
|
|
if use_spatial_context:
|
|
time_context_dim = context_dim
|
|
|
|
camera_context_dim = time_context_dim
|
|
motion_context_dim = 4 # time_context_dim
|
|
|
|
# Camera attention layer
|
|
self.time_mix_blocks = nn.ModuleList(
|
|
[
|
|
BasicTransformerTimeMixBlock(
|
|
inner_dim,
|
|
n_time_mix_heads,
|
|
time_mix_d_head,
|
|
dropout=dropout,
|
|
context_dim=camera_context_dim,
|
|
timesteps=timesteps,
|
|
checkpoint=checkpoint,
|
|
ff_in=ff_in,
|
|
inner_dim=time_mix_inner_dim,
|
|
attn_mode=attn_mode,
|
|
disable_self_attn=disable_self_attn,
|
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
|
)
|
|
for _ in range(self.depth)
|
|
]
|
|
)
|
|
|
|
# Motion attention layer
|
|
self.motion_blocks = nn.ModuleList(
|
|
[
|
|
BasicTransformerTimeMixBlock(
|
|
inner_dim,
|
|
n_time_mix_heads,
|
|
time_mix_d_head,
|
|
dropout=dropout,
|
|
context_dim=motion_context_dim,
|
|
timesteps=timesteps,
|
|
checkpoint=checkpoint,
|
|
ff_in=ff_in,
|
|
inner_dim=time_mix_inner_dim,
|
|
attn_mode=attn_mode,
|
|
disable_self_attn=disable_self_attn,
|
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
|
)
|
|
for _ in range(self.depth)
|
|
]
|
|
)
|
|
|
|
assert len(self.time_mix_blocks) == len(self.transformer_blocks)
|
|
|
|
self.use_spatial_context = use_spatial_context
|
|
self.in_channels = in_channels
|
|
|
|
time_embed_dim = self.in_channels * 4
|
|
# Camera view embedding
|
|
self.time_mix_time_embed = nn.Sequential(
|
|
linear(self.in_channels, time_embed_dim),
|
|
nn.SiLU(),
|
|
linear(time_embed_dim, self.in_channels),
|
|
)
|
|
# Motion time embedding
|
|
self.time_mix_motion_embed = nn.Sequential(
|
|
linear(self.in_channels, time_embed_dim),
|
|
nn.SiLU(),
|
|
linear(time_embed_dim, self.in_channels),
|
|
)
|
|
|
|
self.time_mix_legacy = time_mix_legacy
|
|
if self.time_mix_legacy:
|
|
if merge_strategy == "fixed":
|
|
self.register_buffer("mix_factor", torch.Tensor([merge_factor]))
|
|
elif merge_strategy == "learned" or merge_strategy == "learned_with_images":
|
|
self.register_parameter(
|
|
"mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor]))
|
|
)
|
|
elif merge_strategy == "fixed_with_images":
|
|
self.mix_factor = None
|
|
else:
|
|
raise ValueError(f"unknown merge strategy {merge_strategy}")
|
|
|
|
self.get_alpha_fn = partial(
|
|
get_alpha,
|
|
merge_strategy,
|
|
self.mix_factor,
|
|
apply_sigmoid=apply_sigmoid_to_merge,
|
|
is_attn=True,
|
|
)
|
|
else:
|
|
self.time_mixer = AlphaBlender(
|
|
alpha=merge_factor, merge_strategy=merge_strategy
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
context: Optional[torch.Tensor] = None,
|
|
# cam: Optional[torch.Tensor] = None,
|
|
time_context: Optional[torch.Tensor] = None,
|
|
timesteps: Optional[int] = None,
|
|
image_only_indicator: Optional[torch.Tensor] = None,
|
|
cond_view: Optional[torch.Tensor] = None,
|
|
cond_motion: Optional[torch.Tensor] = None,
|
|
time_step: Optional[int] = None,
|
|
name: Optional[str] = None,
|
|
) -> torch.Tensor:
|
|
_, _, h, w = x.shape
|
|
x_in = x
|
|
spatial_context = None
|
|
if exists(context):
|
|
spatial_context = context
|
|
|
|
# cond_view: b v 4 h w
|
|
# cond_motion: b t 4 h w
|
|
b, t, d1 = context.shape # CLIP
|
|
v, d2 = cond_view.shape[0]//b, cond_view.shape[1] # VAE
|
|
cond_view = torch.nn.functional.interpolate(cond_view, size=(h,w), mode="bilinear") # b*v d h w
|
|
spatial_context = context[:,:,None].repeat(1,1,v,1).reshape(-1,1,d1) # (b*t*v) 1 d1
|
|
camera_context = context[:,:,None].repeat(1,1,h*w,1).reshape(-1,1,d1) # (b*t*h*w) 1 d1
|
|
motion_context = cond_view.permute(0,2,3,1).reshape(-1,1,d2) # (b*v*h*w) 1 d2
|
|
|
|
x = self.norm(x)
|
|
if not self.use_linear:
|
|
x = self.proj_in(x)
|
|
x = rearrange(x, "b c h w -> b (h w) c") # 21 x 4096 x 320
|
|
if self.use_linear:
|
|
x = self.proj_in(x)
|
|
c = x.shape[-1]
|
|
|
|
if self.time_mix_legacy:
|
|
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
|
|
|
num_frames = torch.arange(t, device=x.device)
|
|
num_frames = repeat(num_frames, "t -> b t", b=b)
|
|
num_frames = rearrange(num_frames, "b t -> (b t)")
|
|
t_emb = timestep_embedding(
|
|
num_frames,
|
|
self.in_channels,
|
|
repeat_only=False,
|
|
max_period=self.max_time_embed_period,
|
|
)
|
|
emb_time = self.time_mix_motion_embed(t_emb)
|
|
emb_time = emb_time[:, None, :] # b*t x 1 x 320
|
|
|
|
num_views = torch.arange(v, device=x.device)
|
|
num_views = repeat(num_views, "t -> b t", b=b)
|
|
num_views = rearrange(num_views, "b t -> (b t)")
|
|
v_emb = timestep_embedding(
|
|
num_views,
|
|
self.in_channels,
|
|
repeat_only=False,
|
|
max_period=self.max_time_embed_period,
|
|
)
|
|
emb_view = self.time_mix_time_embed(v_emb)
|
|
emb_view = emb_view[:, None, :] # b*v x 1 x 320
|
|
|
|
for it_, (block, time_block, mot_block) in enumerate(
|
|
zip(self.transformer_blocks, self.time_mix_blocks, self.motion_blocks)
|
|
):
|
|
# Spatial attention
|
|
x = block(
|
|
x,
|
|
context=spatial_context,
|
|
)
|
|
|
|
# Camera attention
|
|
x = x.view(b, t, v, h*w, c).permute(0,2,1,3,4).reshape(b*v,-1,c) # b*v t*h*w c
|
|
x_mix = x + emb_view
|
|
x_mix = time_block(x_mix, context=camera_context, timesteps=v)
|
|
if self.time_mix_legacy:
|
|
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
|
else:
|
|
x = self.time_mixer(
|
|
x_spatial=x,
|
|
x_temporal=x_mix,
|
|
image_only_indicator=image_only_indicator[:,:v],
|
|
)
|
|
|
|
# Motion attention
|
|
x = x.view(b, v, t, h*w, c).permute(0,2,1,3,4).reshape(b*t,-1,c) # b*t v*h*w c
|
|
x_mix = x + emb_time
|
|
x_mix = mot_block(x_mix, context=motion_context, timesteps=t)
|
|
if self.time_mix_legacy:
|
|
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
|
else:
|
|
x = self.time_mixer(
|
|
x_spatial=x,
|
|
x_temporal=x_mix,
|
|
image_only_indicator=image_only_indicator[:,:t],
|
|
)
|
|
|
|
x = x.view(b, t, v, h*w, c).reshape(-1,h*w,c) # b*t*v h*w c
|
|
|
|
if self.use_linear:
|
|
x = self.proj_out(x)
|
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
|
if not self.use_linear:
|
|
x = self.proj_out(x)
|
|
out = x + x_in
|
|
return out |