SP4D updates

This commit is contained in:
Chun-Han Yao
2025-11-03 21:23:17 +00:00
parent 8f41cbc50b
commit fd9d14e02f
7 changed files with 749 additions and 4 deletions

View File

@@ -13,6 +13,7 @@ from ...modules.spacetime_attention import (
from ...util import default
from .util import AlphaBlender, get_alpha
import torch
class VideoResBlock(ResBlock):
def __init__(
@@ -1252,3 +1253,157 @@ class SpatialUNetModelWithTime(nn.Module):
)
h = h.type(x.dtype)
return self.out(h)
class CrossNetworkLayer(nn.Module):
def __init__(self, feature_dim: int):
super().__init__()
self.fusion_conv = nn.Sequential(
nn.Conv2d(feature_dim * 2, feature_dim, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(feature_dim, feature_dim, kernel_size=1),
)
def forward(self, h1: torch.Tensor, h2: torch.Tensor):
"""
h1, h2: (B, C, H, W)
return: (out1, out2), (B, C, H, W)
"""
fused_input = torch.cat([h1, h2], dim=1) # (B, 2C, H, W)
fused_output = self.fusion_conv(fused_input) # (B, C, H, W)
out1 = fused_output + h1
out2 = fused_output + h2
return out1, out2
class DualSpatialUNetWithCrossComm(nn.Module):
def __init__(self, unet_config):
super().__init__()
self.num_classes = unet_config["num_classes"]
self.model_channels = unet_config["model_channels"]
self.net1 = SpatialUNetModelWithTime(**unet_config)
self.net2 = SpatialUNetModelWithTime(**unet_config)
self.input_cross_layers = nn.ModuleList()
for block in self.net1.input_blocks:
out_ch = self._get_block_out_channels(block)
self.input_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
middle_out_ch = self._get_block_out_channels(self.net1.middle_block)
self.middle_cross = CrossNetworkLayer(feature_dim=middle_out_ch)
self.output_cross_layers = nn.ModuleList()
for block in self.net1.output_blocks:
out_ch = self._get_block_out_channels(block)
self.output_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
def _get_block_out_channels(self, block: nn.Module) -> int:
mod_list = list(block.children())
for m in reversed(mod_list):
if hasattr(m, "out_channels"):
return m.out_channels
if isinstance(
m,
(SpatialTransformer, PostHocSpatialTransformerWithTimeMixingAndMotion),
):
return m.in_channels
if isinstance(m, nn.Conv2d):
return m.out_channels
raise ValueError(f"Cannot determine out_channels from block: {block}")
def forward(
self,
x: th.Tensor,
timesteps: th.Tensor,
context: Optional[th.Tensor] = None,
y: Optional[th.Tensor] = None,
cam: Optional[th.Tensor] = None,
time_context: Optional[th.Tensor] = None,
num_video_frames: Optional[int] = None,
image_only_indicator: Optional[th.Tensor] = None,
cond_view: Optional[th.Tensor] = None,
cond_motion: Optional[th.Tensor] = None,
time_step: Optional[int] = None,
):
# ============ encoder ============
h1, h2 = x[:, : x.shape[1] // 2], x[:, x.shape[1] // 2 :]
encoder_feats1 = []
encoder_feats2 = []
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False
) # 21 x 320
emb = self.net1.time_embed(t_emb) # 21 x 1280
time = str(timesteps[0].data.cpu().numpy())
if self.num_classes is not None:
assert y.shape[0] == h1.shape[0]
emb = emb + self.net1.label_emb(y) # 21 x 1280
filtered_args = {
"emb": emb,
"context": context,
"cam": cam,
"cond_view": cond_view,
"cond_motion": cond_motion,
"time_context": time_context,
"num_video_frames": num_video_frames,
"image_only_indicator": image_only_indicator,
"time_step": time_step,
}
for i, (block1, block2) in enumerate(
zip(self.net1.input_blocks, self.net2.input_blocks)
):
h1 = block1(h1, name="encoder_{}_{}".format(time, i), **filtered_args)
h2 = block2(h2, name="encoder_{}_{}".format(time, i), **filtered_args)
# cross
h1, h2 = self.input_cross_layers[i](h1, h2)
encoder_feats1.append(h1)
encoder_feats2.append(h2)
# ============ middle block ============
h1 = self.net1.middle_block(
h1, name="middle_{}_0".format(time, i), **filtered_args
)
h2 = self.net2.middle_block(
h2, name="middle_{}_0".format(time, i), **filtered_args
)
# cross
h1, h2 = self.middle_cross(h1, h2)
# ============ decoder ============
for i, (block1, block2) in enumerate(
zip(self.net1.output_blocks, self.net2.output_blocks)
):
skip1 = encoder_feats1.pop()
skip2 = encoder_feats2.pop()
h1 = torch.cat([h1, skip1], dim=1)
h2 = torch.cat([h2, skip2], dim=1)
h1 = block1(h1, name="decoder_{}_{}".format(time, i), **filtered_args)
h2 = block2(h2, name="decoder_{}_{}".format(time, i), **filtered_args)
# cross
h1, h2 = self.output_cross_layers[i](h1, h2)
# ============ output ============
out1 = self.net1.out(h1) # shape: (B, out_channels, H, W)
out2 = self.net2.out(h2) # same shape
out = torch.cat([out1, out2], dim=1)
return out