From a35c309b5f0e70b5facff2733c886049fddcb0d8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 19 Apr 2022 09:48:27 -0700 Subject: [PATCH] add sparse attention layers in between convnext blocks in unet (grid like attention, used in mobilevit, maxvit [bytedance ai], as well as a growing number of attention-based GANs) --- README.md | 11 ++++++++++- dalle2_pytorch/dalle2_pytorch.py | 24 ++++++++++++++++++++++-- setup.py | 2 +- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 8db0333..90d5c50 100644 --- a/README.md +++ b/README.md @@ -409,9 +409,10 @@ Offer training wrappers - [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference) - [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper) - [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions +- [x] add efficient attention in unet - [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting) -- [ ] become an expert with unets, cleanup unet code, make it fully configurable, add efficient attention (conditional on resolution), port all learnings over to https://github.com/lucidrains/x-unet +- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] train on a toy task, offer in colab ## Citations @@ -461,4 +462,12 @@ Offer training wrappers } ``` +```bibtex +@inproceedings{Tu2022MaxViTMV, + title = {MaxViT: Multi-Axis Vision Transformer}, + author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li}, + year = {2022} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 81f508b..d1a4538 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -798,6 +798,20 @@ class CrossAttention(nn.Module): out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) +class GridAttention(nn.Module): + def __init__(self, *args, window_size = 8, **kwargs): + super().__init__() + self.window_size = window_size + self.attn = Attention(*args, **kwargs) + + def forward(self, x): + h, w = x.shape[-2:] + wsz = self.window_size + x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz) + out = self.attn(x) + out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz) + return out + class Unet(nn.Module): def __init__( self, @@ -813,6 +827,8 @@ class Unet(nn.Module): lowres_cond_upsample_mode = 'bilinear', blur_sigma = 0.1, blur_kernel_size = 3, + sparse_attn = False, + sparse_attn_window = 8, # window size for sparse attention attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) ): super().__init__() @@ -875,6 +891,7 @@ class Unet(nn.Module): self.downs.append(nn.ModuleList([ ConvNextBlock(dim_in, dim_out, norm = ind != 0), + Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(), ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim), Downsample(dim_out) if not is_last else nn.Identity() ])) @@ -891,6 +908,7 @@ class Unet(nn.Module): self.ups.append(nn.ModuleList([ ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim), + Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(), ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim), Upsample(dim_in) ])) @@ -995,8 +1013,9 @@ class Unet(nn.Module): hiddens = [] - for convnext, convnext2, downsample in self.downs: + for convnext, sparse_attn, convnext2, downsample in self.downs: x = convnext(x, c) + x = sparse_attn(x) x = convnext2(x, c) hiddens.append(x) x = downsample(x) @@ -1008,9 +1027,10 @@ class Unet(nn.Module): x = self.mid_block2(x, mid_c) - for convnext, convnext2, upsample in self.ups: + for convnext, sparse_attn, convnext2, upsample in self.ups: x = torch.cat((x, hiddens.pop()), dim=1) x = convnext(x, c) + x = sparse_attn(x) x = convnext2(x, c) x = upsample(x) diff --git a/setup.py b/setup.py index 681eef5..28513c8 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.24', + version = '0.0.25', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',