add one full attention at the middle of the unet, prepare to do efficient attention employing every trick i know from vision transformer literature

This commit is contained in:
Phil Wang
2022-04-13 10:39:06 -07:00
parent 3aa6f91e7a
commit d573c82f8c

View File

@@ -91,6 +91,14 @@ class RMSNorm(nn.Module):
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * self.gamma * self.scale
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = RMSNorm(dim)
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x
def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(mult * dim)
return nn.Sequential(
@@ -509,6 +517,21 @@ class ConvNextBlock(nn.Module):
h = self.net(h)
return h + self.res_conv(x)
class EinopsToAndFrom(nn.Module):
def __init__(self, from_einops, to_einops, fn):
super().__init__()
self.from_einops = from_einops
self.to_einops = to_einops
self.fn = fn
def forward(self, x, **kwargs):
shape = x.shape
reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
x = self.fn(x, **kwargs)
x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
return x
class Unet(nn.Module):
def __init__(
self,
@@ -553,7 +576,9 @@ class Unet(nn.Module):
]))
mid_dim = dims[-1]
self.mid_block = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
@@ -616,7 +641,9 @@ class Unet(nn.Module):
hiddens.append(x)
x = downsample(x)
x = self.mid_block(x, t)
x = self.mid_block1(x, t)
x = self.attn(x)
x = self.mid_block2(x, t)
for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)