From 1bb9fc982911f2e8f45b871277962a3cb15a6df0 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 1 May 2022 09:32:24 -0700 Subject: [PATCH] add convnext backbone for vqgan-vae, still need to fix groupnorms in resnet encdec --- dalle2_pytorch/vqgan_vae.py | 104 ++++++++++++++++++++++++++++++++++++ setup.py | 2 +- 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index 76f395f..bc203b5 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -327,6 +327,108 @@ class ResBlock(nn.Module): def forward(self, x): return self.net(x) + x +# convnext enc dec + +class ChanLayerNorm(nn.Module): + def __init__(self, dim, eps = 1e-5): + super().__init__() + self.eps = eps + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) / (var + self.eps).sqrt() * self.g + +class ConvNext(nn.Module): + def __init__(self, dim, mult = 4, kernel_size = 3, ds_kernel_size = 7): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Conv2d(dim, dim, ds_kernel_size, padding = ds_kernel_size // 2, groups = dim), + ChanLayerNorm(dim), + nn.Conv2d(dim, inner_dim, kernel_size, padding = kernel_size // 2), + nn.GELU(), + nn.Conv2d(inner_dim, dim, kernel_size, padding = kernel_size // 2) + ) + + def forward(self, x): + return self.net(x) + x + +class ConvNextEncDec(nn.Module): + def __init__( + self, + dim, + *, + channels = 3, + layers = 4, + layer_mults = None, + num_blocks = 1, + first_conv_kernel_size = 5, + use_attn = True, + attn_dim_head = 64, + attn_heads = 8, + attn_dropout = 0., + ): + super().__init__() + + self.layers = layers + + self.encoders = MList([]) + self.decoders = MList([]) + + layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers)))) + assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers' + + layer_dims = [dim * mult for mult in layer_mults] + dims = (dim, *layer_dims) + + self.encoded_dim = dims[-1] + + dim_pairs = zip(dims[:-1], dims[1:]) + + append = lambda arr, t: arr.append(t) + prepend = lambda arr, t: arr.insert(0, t) + + if not isinstance(num_blocks, tuple): + num_blocks = (*((0,) * (layers - 1)), num_blocks) + + if not isinstance(use_attn, tuple): + use_attn = (*((False,) * (layers - 1)), use_attn) + + assert len(num_blocks) == layers, 'number of blocks config must be equal to number of layers' + assert len(use_attn) == layers + + for layer_index, (dim_in, dim_out), layer_num_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_blocks, use_attn): + append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu())) + prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu())) + + if layer_use_attn: + prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) + + for _ in range(layer_num_blocks): + append(self.encoders, ConvNext(dim_out)) + prepend(self.decoders, ConvNext(dim_out)) + + if layer_use_attn: + append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) + + prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2)) + append(self.decoders, nn.Conv2d(dim, channels, 1)) + + def get_encoded_fmap_size(self, image_size): + return image_size // (2 ** self.layers) + + def encode(self, x): + for enc in self.encoders: + x = enc(x) + return x + + def decode(self, x): + for dec in self.decoders: + x = dec(x) + return x + # vqgan attention layer class VQGanAttention(nn.Module): @@ -568,6 +670,8 @@ class VQGanVAE(nn.Module): enc_dec_klass = ResnetEncDec elif vae_type == 'vit': enc_dec_klass = ViTEncDec + elif vae_type == 'convnext': + enc_dec_klass = ConvNextEncDec else: raise ValueError(f'{vae_type} not valid') diff --git a/setup.py b/setup.py index 82ae401..40fd70e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.85', + version = '0.0.86', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',