From cf22affcbba384c37839cf5f2ce19c561864c407 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Apr 2022 10:58:44 -0700 Subject: [PATCH] bring in modified unet using convnext blocks https://arxiv.org/abs/2201.03545 --- README.md | 8 ++ dalle2_pytorch/dalle2_pytorch.py | 142 ++++++++++++++++++++++++++++++- 2 files changed, 149 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1cd64be..837860c 100644 --- a/README.md +++ b/README.md @@ -63,3 +63,11 @@ Todo primaryClass = {cs.CV} } ``` + +```bibtex +@inproceedings{Liu2022ACF, + title = {A ConvNet for the 2020s}, + author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie}, + year = {2022} +} +``` diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 1e06b55..42900d8 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -143,7 +143,7 @@ class Transformer(nn.Module): mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings ): for attn, ff in self.layers: - x = attn(x) + x + x = attn(x, mask = mask) + x x = ff(x) + x return self.norm(x) @@ -168,6 +168,146 @@ class DiffusionPrior(nn.Module): # decoder +def Upsample(dim): + return nn.ConvTranspose2d(dim, dim, 4, 2, 1) + +def Downsample(dim): + return nn.Conv2d(dim, dim, 4, 2, 1) + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) + emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') + return torch.cat((emb.sin(), emb.cos()), dim = -1) + +class ConvNextBlock(nn.Module): + """ https://arxiv.org/abs/2201.03545 """ + + def __init__( + self, + dim, + dim_out, + *, + time_emb_dim = None, + mult = 2, + norm = True + ): + super().__init__() + need_projection = dim != dim_out + + self.mlp = nn.Sequential( + nn.GELU(), + nn.Linear(time_emb_dim, dim) + ) if exists(time_emb_dim) else None + + self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim) + + inner_dim = int(dim_out * mult) + self.net = nn.Sequential( + LayerNorm(dim) if norm else nn.Identity(), + nn.Conv2d(dim, inner_dim, 3, padding = 1), + nn.GELU(), + nn.Conv2d(inner_dim, dim_out, 3, padding = 1) + ) + + self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity() + + def forward(self, x, time_emb = None): + h = self.ds_conv(x) + + if exists(self.mlp): + assert exists(time_emb) + condition = self.mlp(time_emb) + h = h + rearrange(condition, 'b c -> b c 1 1') + + h = self.net(h) + return h + self.res_conv(x) + +class Unet(nn.Module): + def __init__( + self, + dim, + out_dim = None, + dim_mults=(1, 2, 4, 8), + channels = 3, + with_time_emb = True + ): + super().__init__() + self.channels = channels + + dims = [channels, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + if with_time_emb: + time_dim = dim + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, dim * 4), + nn.GELU(), + nn.Linear(dim * 4, dim) + ) + else: + time_dim = None + self.time_mlp = None + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0), + ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim), + Downsample(dim_out) if not is_last else nn.Identity() + ])) + + mid_dim = dims[-1] + self.mid_block = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + is_last = ind >= (num_resolutions - 1) + + self.ups.append(nn.ModuleList([ + ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim), + ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim), + Upsample(dim_in) if not is_last else nn.Identity() + ])) + + out_dim = default(out_dim, channels) + self.final_conv = nn.Sequential( + ConvNextBlock(dim, dim), + nn.Conv2d(dim, out_dim, 1) + ) + + def forward(self, x, time): + t = self.time_mlp(time) if exists(self.time_mlp) else None + + hiddens = [] + + for convnext, convnext2, downsample in self.downs: + x = convnext(x, t) + x = convnext2(x, t) + hiddens.append(x) + x = downsample(x) + + x = self.mid_block(x, t) + + for convnext, convnext2, upsample in self.ups: + x = torch.cat((x, hiddens.pop()), dim=1) + x = convnext(x, t) + x = convnext2(x, t) + x = upsample(x) + + return self.final_conv(x) + class Decoder(nn.Module): def __init__( self,