mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
bring in modified unet using convnext blocks https://arxiv.org/abs/2201.03545
This commit is contained in:
@@ -63,3 +63,11 @@ Todo
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Liu2022ACF,
|
||||
title = {A ConvNet for the 2020s},
|
||||
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -143,7 +143,7 @@ class Transformer(nn.Module):
|
||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
||||
):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = attn(x, mask = mask) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
@@ -168,6 +168,146 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# decoder
|
||||
|
||||
def Upsample(dim):
|
||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
|
||||
def Downsample(dim):
|
||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
|
||||
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
|
||||
class ConvNextBlock(nn.Module):
|
||||
""" https://arxiv.org/abs/2201.03545 """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
*,
|
||||
time_emb_dim = None,
|
||||
mult = 2,
|
||||
norm = True
|
||||
):
|
||||
super().__init__()
|
||||
need_projection = dim != dim_out
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Linear(time_emb_dim, dim)
|
||||
) if exists(time_emb_dim) else None
|
||||
|
||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
LayerNorm(dim) if norm else nn.Identity(),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
)
|
||||
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
||||
|
||||
def forward(self, x, time_emb = None):
|
||||
h = self.ds_conv(x)
|
||||
|
||||
if exists(self.mlp):
|
||||
assert exists(time_emb)
|
||||
condition = self.mlp(time_emb)
|
||||
h = h + rearrange(condition, 'b c -> b c 1 1')
|
||||
|
||||
h = self.net(h)
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
with_time_emb = True
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
if with_time_emb:
|
||||
time_dim = dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * 4, dim)
|
||||
)
|
||||
else:
|
||||
time_dim = None
|
||||
self.time_mlp = None
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0),
|
||||
ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim),
|
||||
ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim),
|
||||
Upsample(dim_in) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
out_dim = default(out_dim, channels)
|
||||
self.final_conv = nn.Sequential(
|
||||
ConvNextBlock(dim, dim),
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
)
|
||||
|
||||
def forward(self, x, time):
|
||||
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
||||
|
||||
hiddens = []
|
||||
|
||||
for convnext, convnext2, downsample in self.downs:
|
||||
x = convnext(x, t)
|
||||
x = convnext2(x, t)
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block(x, t)
|
||||
|
||||
for convnext, convnext2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = convnext(x, t)
|
||||
x = convnext2(x, t)
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user