switch to using linear attention for the sparse attention layers within unet, given success in GAN projects

This commit is contained in:
Phil Wang
2022-05-01 17:59:03 -07:00
parent 76c767b1ce
commit ad87bfe28f
3 changed files with 50 additions and 5 deletions

View File

@@ -1050,6 +1050,42 @@ class GridAttention(nn.Module):
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class LinearAttention(nn.Module):
def __init__(
self,
dim,
dim_head = 32,
heads = 8
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.nonlin = nn.GELU()
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
out = self.nonlin(out)
return self.to_out(out)
class Unet(nn.Module):
def __init__(
self,
@@ -1067,7 +1103,6 @@ class Unet(nn.Module):
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
@@ -1161,7 +1196,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
@@ -1178,7 +1213,7 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Upsample(dim_in)
]))