mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
bring in modified unet using convnext blocks https://arxiv.org/abs/2201.03545
This commit is contained in:
@@ -63,3 +63,11 @@ Todo
|
|||||||
primaryClass = {cs.CV}
|
primaryClass = {cs.CV}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{Liu2022ACF,
|
||||||
|
title = {A ConvNet for the 2020s},
|
||||||
|
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
||||||
|
year = {2022}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ class Transformer(nn.Module):
|
|||||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
||||||
):
|
):
|
||||||
for attn, ff in self.layers:
|
for attn, ff in self.layers:
|
||||||
x = attn(x) + x
|
x = attn(x, mask = mask) + x
|
||||||
x = ff(x) + x
|
x = ff(x) + x
|
||||||
|
|
||||||
return self.norm(x)
|
return self.norm(x)
|
||||||
@@ -168,6 +168,146 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
|
def Upsample(dim):
|
||||||
|
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||||
|
|
||||||
|
def Downsample(dim):
|
||||||
|
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||||
|
|
||||||
|
class SinusoidalPosEmb(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
|
||||||
|
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||||
|
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||||
|
|
||||||
|
class ConvNextBlock(nn.Module):
|
||||||
|
""" https://arxiv.org/abs/2201.03545 """
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_out,
|
||||||
|
*,
|
||||||
|
time_emb_dim = None,
|
||||||
|
mult = 2,
|
||||||
|
norm = True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
need_projection = dim != dim_out
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(time_emb_dim, dim)
|
||||||
|
) if exists(time_emb_dim) else None
|
||||||
|
|
||||||
|
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||||
|
|
||||||
|
inner_dim = int(dim_out * mult)
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
LayerNorm(dim) if norm else nn.Identity(),
|
||||||
|
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, time_emb = None):
|
||||||
|
h = self.ds_conv(x)
|
||||||
|
|
||||||
|
if exists(self.mlp):
|
||||||
|
assert exists(time_emb)
|
||||||
|
condition = self.mlp(time_emb)
|
||||||
|
h = h + rearrange(condition, 'b c -> b c 1 1')
|
||||||
|
|
||||||
|
h = self.net(h)
|
||||||
|
return h + self.res_conv(x)
|
||||||
|
|
||||||
|
class Unet(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
out_dim = None,
|
||||||
|
dim_mults=(1, 2, 4, 8),
|
||||||
|
channels = 3,
|
||||||
|
with_time_emb = True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
||||||
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
|
|
||||||
|
if with_time_emb:
|
||||||
|
time_dim = dim
|
||||||
|
self.time_mlp = nn.Sequential(
|
||||||
|
SinusoidalPosEmb(dim),
|
||||||
|
nn.Linear(dim, dim * 4),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(dim * 4, dim)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
time_dim = None
|
||||||
|
self.time_mlp = None
|
||||||
|
|
||||||
|
self.downs = nn.ModuleList([])
|
||||||
|
self.ups = nn.ModuleList([])
|
||||||
|
num_resolutions = len(in_out)
|
||||||
|
|
||||||
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||||
|
is_last = ind >= (num_resolutions - 1)
|
||||||
|
|
||||||
|
self.downs.append(nn.ModuleList([
|
||||||
|
ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0),
|
||||||
|
ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim),
|
||||||
|
Downsample(dim_out) if not is_last else nn.Identity()
|
||||||
|
]))
|
||||||
|
|
||||||
|
mid_dim = dims[-1]
|
||||||
|
self.mid_block = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
|
||||||
|
|
||||||
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
|
is_last = ind >= (num_resolutions - 1)
|
||||||
|
|
||||||
|
self.ups.append(nn.ModuleList([
|
||||||
|
ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim),
|
||||||
|
ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim),
|
||||||
|
Upsample(dim_in) if not is_last else nn.Identity()
|
||||||
|
]))
|
||||||
|
|
||||||
|
out_dim = default(out_dim, channels)
|
||||||
|
self.final_conv = nn.Sequential(
|
||||||
|
ConvNextBlock(dim, dim),
|
||||||
|
nn.Conv2d(dim, out_dim, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, time):
|
||||||
|
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
||||||
|
|
||||||
|
hiddens = []
|
||||||
|
|
||||||
|
for convnext, convnext2, downsample in self.downs:
|
||||||
|
x = convnext(x, t)
|
||||||
|
x = convnext2(x, t)
|
||||||
|
hiddens.append(x)
|
||||||
|
x = downsample(x)
|
||||||
|
|
||||||
|
x = self.mid_block(x, t)
|
||||||
|
|
||||||
|
for convnext, convnext2, upsample in self.ups:
|
||||||
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||||
|
x = convnext(x, t)
|
||||||
|
x = convnext2(x, t)
|
||||||
|
x = upsample(x)
|
||||||
|
|
||||||
|
return self.final_conv(x)
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user