mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
offer old resnet blocks, from the original DDPM paper, just in case convnexts are unsuitable for generative work
This commit is contained in:
@@ -930,6 +930,72 @@ class SinusoidalPosEmb(nn.Module):
|
||||
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
nn.SiLU()
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
*,
|
||||
cond_dim = None,
|
||||
time_cond_dim = None,
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_mlp = None
|
||||
|
||||
if exists(time_cond_dim):
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_cond_dim, dim_out)
|
||||
)
|
||||
|
||||
self.cross_attn = None
|
||||
|
||||
if exists(cond_dim):
|
||||
self.cross_attn = EinopsToAndFrom(
|
||||
'b c h w',
|
||||
'b (h w) c',
|
||||
CrossAttention(
|
||||
dim = dim_out,
|
||||
context_dim = cond_dim
|
||||
)
|
||||
)
|
||||
|
||||
self.block1 = Block(dim, dim_out, groups = groups)
|
||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None, time_emb = None):
|
||||
h = self.block1(x)
|
||||
|
||||
if exists(self.time_mlp) and exists(time_emb):
|
||||
time_emb = self.time_mlp(time_emb)
|
||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
||||
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
h = self.cross_attn(h, context = cond) + h
|
||||
|
||||
h = self.block2(h)
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class ConvNextBlock(nn.Module):
|
||||
""" https://arxiv.org/abs/2201.03545 """
|
||||
|
||||
@@ -940,8 +1006,7 @@ class ConvNextBlock(nn.Module):
|
||||
*,
|
||||
cond_dim = None,
|
||||
time_cond_dim = None,
|
||||
mult = 2,
|
||||
norm = True
|
||||
mult = 2
|
||||
):
|
||||
super().__init__()
|
||||
need_projection = dim != dim_out
|
||||
@@ -970,7 +1035,7 @@ class ConvNextBlock(nn.Module):
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
ChanLayerNorm(dim) if norm else nn.Identity(),
|
||||
ChanLayerNorm(dim),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
@@ -1082,7 +1147,11 @@ class LinearAttention(nn.Module):
|
||||
|
||||
self.nonlin = nn.GELU()
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1, bias = False),
|
||||
ChanLayerNorm(dim)
|
||||
)
|
||||
|
||||
def forward(self, fmap):
|
||||
h, x, y = self.heads, *fmap.shape[-2:]
|
||||
@@ -1125,7 +1194,9 @@ class Unet(nn.Module):
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7
|
||||
init_conv_kernel_size = 7,
|
||||
block_type = 'resnet',
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
# save locals to take care of some hyperparameters for cascading DDPM
|
||||
@@ -1200,6 +1271,15 @@ class Unet(nn.Module):
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
|
||||
# whether to use resnet or the (improved?) convnext blocks
|
||||
|
||||
if block_type == 'resnet':
|
||||
block_klass = ResnetBlock
|
||||
elif block_type == 'convnext':
|
||||
block_klass = ConvNextBlock
|
||||
else:
|
||||
raise ValueError(f'unimplemented block type {block_type}')
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
@@ -1212,32 +1292,32 @@ class Unet(nn.Module):
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
|
||||
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
out_dim = default(out_dim, channels)
|
||||
self.final_conv = nn.Sequential(
|
||||
ConvNextBlock(dim, dim),
|
||||
block_klass(dim, dim),
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
)
|
||||
|
||||
@@ -1368,10 +1448,10 @@ class Unet(nn.Module):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
||||
x = convnext(x, c, t)
|
||||
for block1, sparse_attn, block2, downsample in self.downs:
|
||||
x = block1(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c, t)
|
||||
x = block2(x, c, t)
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
@@ -1382,11 +1462,11 @@ class Unet(nn.Module):
|
||||
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
|
||||
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
||||
for block1, sparse_attn, block2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = convnext(x, c, t)
|
||||
x = block1(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c, t)
|
||||
x = block2(x, c, t)
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
Reference in New Issue
Block a user