diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c5ce08a..1ef9b3a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1084,8 +1084,9 @@ class DiffusionPrior(BaseGaussianDiffusion): def Upsample(dim): return nn.ConvTranspose2d(dim, dim, 4, 2, 1) -def Downsample(dim): - return nn.Conv2d(dim, dim, 4, 2, 1) +def Downsample(dim, *, dim_out = None): + dim_out = default(dim_out, dim) + return nn.Conv2d(dim, dim_out, 4, 2, 1) class SinusoidalPosEmb(nn.Module): def __init__(self, dim): @@ -1370,7 +1371,7 @@ class Unet(nn.Module): self.channels_out = default(channels_out, channels) init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis - init_dim = default(init_dim, dim // 3 * 2) + init_dim = default(init_dim, dim) self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) @@ -1461,10 +1462,10 @@ class Unet(nn.Module): layer_cond_dim = cond_dim if not is_first else None self.downs.append(nn.ModuleList([ - ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups), + downsample_klass(dim_in, dim_out = dim_out), + ResnetBlock(dim_out, dim_out, time_cond_dim = time_cond_dim, groups = groups), Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), - downsample_klass(dim_out) if not is_last else nn.Identity() ])) mid_dim = dims[-1] @@ -1473,7 +1474,7 @@ class Unet(nn.Module): self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) - for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))): + for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))): is_last = ind >= (num_resolutions - 2) layer_cond_dim = cond_dim if not is_last else None @@ -1654,7 +1655,8 @@ class Unet(nn.Module): hiddens = [] - for init_block, sparse_attn, resnet_blocks, downsample in self.downs: + for downsample, init_block, sparse_attn, resnet_blocks in self.downs: + x = downsample(x) x = init_block(x, c, t) x = sparse_attn(x) @@ -1662,7 +1664,6 @@ class Unet(nn.Module): x = resnet_block(x, c, t) hiddens.append(x) - x = downsample(x) x = self.mid_block1(x, mid_c, t) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index f0788a8..32a90a3 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.7.1' +__version__ = '0.8.0'