From d573c82f8c190209bbf4e4daffaedf5ed521603e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 13 Apr 2022 10:39:06 -0700 Subject: [PATCH] add one full attention at the middle of the unet, prepare to do efficient attention employing every trick i know from vision transformer literature --- dalle2_pytorch/dalle2_pytorch.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 53eb9db..928cf10 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -91,6 +91,14 @@ class RMSNorm(nn.Module): inv_norm = torch.rsqrt(squared_sum + self.eps) return x * inv_norm * self.gamma * self.scale +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = RMSNorm(dim) + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + x + def FeedForward(dim, mult = 4, dropout = 0.): inner_dim = int(mult * dim) return nn.Sequential( @@ -509,6 +517,21 @@ class ConvNextBlock(nn.Module): h = self.net(h) return h + self.res_conv(x) +class EinopsToAndFrom(nn.Module): + def __init__(self, from_einops, to_einops, fn): + super().__init__() + self.from_einops = from_einops + self.to_einops = to_einops + self.fn = fn + + def forward(self, x, **kwargs): + shape = x.shape + reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape))) + x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') + x = self.fn(x, **kwargs) + x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) + return x + class Unet(nn.Module): def __init__( self, @@ -553,7 +576,9 @@ class Unet(nn.Module): ])) mid_dim = dims[-1] - self.mid_block = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) + self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) + self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim))) + self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) @@ -616,7 +641,9 @@ class Unet(nn.Module): hiddens.append(x) x = downsample(x) - x = self.mid_block(x, t) + x = self.mid_block1(x, t) + x = self.attn(x) + x = self.mid_block2(x, t) for convnext, convnext2, upsample in self.ups: x = torch.cat((x, hiddens.pop()), dim=1)