mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
switch to using linear attention for the sparse attention layers within unet, given success in GAN projects
This commit is contained in:
12
README.md
12
README.md
@@ -861,12 +861,22 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@inproceedings{Liu2022ACF,
|
@inproceedings{Liu2022ACF,
|
||||||
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
|
title = {A ConvNet for the 2020s},
|
||||||
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
||||||
year = {2022}
|
year = {2022}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{shen2019efficient,
|
||||||
|
author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
|
||||||
|
title = {Efficient Attention: Attention with Linear Complexities},
|
||||||
|
journal = {CoRR},
|
||||||
|
year = {2018},
|
||||||
|
url = {http://arxiv.org/abs/1812.01243},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@inproceedings{Tu2022MaxViTMV,
|
@inproceedings{Tu2022MaxViTMV,
|
||||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||||
|
|||||||
@@ -1050,6 +1050,42 @@ class GridAttention(nn.Module):
|
|||||||
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class LinearAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_head = 32,
|
||||||
|
heads = 8
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
self.heads = heads
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.norm = ChanLayerNorm(dim)
|
||||||
|
|
||||||
|
self.nonlin = nn.GELU()
|
||||||
|
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||||
|
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||||
|
|
||||||
|
def forward(self, fmap):
|
||||||
|
h, x, y = self.heads, *fmap.shape[-2:]
|
||||||
|
|
||||||
|
fmap = self.norm(fmap)
|
||||||
|
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
|
||||||
|
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
|
||||||
|
|
||||||
|
q = q.softmax(dim = -1)
|
||||||
|
k = k.softmax(dim = -2)
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
context = einsum('b n d, b n e -> b d e', k, v)
|
||||||
|
out = einsum('b n d, b d e -> b n e', q, context)
|
||||||
|
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
|
||||||
|
|
||||||
|
out = self.nonlin(out)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
class Unet(nn.Module):
|
class Unet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1067,7 +1103,6 @@ class Unet(nn.Module):
|
|||||||
attn_heads = 16,
|
attn_heads = 16,
|
||||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||||
sparse_attn = False,
|
sparse_attn = False,
|
||||||
sparse_attn_window = 8, # window size for sparse attention
|
|
||||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||||
cond_on_text_encodings = False,
|
cond_on_text_encodings = False,
|
||||||
max_text_len = 256,
|
max_text_len = 256,
|
||||||
@@ -1161,7 +1196,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
|
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
|
||||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||||
Downsample(dim_out) if not is_last else nn.Identity()
|
Downsample(dim_out) if not is_last else nn.Identity()
|
||||||
]))
|
]))
|
||||||
@@ -1178,7 +1213,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||||
Upsample(dim_in)
|
Upsample(dim_in)
|
||||||
]))
|
]))
|
||||||
|
|||||||
Reference in New Issue
Block a user