diff --git a/README.md b/README.md index 8e83657..3a4d480 100644 --- a/README.md +++ b/README.md @@ -1112,7 +1112,7 @@ For detailed information on training the diffusion prior, please refer to the [d - [x] allow for unet to be able to condition non-cross attention style as well - [x] speed up inference, read up on papers (ddim) - [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865 -- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow +- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 ## Citations diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index cf4c197..71c92e1 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1538,6 +1538,38 @@ class CrossEmbedLayer(nn.Module): fmaps = tuple(map(lambda conv: conv(x), self.convs)) return torch.cat(fmaps, dim = 1) +class UpsampleCombiner(nn.Module): + def __init__( + self, + dim, + *, + enabled = False, + dim_ins = tuple(), + dim_outs = tuple() + ): + super().__init__() + assert len(dim_ins) == len(dim_outs) + self.enabled = enabled + + if not self.enabled: + self.dim_out = dim + return + + self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)]) + self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0) + + def forward(self, x, fmaps = None): + target_size = x.shape[-1] + + fmaps = default(fmaps, tuple()) + + if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0: + return x + + fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps] + outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)] + return torch.cat((x, *outs), dim = 1) + class Unet(nn.Module): def __init__( self, @@ -1575,6 +1607,7 @@ class Unet(nn.Module): scale_skip_connection = False, pixel_shuffle_upsample = True, final_conv_kernel_size = 1, + combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper **kwargs ): super().__init__() @@ -1710,7 +1743,8 @@ class Unet(nn.Module): self.ups = nn.ModuleList([]) num_resolutions = len(in_out) - skip_connect_dims = [] # keeping track of skip connection dimensions + skip_connect_dims = [] # keeping track of skip connection dimensions + upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)): is_first = ind == 0 @@ -1752,6 +1786,8 @@ class Unet(nn.Module): elif sparse_attn: attention = Residual(LinearAttention(dim_out, **attn_kwargs)) + upsample_combiner_dims.append(dim_out) + self.ups.append(nn.ModuleList([ ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), @@ -1759,7 +1795,18 @@ class Unet(nn.Module): upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity() ])) - self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) + # whether to combine outputs from all upsample blocks for final resnet block + + self.upsample_combiner = UpsampleCombiner( + dim = dim, + enabled = combine_upsample_fmaps, + dim_ins = upsample_combiner_dims, + dim_outs = (dim,) * len(upsample_combiner_dims) + ) + + # a final resnet block + + self.final_resnet_block = ResnetBlock(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) out_dim_in = dim + (channels if lowres_cond else 0) @@ -1953,7 +2000,8 @@ class Unet(nn.Module): # go through the layers of the unet, down and up - hiddens = [] + down_hiddens = [] + up_hiddens = [] for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs: if exists(pre_downsample): @@ -1963,10 +2011,10 @@ class Unet(nn.Module): for resnet_block in resnet_blocks: x = resnet_block(x, t, c) - hiddens.append(x) + down_hiddens.append(x.contiguous()) x = attn(x) - hiddens.append(x.contiguous()) + down_hiddens.append(x.contiguous()) if exists(post_downsample): x = post_downsample(x) @@ -1978,7 +2026,7 @@ class Unet(nn.Module): x = self.mid_block2(x, t, mid_c) - connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1) + connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1) for init_block, resnet_blocks, attn, upsample in self.ups: x = connect_skip(x) @@ -1989,8 +2037,12 @@ class Unet(nn.Module): x = resnet_block(x, t, c) x = attn(x) + + up_hiddens.append(x.contiguous()) x = upsample(x) + x = self.upsample_combiner(x, up_hiddens) + x = torch.cat((x, r), dim = 1) x = self.final_resnet_block(x, t) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index da2182f..1a72d32 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.0.6' +__version__ = '1.1.0'