mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-23 14:44:31 +01:00
SP4D updates
This commit is contained in:
@@ -13,6 +13,7 @@ from ...modules.spacetime_attention import (
|
||||
from ...util import default
|
||||
from .util import AlphaBlender, get_alpha
|
||||
|
||||
import torch
|
||||
|
||||
class VideoResBlock(ResBlock):
|
||||
def __init__(
|
||||
@@ -1252,3 +1253,157 @@ class SpatialUNetModelWithTime(nn.Module):
|
||||
)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class CrossNetworkLayer(nn.Module):
|
||||
def __init__(self, feature_dim: int):
|
||||
super().__init__()
|
||||
self.fusion_conv = nn.Sequential(
|
||||
nn.Conv2d(feature_dim * 2, feature_dim, kernel_size=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(feature_dim, feature_dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, h1: torch.Tensor, h2: torch.Tensor):
|
||||
"""
|
||||
h1, h2: (B, C, H, W)
|
||||
return: (out1, out2), (B, C, H, W)
|
||||
"""
|
||||
fused_input = torch.cat([h1, h2], dim=1) # (B, 2C, H, W)
|
||||
fused_output = self.fusion_conv(fused_input) # (B, C, H, W)
|
||||
out1 = fused_output + h1
|
||||
out2 = fused_output + h2
|
||||
return out1, out2
|
||||
|
||||
|
||||
class DualSpatialUNetWithCrossComm(nn.Module):
|
||||
def __init__(self, unet_config):
|
||||
super().__init__()
|
||||
self.num_classes = unet_config["num_classes"]
|
||||
self.model_channels = unet_config["model_channels"]
|
||||
|
||||
self.net1 = SpatialUNetModelWithTime(**unet_config)
|
||||
self.net2 = SpatialUNetModelWithTime(**unet_config)
|
||||
|
||||
self.input_cross_layers = nn.ModuleList()
|
||||
for block in self.net1.input_blocks:
|
||||
out_ch = self._get_block_out_channels(block)
|
||||
self.input_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
|
||||
|
||||
middle_out_ch = self._get_block_out_channels(self.net1.middle_block)
|
||||
self.middle_cross = CrossNetworkLayer(feature_dim=middle_out_ch)
|
||||
|
||||
self.output_cross_layers = nn.ModuleList()
|
||||
for block in self.net1.output_blocks:
|
||||
out_ch = self._get_block_out_channels(block)
|
||||
self.output_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
|
||||
|
||||
def _get_block_out_channels(self, block: nn.Module) -> int:
|
||||
mod_list = list(block.children())
|
||||
for m in reversed(mod_list):
|
||||
if hasattr(m, "out_channels"):
|
||||
return m.out_channels
|
||||
|
||||
if isinstance(
|
||||
m,
|
||||
(SpatialTransformer, PostHocSpatialTransformerWithTimeMixingAndMotion),
|
||||
):
|
||||
return m.in_channels
|
||||
|
||||
if isinstance(m, nn.Conv2d):
|
||||
return m.out_channels
|
||||
|
||||
raise ValueError(f"Cannot determine out_channels from block: {block}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
timesteps: th.Tensor,
|
||||
context: Optional[th.Tensor] = None,
|
||||
y: Optional[th.Tensor] = None,
|
||||
cam: Optional[th.Tensor] = None,
|
||||
time_context: Optional[th.Tensor] = None,
|
||||
num_video_frames: Optional[int] = None,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
cond_view: Optional[th.Tensor] = None,
|
||||
cond_motion: Optional[th.Tensor] = None,
|
||||
time_step: Optional[int] = None,
|
||||
):
|
||||
|
||||
# ============ encoder ============
|
||||
h1, h2 = x[:, : x.shape[1] // 2], x[:, x.shape[1] // 2 :]
|
||||
|
||||
encoder_feats1 = []
|
||||
encoder_feats2 = []
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
||||
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False
|
||||
) # 21 x 320
|
||||
|
||||
emb = self.net1.time_embed(t_emb) # 21 x 1280
|
||||
time = str(timesteps[0].data.cpu().numpy())
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == h1.shape[0]
|
||||
emb = emb + self.net1.label_emb(y) # 21 x 1280
|
||||
|
||||
filtered_args = {
|
||||
"emb": emb,
|
||||
"context": context,
|
||||
"cam": cam,
|
||||
"cond_view": cond_view,
|
||||
"cond_motion": cond_motion,
|
||||
"time_context": time_context,
|
||||
"num_video_frames": num_video_frames,
|
||||
"image_only_indicator": image_only_indicator,
|
||||
"time_step": time_step,
|
||||
}
|
||||
|
||||
for i, (block1, block2) in enumerate(
|
||||
zip(self.net1.input_blocks, self.net2.input_blocks)
|
||||
):
|
||||
h1 = block1(h1, name="encoder_{}_{}".format(time, i), **filtered_args)
|
||||
h2 = block2(h2, name="encoder_{}_{}".format(time, i), **filtered_args)
|
||||
|
||||
# cross
|
||||
h1, h2 = self.input_cross_layers[i](h1, h2)
|
||||
|
||||
encoder_feats1.append(h1)
|
||||
encoder_feats2.append(h2)
|
||||
|
||||
# ============ middle block ============
|
||||
h1 = self.net1.middle_block(
|
||||
h1, name="middle_{}_0".format(time, i), **filtered_args
|
||||
)
|
||||
h2 = self.net2.middle_block(
|
||||
h2, name="middle_{}_0".format(time, i), **filtered_args
|
||||
)
|
||||
|
||||
# cross
|
||||
h1, h2 = self.middle_cross(h1, h2)
|
||||
|
||||
# ============ decoder ============
|
||||
for i, (block1, block2) in enumerate(
|
||||
zip(self.net1.output_blocks, self.net2.output_blocks)
|
||||
):
|
||||
skip1 = encoder_feats1.pop()
|
||||
skip2 = encoder_feats2.pop()
|
||||
h1 = torch.cat([h1, skip1], dim=1)
|
||||
h2 = torch.cat([h2, skip2], dim=1)
|
||||
|
||||
h1 = block1(h1, name="decoder_{}_{}".format(time, i), **filtered_args)
|
||||
h2 = block2(h2, name="decoder_{}_{}".format(time, i), **filtered_args)
|
||||
|
||||
# cross
|
||||
h1, h2 = self.output_cross_layers[i](h1, h2)
|
||||
|
||||
# ============ output ============
|
||||
out1 = self.net1.out(h1) # shape: (B, out_channels, H, W)
|
||||
out2 = self.net2.out(h2) # same shape
|
||||
out = torch.cat([out1, out2], dim=1)
|
||||
|
||||
return out
|
||||
Reference in New Issue
Block a user