mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
remove convnext blocks, they are illsuited for generative work, validated by early experimental results at https://github.com/lucidrains/video-diffusion-pytorch
This commit is contained in:
@@ -866,14 +866,6 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@inproceedings{Liu2022ACF,
|
|
||||||
title = {A ConvNet for the 2020s},
|
|
||||||
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
|
||||||
year = {2022}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{shen2019efficient,
|
@article{shen2019efficient,
|
||||||
author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
|
author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
|
||||||
|
|||||||
@@ -999,68 +999,6 @@ class ResnetBlock(nn.Module):
|
|||||||
h = self.block2(h)
|
h = self.block2(h)
|
||||||
return h + self.res_conv(x)
|
return h + self.res_conv(x)
|
||||||
|
|
||||||
class ConvNextBlock(nn.Module):
|
|
||||||
""" https://arxiv.org/abs/2201.03545 """
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
dim_out,
|
|
||||||
*,
|
|
||||||
cond_dim = None,
|
|
||||||
time_cond_dim = None,
|
|
||||||
mult = 2
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
need_projection = dim != dim_out
|
|
||||||
|
|
||||||
self.cross_attn = None
|
|
||||||
|
|
||||||
if exists(cond_dim):
|
|
||||||
self.cross_attn = EinopsToAndFrom(
|
|
||||||
'b c h w',
|
|
||||||
'b (h w) c',
|
|
||||||
CrossAttention(
|
|
||||||
dim = dim,
|
|
||||||
context_dim = cond_dim
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.time_mlp = None
|
|
||||||
|
|
||||||
if exists(time_cond_dim):
|
|
||||||
self.time_mlp = nn.Sequential(
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(time_cond_dim, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
|
||||||
|
|
||||||
inner_dim = int(dim_out * mult)
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
ChanLayerNorm(dim),
|
|
||||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x, cond = None, time = None):
|
|
||||||
h = self.ds_conv(x)
|
|
||||||
|
|
||||||
if exists(time) and exists(self.time_mlp):
|
|
||||||
t = self.time_mlp(time)
|
|
||||||
h = rearrange(t, 'b c -> b c 1 1') + h
|
|
||||||
|
|
||||||
if exists(self.cross_attn):
|
|
||||||
assert exists(cond)
|
|
||||||
h = self.cross_attn(h, context = cond) + h
|
|
||||||
|
|
||||||
h = self.net(h)
|
|
||||||
|
|
||||||
return h + self.res_conv(x)
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1200,7 +1138,6 @@ class Unet(nn.Module):
|
|||||||
init_conv_kernel_size = 7,
|
init_conv_kernel_size = 7,
|
||||||
block_type = 'resnet',
|
block_type = 'resnet',
|
||||||
block_resnet_groups = 8,
|
block_resnet_groups = 8,
|
||||||
block_convnext_mult = 2,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1276,14 +1213,9 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||||
|
|
||||||
# whether to use resnet or the (improved?) convnext blocks
|
# resnet block klass
|
||||||
|
|
||||||
if block_type == 'resnet':
|
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
|
||||||
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
|
|
||||||
elif block_type == 'convnext':
|
|
||||||
block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
|
|
||||||
else:
|
|
||||||
raise ValueError(f'unimplemented block type {block_type}')
|
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
|
|
||||||
|
|||||||
@@ -331,112 +331,6 @@ class ResBlock(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x) + x
|
return self.net(x) + x
|
||||||
|
|
||||||
# convnext enc dec
|
|
||||||
|
|
||||||
class ChanLayerNorm(nn.Module):
|
|
||||||
def __init__(self, dim, eps = 1e-5):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
|
||||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
|
||||||
return (x - mean) / (var + self.eps).sqrt() * self.g
|
|
||||||
|
|
||||||
class ConvNext(nn.Module):
|
|
||||||
def __init__(self, dim, mult = 4, kernel_size = 3, ds_kernel_size = 7):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = int(dim * mult)
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Conv2d(dim, dim, ds_kernel_size, padding = ds_kernel_size // 2, groups = dim),
|
|
||||||
ChanLayerNorm(dim),
|
|
||||||
nn.Conv2d(dim, inner_dim, kernel_size, padding = kernel_size // 2),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Conv2d(inner_dim, dim, kernel_size, padding = kernel_size // 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x) + x
|
|
||||||
|
|
||||||
class ConvNextEncDec(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
*,
|
|
||||||
channels = 3,
|
|
||||||
layers = 4,
|
|
||||||
layer_mults = None,
|
|
||||||
num_blocks = 1,
|
|
||||||
first_conv_kernel_size = 5,
|
|
||||||
use_attn = True,
|
|
||||||
attn_dim_head = 64,
|
|
||||||
attn_heads = 8,
|
|
||||||
attn_dropout = 0.,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.layers = layers
|
|
||||||
|
|
||||||
self.encoders = MList([])
|
|
||||||
self.decoders = MList([])
|
|
||||||
|
|
||||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
|
||||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
|
||||||
|
|
||||||
layer_dims = [dim * mult for mult in layer_mults]
|
|
||||||
dims = (dim, *layer_dims)
|
|
||||||
|
|
||||||
self.encoded_dim = dims[-1]
|
|
||||||
|
|
||||||
dim_pairs = zip(dims[:-1], dims[1:])
|
|
||||||
|
|
||||||
append = lambda arr, t: arr.append(t)
|
|
||||||
prepend = lambda arr, t: arr.insert(0, t)
|
|
||||||
|
|
||||||
if not isinstance(num_blocks, tuple):
|
|
||||||
num_blocks = (*((0,) * (layers - 1)), num_blocks)
|
|
||||||
|
|
||||||
if not isinstance(use_attn, tuple):
|
|
||||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
|
||||||
|
|
||||||
assert len(num_blocks) == layers, 'number of blocks config must be equal to number of layers'
|
|
||||||
assert len(use_attn) == layers
|
|
||||||
|
|
||||||
for layer_index, (dim_in, dim_out), layer_num_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_blocks, use_attn):
|
|
||||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
|
||||||
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
|
||||||
|
|
||||||
if layer_use_attn:
|
|
||||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
|
||||||
|
|
||||||
for _ in range(layer_num_blocks):
|
|
||||||
append(self.encoders, ConvNext(dim_out))
|
|
||||||
prepend(self.decoders, ConvNext(dim_out))
|
|
||||||
|
|
||||||
if layer_use_attn:
|
|
||||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
|
||||||
|
|
||||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
|
||||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
|
||||||
|
|
||||||
def get_encoded_fmap_size(self, image_size):
|
|
||||||
return image_size // (2 ** self.layers)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def last_dec_layer(self):
|
|
||||||
return self.decoders[-1].weight
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
for enc in self.encoders:
|
|
||||||
x = enc(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
for dec in self.decoders:
|
|
||||||
x = dec(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
# vqgan attention layer
|
# vqgan attention layer
|
||||||
|
|
||||||
class VQGanAttention(nn.Module):
|
class VQGanAttention(nn.Module):
|
||||||
@@ -682,8 +576,6 @@ class VQGanVAE(nn.Module):
|
|||||||
enc_dec_klass = ResnetEncDec
|
enc_dec_klass = ResnetEncDec
|
||||||
elif vae_type == 'vit':
|
elif vae_type == 'vit':
|
||||||
enc_dec_klass = ViTEncDec
|
enc_dec_klass = ViTEncDec
|
||||||
elif vae_type == 'convnext':
|
|
||||||
enc_dec_klass = ConvNextEncDec
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'{vae_type} not valid')
|
raise ValueError(f'{vae_type} not valid')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user