mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
add one full attention at the middle of the unet, prepare to do efficient attention employing every trick i know from vision transformer literature
This commit is contained in:
@@ -91,6 +91,14 @@ class RMSNorm(nn.Module):
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * self.gamma * self.scale
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs) + x
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
inner_dim = int(mult * dim)
|
||||
return nn.Sequential(
|
||||
@@ -509,6 +517,21 @@ class ConvNextBlock(nn.Module):
|
||||
h = self.net(h)
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class EinopsToAndFrom(nn.Module):
|
||||
def __init__(self, from_einops, to_einops, fn):
|
||||
super().__init__()
|
||||
self.from_einops = from_einops
|
||||
self.to_einops = to_einops
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
shape = x.shape
|
||||
reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))
|
||||
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
|
||||
x = self.fn(x, **kwargs)
|
||||
x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
|
||||
return x
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -553,7 +576,9 @@ class Unet(nn.Module):
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
@@ -616,7 +641,9 @@ class Unet(nn.Module):
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block(x, t)
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.attn(x)
|
||||
x = self.mid_block2(x, t)
|
||||
|
||||
for convnext, convnext2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
|
||||
Reference in New Issue
Block a user