mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
remove convnext blocks, they are illsuited for generative work, validated by early experimental results at https://github.com/lucidrains/video-diffusion-pytorch
This commit is contained in:
@@ -331,112 +331,6 @@ class ResBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
# convnext enc dec
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||
|
||||
class ConvNext(nn.Module):
|
||||
def __init__(self, dim, mult = 4, kernel_size = 3, ds_kernel_size = 7):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim, ds_kernel_size, padding = ds_kernel_size // 2, groups = dim),
|
||||
ChanLayerNorm(dim),
|
||||
nn.Conv2d(dim, inner_dim, kernel_size, padding = kernel_size // 2),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim, kernel_size, padding = kernel_size // 2)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
class ConvNextEncDec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
num_blocks = 1,
|
||||
first_conv_kernel_size = 5,
|
||||
use_attn = True,
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
attn_dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layers = layers
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_blocks, tuple):
|
||||
num_blocks = (*((0,) * (layers - 1)), num_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_blocks) == layers, 'number of blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
for _ in range(layer_num_blocks):
|
||||
append(self.encoders, ConvNext(dim_out))
|
||||
prepend(self.decoders, ConvNext(dim_out))
|
||||
|
||||
if layer_use_attn:
|
||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoders[-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
for dec in self.decoders:
|
||||
x = dec(x)
|
||||
return x
|
||||
|
||||
# vqgan attention layer
|
||||
|
||||
class VQGanAttention(nn.Module):
|
||||
@@ -682,8 +576,6 @@ class VQGanVAE(nn.Module):
|
||||
enc_dec_klass = ResnetEncDec
|
||||
elif vae_type == 'vit':
|
||||
enc_dec_klass = ViTEncDec
|
||||
elif vae_type == 'convnext':
|
||||
enc_dec_klass = ConvNextEncDec
|
||||
else:
|
||||
raise ValueError(f'{vae_type} not valid')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user