mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
add one full attention at the middle of the unet, prepare to do efficient attention employing every trick i know from vision transformer literature
This commit is contained in:
@@ -91,6 +91,14 @@ class RMSNorm(nn.Module):
|
|||||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||||
return x * inv_norm * self.gamma * self.scale
|
return x * inv_norm * self.gamma * self.scale
|
||||||
|
|
||||||
|
class PreNormResidual(nn.Module):
|
||||||
|
def __init__(self, dim, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = RMSNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
return self.fn(self.norm(x), **kwargs) + x
|
||||||
|
|
||||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||||
inner_dim = int(mult * dim)
|
inner_dim = int(mult * dim)
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
@@ -509,6 +517,21 @@ class ConvNextBlock(nn.Module):
|
|||||||
h = self.net(h)
|
h = self.net(h)
|
||||||
return h + self.res_conv(x)
|
return h + self.res_conv(x)
|
||||||
|
|
||||||
|
class EinopsToAndFrom(nn.Module):
|
||||||
|
def __init__(self, from_einops, to_einops, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.from_einops = from_einops
|
||||||
|
self.to_einops = to_einops
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
shape = x.shape
|
||||||
|
reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))
|
||||||
|
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
|
||||||
|
x = self.fn(x, **kwargs)
|
||||||
|
x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
|
||||||
|
return x
|
||||||
|
|
||||||
class Unet(nn.Module):
|
class Unet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -553,7 +576,9 @@ class Unet(nn.Module):
|
|||||||
]))
|
]))
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
self.mid_block = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||||
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
|
||||||
|
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||||
|
|
||||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
is_last = ind >= (num_resolutions - 1)
|
is_last = ind >= (num_resolutions - 1)
|
||||||
@@ -616,7 +641,9 @@ class Unet(nn.Module):
|
|||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x)
|
x = downsample(x)
|
||||||
|
|
||||||
x = self.mid_block(x, t)
|
x = self.mid_block1(x, t)
|
||||||
|
x = self.attn(x)
|
||||||
|
x = self.mid_block2(x, t)
|
||||||
|
|
||||||
for convnext, convnext2, upsample in self.ups:
|
for convnext, convnext2, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user