From ad87bfe28f610dae0db6dc6d418f2d6bce5a144b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 1 May 2022 17:59:03 -0700 Subject: [PATCH] switch to using linear attention for the sparse attention layers within unet, given success in GAN projects --- README.md | 12 +++++++++- dalle2_pytorch/dalle2_pytorch.py | 41 +++++++++++++++++++++++++++++--- setup.py | 2 +- 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 2ea3dc2..b251ce2 100644 --- a/README.md +++ b/README.md @@ -861,12 +861,22 @@ Once built, images will be saved to the same directory the command is invoked ```bibtex @inproceedings{Liu2022ACF, - title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s}, + title = {A ConvNet for the 2020s}, author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie}, year = {2022} } ``` +```bibtex +@article{shen2019efficient, + author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li}, + title = {Efficient Attention: Attention with Linear Complexities}, + journal = {CoRR}, + year = {2018}, + url = {http://arxiv.org/abs/1812.01243}, +} +``` + ```bibtex @inproceedings{Tu2022MaxViTMV, title = {MaxViT: Multi-Axis Vision Transformer}, diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index ac23d4c..67bf52d 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1050,6 +1050,42 @@ class GridAttention(nn.Module): out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz) return out +class LinearAttention(nn.Module): + def __init__( + self, + dim, + dim_head = 32, + heads = 8 + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + inner_dim = dim_head * heads + self.norm = ChanLayerNorm(dim) + + self.nonlin = nn.GELU() + self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False) + + def forward(self, fmap): + h, x, y = self.heads, *fmap.shape[-2:] + + fmap = self.norm(fmap) + q, k, v = self.to_qkv(fmap).chunk(3, dim = 1) + q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h) + + q = q.softmax(dim = -1) + k = k.softmax(dim = -2) + + q = q * self.scale + + context = einsum('b n d, b n e -> b d e', k, v) + out = einsum('b n d, b d e -> b n e', q, context) + out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) + + out = self.nonlin(out) + return self.to_out(out) + class Unet(nn.Module): def __init__( self, @@ -1067,7 +1103,6 @@ class Unet(nn.Module): attn_heads = 16, lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ sparse_attn = False, - sparse_attn_window = 8, # window size for sparse attention attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) cond_on_text_encodings = False, max_text_len = 256, @@ -1161,7 +1196,7 @@ class Unet(nn.Module): self.downs.append(nn.ModuleList([ ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0), - Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(), + Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), Downsample(dim_out) if not is_last else nn.Identity() ])) @@ -1178,7 +1213,7 @@ class Unet(nn.Module): self.ups.append(nn.ModuleList([ ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), - Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(), + Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(), ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), Upsample(dim_in) ])) diff --git a/setup.py b/setup.py index eb52bf3..178a666 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.87', + version = '0.0.88', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',