Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
57f1ddf9d2 fix missing resisidual for highest resolution of the unet 2022-06-15 19:11:58 -07:00
2 changed files with 5 additions and 9 deletions

View File

@@ -1476,23 +1476,19 @@ class Unet(nn.Module):
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
up_in_out_slice = slice(1 if not memory_efficient else None, None)
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[up_in_out_slice]), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (num_resolutions - 2)
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (len(in_out) - 1)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_in)
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
]))
final_dim_in = dim * (1 if memory_efficient else 2)
self.final_conv = nn.Sequential(
ResnetBlock(final_dim_in, dim, groups = resnet_groups[0]),
ResnetBlock(dim, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1)
)

View File

@@ -1 +1 @@
__version__ = '0.9.0'
__version__ = '0.9.1'