allow for full customization of number of resnet blocks per down or upsampling layers in unet, as in imagen

This commit is contained in:
Phil Wang
2022-05-26 08:33:31 -07:00
parent 645e207441
commit f4fe6c570d
2 changed files with 19 additions and 11 deletions

View File

@@ -1346,6 +1346,7 @@ class Unet(nn.Module):
init_dim = None,
init_conv_kernel_size = 7,
resnet_groups = 8,
num_resnet_blocks = 1,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
@@ -1431,6 +1432,7 @@ class Unet(nn.Module):
# resnet block klass
resnet_groups = cast_tuple(resnet_groups, len(in_out))
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
assert len(resnet_groups) == len(in_out)
@@ -1446,7 +1448,7 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
@@ -1454,7 +1456,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
downsample_klass(dim_out) if not is_last else nn.Identity()
]))
@@ -1464,14 +1466,14 @@ class Unet(nn.Module):
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_in)
]))
@@ -1628,10 +1630,13 @@ class Unet(nn.Module):
hiddens = []
for block1, sparse_attn, block2, downsample in self.downs:
x = block1(x, c, t)
for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
x = init_block(x, c, t)
x = sparse_attn(x)
x = block2(x, c, t)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
hiddens.append(x)
x = downsample(x)
@@ -1642,11 +1647,14 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t)
for block1, sparse_attn, block2, upsample in self.ups:
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = block1(x, c, t)
x = init_block(x, c, t)
x = sparse_attn(x)
x = block2(x, c, t)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
x = upsample(x)
return self.final_conv(x)