bring in modified unet using convnext blocks https://arxiv.org/abs/2201.03545

This commit is contained in:
Phil Wang
2022-04-12 10:58:44 -07:00
parent 522f42f582
commit cf22affcbb
2 changed files with 149 additions and 1 deletions

View File

@@ -63,3 +63,11 @@ Todo
primaryClass = {cs.CV} primaryClass = {cs.CV}
} }
``` ```
```bibtex
@inproceedings{Liu2022ACF,
title = {A ConvNet for the 2020s},
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022}
}
```

View File

@@ -143,7 +143,7 @@ class Transformer(nn.Module):
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
): ):
for attn, ff in self.layers: for attn, ff in self.layers:
x = attn(x) + x x = attn(x, mask = mask) + x
x = ff(x) + x x = ff(x) + x
return self.norm(x) return self.norm(x)
@@ -168,6 +168,146 @@ class DiffusionPrior(nn.Module):
# decoder # decoder
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
class ConvNextBlock(nn.Module):
""" https://arxiv.org/abs/2201.03545 """
def __init__(
self,
dim,
dim_out,
*,
time_emb_dim = None,
mult = 2,
norm = True
):
super().__init__()
need_projection = dim != dim_out
self.mlp = nn.Sequential(
nn.GELU(),
nn.Linear(time_emb_dim, dim)
) if exists(time_emb_dim) else None
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
LayerNorm(dim) if norm else nn.Identity(),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
def forward(self, x, time_emb = None):
h = self.ds_conv(x)
if exists(self.mlp):
assert exists(time_emb)
condition = self.mlp(time_emb)
h = h + rearrange(condition, 'b c -> b c 1 1')
h = self.net(h)
return h + self.res_conv(x)
class Unet(nn.Module):
def __init__(
self,
dim,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
with_time_emb = True
):
super().__init__()
self.channels = channels
dims = [channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
if with_time_emb:
time_dim = dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
else:
time_dim = None
self.time_mlp = None
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim),
ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim),
Upsample(dim_in) if not is_last else nn.Identity()
]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
ConvNextBlock(dim, dim),
nn.Conv2d(dim, out_dim, 1)
)
def forward(self, x, time):
t = self.time_mlp(time) if exists(self.time_mlp) else None
hiddens = []
for convnext, convnext2, downsample in self.downs:
x = convnext(x, t)
x = convnext2(x, t)
hiddens.append(x)
x = downsample(x)
x = self.mid_block(x, t)
for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, t)
x = convnext2(x, t)
x = upsample(x)
return self.final_conv(x)
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(
self, self,