|
|
|
|
@@ -1352,6 +1352,7 @@ class Unet(nn.Module):
|
|
|
|
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
|
|
|
|
cross_embed_downsample = False,
|
|
|
|
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
|
|
|
|
memory_efficient = False,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -1462,10 +1463,11 @@ class Unet(nn.Module):
|
|
|
|
|
layer_cond_dim = cond_dim if not is_first else None
|
|
|
|
|
|
|
|
|
|
self.downs.append(nn.ModuleList([
|
|
|
|
|
downsample_klass(dim_in, dim_out = dim_out),
|
|
|
|
|
ResnetBlock(dim_out, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
|
|
|
|
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
|
|
|
|
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
|
|
|
|
downsample_klass(dim_out) if not is_last and not memory_efficient else None
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
mid_dim = dims[-1]
|
|
|
|
|
@@ -1475,18 +1477,18 @@ class Unet(nn.Module):
|
|
|
|
|
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
|
|
|
|
|
|
|
|
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
|
|
|
|
is_last = ind >= (num_resolutions - 2)
|
|
|
|
|
is_last = ind >= (len(in_out) - 1)
|
|
|
|
|
layer_cond_dim = cond_dim if not is_last else None
|
|
|
|
|
|
|
|
|
|
self.ups.append(nn.ModuleList([
|
|
|
|
|
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
|
|
|
|
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
|
|
|
|
Upsample(dim_in)
|
|
|
|
|
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
self.final_conv = nn.Sequential(
|
|
|
|
|
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
|
|
|
|
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
|
|
|
|
|
nn.Conv2d(dim, self.channels_out, 1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -1558,6 +1560,7 @@ class Unet(nn.Module):
|
|
|
|
|
# initial convolution
|
|
|
|
|
|
|
|
|
|
x = self.init_conv(x)
|
|
|
|
|
r = x.clone() # final residual
|
|
|
|
|
|
|
|
|
|
# time conditioning
|
|
|
|
|
|
|
|
|
|
@@ -1655,8 +1658,10 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
hiddens = []
|
|
|
|
|
|
|
|
|
|
for downsample, init_block, sparse_attn, resnet_blocks in self.downs:
|
|
|
|
|
x = downsample(x)
|
|
|
|
|
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
|
|
|
|
|
if exists(pre_downsample):
|
|
|
|
|
x = pre_downsample(x)
|
|
|
|
|
|
|
|
|
|
x = init_block(x, c, t)
|
|
|
|
|
x = sparse_attn(x)
|
|
|
|
|
|
|
|
|
|
@@ -1665,6 +1670,9 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
hiddens.append(x)
|
|
|
|
|
|
|
|
|
|
if exists(post_downsample):
|
|
|
|
|
x = post_downsample(x)
|
|
|
|
|
|
|
|
|
|
x = self.mid_block1(x, mid_c, t)
|
|
|
|
|
|
|
|
|
|
if exists(self.mid_attn):
|
|
|
|
|
@@ -1673,7 +1681,7 @@ class Unet(nn.Module):
|
|
|
|
|
x = self.mid_block2(x, mid_c, t)
|
|
|
|
|
|
|
|
|
|
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
|
|
|
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
|
|
|
|
x = torch.cat((x, hiddens.pop()), dim = 1)
|
|
|
|
|
x = init_block(x, c, t)
|
|
|
|
|
x = sparse_attn(x)
|
|
|
|
|
|
|
|
|
|
@@ -1682,6 +1690,7 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
x = upsample(x)
|
|
|
|
|
|
|
|
|
|
x = torch.cat((x, r), dim = 1)
|
|
|
|
|
return self.final_conv(x)
|
|
|
|
|
|
|
|
|
|
class LowresConditioner(nn.Module):
|
|
|
|
|
|