mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
add convnext backbone for vqgan-vae, still need to fix groupnorms in resnet encdec
This commit is contained in:
@@ -327,6 +327,108 @@ class ResBlock(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x) + x
|
return self.net(x) + x
|
||||||
|
|
||||||
|
# convnext enc dec
|
||||||
|
|
||||||
|
class ChanLayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim, eps = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||||
|
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||||
|
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||||
|
|
||||||
|
class ConvNext(nn.Module):
|
||||||
|
def __init__(self, dim, mult = 4, kernel_size = 3, ds_kernel_size = 7):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(dim, dim, ds_kernel_size, padding = ds_kernel_size // 2, groups = dim),
|
||||||
|
ChanLayerNorm(dim),
|
||||||
|
nn.Conv2d(dim, inner_dim, kernel_size, padding = kernel_size // 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(inner_dim, dim, kernel_size, padding = kernel_size // 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x) + x
|
||||||
|
|
||||||
|
class ConvNextEncDec(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
channels = 3,
|
||||||
|
layers = 4,
|
||||||
|
layer_mults = None,
|
||||||
|
num_blocks = 1,
|
||||||
|
first_conv_kernel_size = 5,
|
||||||
|
use_attn = True,
|
||||||
|
attn_dim_head = 64,
|
||||||
|
attn_heads = 8,
|
||||||
|
attn_dropout = 0.,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = layers
|
||||||
|
|
||||||
|
self.encoders = MList([])
|
||||||
|
self.decoders = MList([])
|
||||||
|
|
||||||
|
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||||
|
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||||
|
|
||||||
|
layer_dims = [dim * mult for mult in layer_mults]
|
||||||
|
dims = (dim, *layer_dims)
|
||||||
|
|
||||||
|
self.encoded_dim = dims[-1]
|
||||||
|
|
||||||
|
dim_pairs = zip(dims[:-1], dims[1:])
|
||||||
|
|
||||||
|
append = lambda arr, t: arr.append(t)
|
||||||
|
prepend = lambda arr, t: arr.insert(0, t)
|
||||||
|
|
||||||
|
if not isinstance(num_blocks, tuple):
|
||||||
|
num_blocks = (*((0,) * (layers - 1)), num_blocks)
|
||||||
|
|
||||||
|
if not isinstance(use_attn, tuple):
|
||||||
|
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||||
|
|
||||||
|
assert len(num_blocks) == layers, 'number of blocks config must be equal to number of layers'
|
||||||
|
assert len(use_attn) == layers
|
||||||
|
|
||||||
|
for layer_index, (dim_in, dim_out), layer_num_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_blocks, use_attn):
|
||||||
|
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||||
|
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||||
|
|
||||||
|
if layer_use_attn:
|
||||||
|
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||||
|
|
||||||
|
for _ in range(layer_num_blocks):
|
||||||
|
append(self.encoders, ConvNext(dim_out))
|
||||||
|
prepend(self.decoders, ConvNext(dim_out))
|
||||||
|
|
||||||
|
if layer_use_attn:
|
||||||
|
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||||
|
|
||||||
|
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||||
|
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||||
|
|
||||||
|
def get_encoded_fmap_size(self, image_size):
|
||||||
|
return image_size // (2 ** self.layers)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
for enc in self.encoders:
|
||||||
|
x = enc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
for dec in self.decoders:
|
||||||
|
x = dec(x)
|
||||||
|
return x
|
||||||
|
|
||||||
# vqgan attention layer
|
# vqgan attention layer
|
||||||
|
|
||||||
class VQGanAttention(nn.Module):
|
class VQGanAttention(nn.Module):
|
||||||
@@ -568,6 +670,8 @@ class VQGanVAE(nn.Module):
|
|||||||
enc_dec_klass = ResnetEncDec
|
enc_dec_klass = ResnetEncDec
|
||||||
elif vae_type == 'vit':
|
elif vae_type == 'vit':
|
||||||
enc_dec_klass = ViTEncDec
|
enc_dec_klass = ViTEncDec
|
||||||
|
elif vae_type == 'convnext':
|
||||||
|
enc_dec_klass = ConvNextEncDec
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'{vae_type} not valid')
|
raise ValueError(f'{vae_type} not valid')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user