mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-15 10:44:24 +01:00
SP4D updates
This commit is contained in:
@@ -746,3 +746,170 @@ class Decoder(nn.Module):
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class DecoderDual(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logpy.info(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
make_attn_cls = self._make_attn()
|
||||
make_resblock_cls = self._make_resblock()
|
||||
make_conv_cls = self._make_conv()
|
||||
|
||||
# z to block_in (处理单个 latent)
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn_cls(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = make_conv_cls(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def _make_attn(self) -> Callable:
|
||||
return make_attn
|
||||
|
||||
def _make_resblock(self) -> Callable:
|
||||
return ResnetBlock
|
||||
|
||||
def _make_conv(self) -> Callable:
|
||||
return torch.nn.Conv2d
|
||||
|
||||
def get_last_layer(self, **kwargs):
|
||||
return self.conv_out.weight
|
||||
|
||||
def forward(self, z, **kwargs):
|
||||
"""
|
||||
输入 z 的形状应为 (B, 2 * z_channels, H, W)
|
||||
- 其中前一半通道为第一个 latent,后一半通道为第二个 latent
|
||||
- 分离后分别解码,最终在 W 维度拼接
|
||||
"""
|
||||
# 断言检查,确保输入的通道数是 2 倍的 z_channels
|
||||
assert (
|
||||
z.shape[1] == 2 * self.z_shape[1]
|
||||
), f"Expected {2 * self.z_shape[1]} channels, got {z.shape[1]}"
|
||||
|
||||
# 分割 latent 为两个部分
|
||||
z1, z2 = torch.chunk(z, 2, dim=1) # 按照通道维度 (C) 切分
|
||||
|
||||
# 分别解码
|
||||
img1 = self.decode_single(z1, **kwargs)
|
||||
img2 = self.decode_single(z2, **kwargs)
|
||||
|
||||
# 沿着 W 维度拼接
|
||||
output = torch.cat([img1, img2], dim=-1) # 在 width 维度拼接
|
||||
|
||||
return output
|
||||
|
||||
def decode_single(self, z, **kwargs):
|
||||
"""解码单个 latent 到一张图像"""
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, None, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, None, **kwargs)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, None, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
|
||||
return h
|
||||
|
||||
Reference in New Issue
Block a user