SP4D updates

This commit is contained in:
Chun-Han Yao
2025-11-03 21:23:17 +00:00
parent 8f41cbc50b
commit fd9d14e02f
7 changed files with 749 additions and 4 deletions

View File

@@ -746,3 +746,170 @@ class Decoder(nn.Module):
if self.tanh_out:
h = torch.tanh(h)
return h
class DecoderDual(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
**ignorekwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logpy.info(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
make_attn_cls = self._make_attn()
make_resblock_cls = self._make_resblock()
make_conv_cls = self._make_conv()
# z to block_in (处理单个 latent)
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = nn.Module()
self.mid.block_1 = make_resblock_cls(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
self.mid.block_2 = make_resblock_cls(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
make_resblock_cls(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn_cls(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = make_conv_cls(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def _make_attn(self) -> Callable:
return make_attn
def _make_resblock(self) -> Callable:
return ResnetBlock
def _make_conv(self) -> Callable:
return torch.nn.Conv2d
def get_last_layer(self, **kwargs):
return self.conv_out.weight
def forward(self, z, **kwargs):
"""
输入 z 的形状应为 (B, 2 * z_channels, H, W)
- 其中前一半通道为第一个 latent后一半通道为第二个 latent
- 分离后分别解码,最终在 W 维度拼接
"""
# 断言检查,确保输入的通道数是 2 倍的 z_channels
assert (
z.shape[1] == 2 * self.z_shape[1]
), f"Expected {2 * self.z_shape[1]} channels, got {z.shape[1]}"
# 分割 latent 为两个部分
z1, z2 = torch.chunk(z, 2, dim=1) # 按照通道维度 (C) 切分
# 分别解码
img1 = self.decode_single(z1, **kwargs)
img2 = self.decode_single(z2, **kwargs)
# 沿着 W 维度拼接
output = torch.cat([img1, img2], dim=-1) # 在 width 维度拼接
return output
def decode_single(self, z, **kwargs):
"""解码单个 latent 到一张图像"""
self.last_z_shape = z.shape
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, None, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, None, **kwargs)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, None, **kwargs)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h, **kwargs)
if self.tanh_out:
h = torch.tanh(h)
return h

View File

@@ -13,6 +13,7 @@ from ...modules.spacetime_attention import (
from ...util import default
from .util import AlphaBlender, get_alpha
import torch
class VideoResBlock(ResBlock):
def __init__(
@@ -1252,3 +1253,157 @@ class SpatialUNetModelWithTime(nn.Module):
)
h = h.type(x.dtype)
return self.out(h)
class CrossNetworkLayer(nn.Module):
def __init__(self, feature_dim: int):
super().__init__()
self.fusion_conv = nn.Sequential(
nn.Conv2d(feature_dim * 2, feature_dim, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(feature_dim, feature_dim, kernel_size=1),
)
def forward(self, h1: torch.Tensor, h2: torch.Tensor):
"""
h1, h2: (B, C, H, W)
return: (out1, out2), (B, C, H, W)
"""
fused_input = torch.cat([h1, h2], dim=1) # (B, 2C, H, W)
fused_output = self.fusion_conv(fused_input) # (B, C, H, W)
out1 = fused_output + h1
out2 = fused_output + h2
return out1, out2
class DualSpatialUNetWithCrossComm(nn.Module):
def __init__(self, unet_config):
super().__init__()
self.num_classes = unet_config["num_classes"]
self.model_channels = unet_config["model_channels"]
self.net1 = SpatialUNetModelWithTime(**unet_config)
self.net2 = SpatialUNetModelWithTime(**unet_config)
self.input_cross_layers = nn.ModuleList()
for block in self.net1.input_blocks:
out_ch = self._get_block_out_channels(block)
self.input_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
middle_out_ch = self._get_block_out_channels(self.net1.middle_block)
self.middle_cross = CrossNetworkLayer(feature_dim=middle_out_ch)
self.output_cross_layers = nn.ModuleList()
for block in self.net1.output_blocks:
out_ch = self._get_block_out_channels(block)
self.output_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
def _get_block_out_channels(self, block: nn.Module) -> int:
mod_list = list(block.children())
for m in reversed(mod_list):
if hasattr(m, "out_channels"):
return m.out_channels
if isinstance(
m,
(SpatialTransformer, PostHocSpatialTransformerWithTimeMixingAndMotion),
):
return m.in_channels
if isinstance(m, nn.Conv2d):
return m.out_channels
raise ValueError(f"Cannot determine out_channels from block: {block}")
def forward(
self,
x: th.Tensor,
timesteps: th.Tensor,
context: Optional[th.Tensor] = None,
y: Optional[th.Tensor] = None,
cam: Optional[th.Tensor] = None,
time_context: Optional[th.Tensor] = None,
num_video_frames: Optional[int] = None,
image_only_indicator: Optional[th.Tensor] = None,
cond_view: Optional[th.Tensor] = None,
cond_motion: Optional[th.Tensor] = None,
time_step: Optional[int] = None,
):
# ============ encoder ============
h1, h2 = x[:, : x.shape[1] // 2], x[:, x.shape[1] // 2 :]
encoder_feats1 = []
encoder_feats2 = []
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False
) # 21 x 320
emb = self.net1.time_embed(t_emb) # 21 x 1280
time = str(timesteps[0].data.cpu().numpy())
if self.num_classes is not None:
assert y.shape[0] == h1.shape[0]
emb = emb + self.net1.label_emb(y) # 21 x 1280
filtered_args = {
"emb": emb,
"context": context,
"cam": cam,
"cond_view": cond_view,
"cond_motion": cond_motion,
"time_context": time_context,
"num_video_frames": num_video_frames,
"image_only_indicator": image_only_indicator,
"time_step": time_step,
}
for i, (block1, block2) in enumerate(
zip(self.net1.input_blocks, self.net2.input_blocks)
):
h1 = block1(h1, name="encoder_{}_{}".format(time, i), **filtered_args)
h2 = block2(h2, name="encoder_{}_{}".format(time, i), **filtered_args)
# cross
h1, h2 = self.input_cross_layers[i](h1, h2)
encoder_feats1.append(h1)
encoder_feats2.append(h2)
# ============ middle block ============
h1 = self.net1.middle_block(
h1, name="middle_{}_0".format(time, i), **filtered_args
)
h2 = self.net2.middle_block(
h2, name="middle_{}_0".format(time, i), **filtered_args
)
# cross
h1, h2 = self.middle_cross(h1, h2)
# ============ decoder ============
for i, (block1, block2) in enumerate(
zip(self.net1.output_blocks, self.net2.output_blocks)
):
skip1 = encoder_feats1.pop()
skip2 = encoder_feats2.pop()
h1 = torch.cat([h1, skip1], dim=1)
h2 = torch.cat([h2, skip2], dim=1)
h1 = block1(h1, name="decoder_{}_{}".format(time, i), **filtered_args)
h2 = block2(h2, name="decoder_{}_{}".format(time, i), **filtered_args)
# cross
h1, h2 = self.output_cross_layers[i](h1, h2)
# ============ output ============
out1 = self.net1.out(h1) # shape: (B, out_channels, H, W)
out2 = self.net2.out(h2) # same shape
out = torch.cat([out1, out2], dim=1)
return out

View File

@@ -6,7 +6,7 @@ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
class IdentityWrapper(nn.Module):
def __init__(self, diffusion_model, compile_model: bool = False):
def __init__(self, diffusion_model, compile_model: bool = False, dual_concat: bool = False):
super().__init__()
compile = (
torch.compile
@@ -15,6 +15,7 @@ class IdentityWrapper(nn.Module):
else lambda x: x
)
self.diffusion_model = compile(diffusion_model)
self.dual_concat = dual_concat
def forward(self, *args, **kwargs):
return self.diffusion_model(*args, **kwargs)
@@ -24,7 +25,14 @@ class OpenAIWrapper(IdentityWrapper):
def forward(
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
) -> torch.Tensor:
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
if self.dual_concat:
x_1 = x[:, : x.shape[1] // 2]
x_2 = x[:, x.shape[1] // 2 :]
x_1 = torch.cat((x_1, c.get("concat", torch.Tensor([]).type_as(x_1))), dim=1)
x_2 = torch.cat((x_2, c.get("concat", torch.Tensor([]).type_as(x_2))), dim=1)
x = torch.cat((x_1, x_2), dim=1)
else:
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
if "cond_view" in c:
return self.diffusion_model(
x,