mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
add sparse attention layers in between convnext blocks in unet (grid like attention, used in mobilevit, maxvit [bytedance ai], as well as a growing number of attention-based GANs)
This commit is contained in:
11
README.md
11
README.md
@@ -409,9 +409,10 @@ Offer training wrappers
|
|||||||
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||||
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||||
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
||||||
|
- [x] add efficient attention in unet
|
||||||
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||||
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, add efficient attention (conditional on resolution), port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
@@ -461,4 +462,12 @@ Offer training wrappers
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{Tu2022MaxViTMV,
|
||||||
|
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||||
|
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||||
|
year = {2022}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||||
|
|||||||
@@ -798,6 +798,20 @@ class CrossAttention(nn.Module):
|
|||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
class GridAttention(nn.Module):
|
||||||
|
def __init__(self, *args, window_size = 8, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.window_size = window_size
|
||||||
|
self.attn = Attention(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
wsz = self.window_size
|
||||||
|
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
|
||||||
|
out = self.attn(x)
|
||||||
|
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
||||||
|
return out
|
||||||
|
|
||||||
class Unet(nn.Module):
|
class Unet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -813,6 +827,8 @@ class Unet(nn.Module):
|
|||||||
lowres_cond_upsample_mode = 'bilinear',
|
lowres_cond_upsample_mode = 'bilinear',
|
||||||
blur_sigma = 0.1,
|
blur_sigma = 0.1,
|
||||||
blur_kernel_size = 3,
|
blur_kernel_size = 3,
|
||||||
|
sparse_attn = False,
|
||||||
|
sparse_attn_window = 8, # window size for sparse attention
|
||||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -875,6 +891,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||||
|
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
||||||
Downsample(dim_out) if not is_last else nn.Identity()
|
Downsample(dim_out) if not is_last else nn.Identity()
|
||||||
]))
|
]))
|
||||||
@@ -891,6 +908,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
||||||
|
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
||||||
Upsample(dim_in)
|
Upsample(dim_in)
|
||||||
]))
|
]))
|
||||||
@@ -995,8 +1013,9 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
for convnext, convnext2, downsample in self.downs:
|
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
||||||
x = convnext(x, c)
|
x = convnext(x, c)
|
||||||
|
x = sparse_attn(x)
|
||||||
x = convnext2(x, c)
|
x = convnext2(x, c)
|
||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x)
|
x = downsample(x)
|
||||||
@@ -1008,9 +1027,10 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
x = self.mid_block2(x, mid_c)
|
x = self.mid_block2(x, mid_c)
|
||||||
|
|
||||||
for convnext, convnext2, upsample in self.ups:
|
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||||
x = convnext(x, c)
|
x = convnext(x, c)
|
||||||
|
x = sparse_attn(x)
|
||||||
x = convnext2(x, c)
|
x = convnext2(x, c)
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user