diff --git a/README.md b/README.md index 2a5ab42..523d46a 100644 --- a/README.md +++ b/README.md @@ -911,4 +911,4 @@ Once built, images will be saved to the same directory the command is invoked } ``` -*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper +*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 2ab0e7e..2cb9811 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -930,6 +930,72 @@ class SinusoidalPosEmb(nn.Module): emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') return torch.cat((emb.sin(), emb.cos()), dim = -1) +class Block(nn.Module): + def __init__( + self, + dim, + dim_out, + groups = 8 + ): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(dim, dim_out, 3, padding = 1), + nn.GroupNorm(groups, dim_out), + nn.SiLU() + ) + def forward(self, x): + return self.block(x) + +class ResnetBlock(nn.Module): + def __init__( + self, + dim, + dim_out, + *, + cond_dim = None, + time_cond_dim = None, + groups = 8 + ): + super().__init__() + + self.time_mlp = None + + if exists(time_cond_dim): + self.time_mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_cond_dim, dim_out) + ) + + self.cross_attn = None + + if exists(cond_dim): + self.cross_attn = EinopsToAndFrom( + 'b c h w', + 'b (h w) c', + CrossAttention( + dim = dim_out, + context_dim = cond_dim + ) + ) + + self.block1 = Block(dim, dim_out, groups = groups) + self.block2 = Block(dim_out, dim_out, groups = groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, cond = None, time_emb = None): + h = self.block1(x) + + if exists(self.time_mlp) and exists(time_emb): + time_emb = self.time_mlp(time_emb) + h = rearrange(time_emb, 'b c -> b c 1 1') + h + + if exists(self.cross_attn): + assert exists(cond) + h = self.cross_attn(h, context = cond) + h + + h = self.block2(h) + return h + self.res_conv(x) + class ConvNextBlock(nn.Module): """ https://arxiv.org/abs/2201.03545 """ @@ -940,8 +1006,7 @@ class ConvNextBlock(nn.Module): *, cond_dim = None, time_cond_dim = None, - mult = 2, - norm = True + mult = 2 ): super().__init__() need_projection = dim != dim_out @@ -970,7 +1035,7 @@ class ConvNextBlock(nn.Module): inner_dim = int(dim_out * mult) self.net = nn.Sequential( - ChanLayerNorm(dim) if norm else nn.Identity(), + ChanLayerNorm(dim), nn.Conv2d(dim, inner_dim, 3, padding = 1), nn.GELU(), nn.Conv2d(inner_dim, dim_out, 3, padding = 1) @@ -1082,7 +1147,11 @@ class LinearAttention(nn.Module): self.nonlin = nn.GELU() self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False) - self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False) + + self.to_out = nn.Sequential( + nn.Conv2d(inner_dim, dim, 1, bias = False), + ChanLayerNorm(dim) + ) def forward(self, fmap): h, x, y = self.heads, *fmap.shape[-2:] @@ -1125,7 +1194,9 @@ class Unet(nn.Module): max_text_len = 256, cond_on_image_embeds = False, init_dim = None, - init_conv_kernel_size = 7 + init_conv_kernel_size = 7, + block_type = 'resnet', + **kwargs ): super().__init__() # save locals to take care of some hyperparameters for cascading DDPM @@ -1200,6 +1271,15 @@ class Unet(nn.Module): attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) + # whether to use resnet or the (improved?) convnext blocks + + if block_type == 'resnet': + block_klass = ResnetBlock + elif block_type == 'convnext': + block_klass = ConvNextBlock + else: + raise ValueError(f'unimplemented block type {block_type}') + # layers self.downs = nn.ModuleList([]) @@ -1212,32 +1292,32 @@ class Unet(nn.Module): layer_cond_dim = cond_dim if not is_first else None self.downs.append(nn.ModuleList([ - ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0), + block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim), Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), - ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), + block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), Downsample(dim_out) if not is_last else nn.Identity() ])) mid_dim = dims[-1] - self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) + self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None - self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) + self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 2) layer_cond_dim = cond_dim if not is_last else None self.ups.append(nn.ModuleList([ - ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), + block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(), - ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), + block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), Upsample(dim_in) ])) out_dim = default(out_dim, channels) self.final_conv = nn.Sequential( - ConvNextBlock(dim, dim), + block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1) ) @@ -1368,10 +1448,10 @@ class Unet(nn.Module): hiddens = [] - for convnext, sparse_attn, convnext2, downsample in self.downs: - x = convnext(x, c, t) + for block1, sparse_attn, block2, downsample in self.downs: + x = block1(x, c, t) x = sparse_attn(x) - x = convnext2(x, c, t) + x = block2(x, c, t) hiddens.append(x) x = downsample(x) @@ -1382,11 +1462,11 @@ class Unet(nn.Module): x = self.mid_block2(x, mid_c, t) - for convnext, sparse_attn, convnext2, upsample in self.ups: + for block1, sparse_attn, block2, upsample in self.ups: x = torch.cat((x, hiddens.pop()), dim=1) - x = convnext(x, c, t) + x = block1(x, c, t) x = sparse_attn(x) - x = convnext2(x, c, t) + x = block2(x, c, t) x = upsample(x) return self.final_conv(x)