mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
switch to using linear attention for the sparse attention layers within unet, given success in GAN projects
This commit is contained in:
@@ -1050,6 +1050,42 @@ class GridAttention(nn.Module):
|
||||
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
||||
return out
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head = 32,
|
||||
heads = 8
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
self.norm = ChanLayerNorm(dim)
|
||||
|
||||
self.nonlin = nn.GELU()
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
def forward(self, fmap):
|
||||
h, x, y = self.heads, *fmap.shape[-2:]
|
||||
|
||||
fmap = self.norm(fmap)
|
||||
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
|
||||
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
|
||||
|
||||
q = q.softmax(dim = -1)
|
||||
k = k.softmax(dim = -2)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
context = einsum('b n d, b n e -> b d e', k, v)
|
||||
out = einsum('b n d, b d e -> b n e', q, context)
|
||||
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
|
||||
|
||||
out = self.nonlin(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1067,7 +1103,6 @@ class Unet(nn.Module):
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
sparse_attn = False,
|
||||
sparse_attn_window = 8, # window size for sparse attention
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
cond_on_text_encodings = False,
|
||||
max_text_len = 256,
|
||||
@@ -1161,7 +1196,7 @@ class Unet(nn.Module):
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
|
||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
@@ -1178,7 +1213,7 @@ class Unet(nn.Module):
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
Reference in New Issue
Block a user