diff --git a/README.md b/README.md index 19a0a6c..f461d80 100644 --- a/README.md +++ b/README.md @@ -644,6 +644,7 @@ Once built, images will be saved to the same directory the command is invoked - [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0 - [x] use attention-based upsampling https://arxiv.org/abs/2112.11435 - [x] use inheritance just this once for sharing logic between decoder and prior network ddpms +- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion - [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network @@ -651,7 +652,6 @@ Once built, images will be saved to the same directory the command is invoked - [ ] train on a toy task, offer in colab - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] bring in tools to train vqgan-vae -- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion ## Citations diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index 39fd45c..dceff14 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -12,6 +12,8 @@ from torch.autograd import grad as torch_grad import torchvision from einops import rearrange, reduce, repeat +from einops_exts import rearrange_many +from einops.layers.torch import Rearrange from dalle2_pytorch.attention import QueryAttnUpsample @@ -146,6 +148,8 @@ class LayerNormChan(nn.Module): mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) / (var + self.eps).sqrt() * self.gamma +# discriminator + class Discriminator(nn.Module): def __init__( self, @@ -179,6 +183,8 @@ class Discriminator(nn.Module): return self.to_logits(x) +# positional encoding + class ContinuousPositionBias(nn.Module): """ from https://arxiv.org/abs/2111.09883 """ @@ -213,6 +219,84 @@ class ContinuousPositionBias(nn.Module): bias = rearrange(rel_pos, 'i j h -> h i j') return x + bias +# resnet encoder / decoder + +class ResnetEncDec(nn.Module): + def __init__( + self, + dim, + *, + channels = 3, + layers = 4, + layer_mults = None, + num_resnet_blocks = 1, + resnet_groups = 16, + first_conv_kernel_size = 5, + use_attn = True, + attn_dim_head = 64, + attn_heads = 8, + attn_dropout = 0., + ): + super().__init__() + assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)' + + self.layers = layers + + self.encoders = MList([]) + self.decoders = MList([]) + + layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers)))) + assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers' + + layer_dims = [dim * mult for mult in layer_mults] + dims = (dim, *layer_dims) + + self.encoded_dim = dims[-1] + + dim_pairs = zip(dims[:-1], dims[1:]) + + append = lambda arr, t: arr.append(t) + prepend = lambda arr, t: arr.insert(0, t) + + if not isinstance(num_resnet_blocks, tuple): + num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks) + + if not isinstance(use_attn, tuple): + use_attn = (*((False,) * (layers - 1)), use_attn) + + assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers' + assert len(use_attn) == layers + + for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn): + append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu())) + prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu())) + + if layer_use_attn: + prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) + + for _ in range(layer_num_resnet_blocks): + append(self.encoders, ResBlock(dim_out, groups = resnet_groups)) + prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups)) + + if layer_use_attn: + append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) + + prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2)) + append(self.decoders, nn.Conv2d(dim, channels, 1)) + + def get_encoded_fmap_size(self, image_size): + return image_size // (2 ** self.layers) + + def encode(self, x): + for enc in self.encoders: + x = enc(x) + return x + + def decode(self, x): + for dec in self.decoders: + x = dec(x) + return x + class GLUResBlock(nn.Module): def __init__(self, chan, groups = 16): super().__init__() @@ -246,6 +330,7 @@ class ResBlock(nn.Module): return self.net(x) + x # vqgan attention layer + class VQGanAttention(nn.Module): def __init__( self, @@ -290,6 +375,145 @@ class VQGanAttention(nn.Module): return out + residual +# ViT encoder / decoder + +class RearrangeImage(nn.Module): + def forward(self, x): + n = x.shape[1] + w = h = int(sqrt(n)) + return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w) + +class Attention(nn.Module): + def __init__( + self, + dim, + *, + heads = 8, + dim_head = 32 + ): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = dim_head * heads + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Linear(inner_dim, dim) + + def forward(self, x): + h = self.heads + + x = self.norm(x) + + q, k, v = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h) + + q = q * self.scale + sim = einsum('b h i d, b h j d -> b h i j', q, k) + + sim = sim - sim.amax(dim = -1, keepdim = True).detach() + attn = sim.softmax(dim = -1) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +def FeedForward(dim, mult = 4): + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * mult, bias = False), + nn.GELU(), + nn.Linear(dim * mult, dim, bias = False) + ) + +class Transformer(nn.Module): + def __init__( + self, + dim, + *, + layers, + dim_head = 32, + heads = 8, + ff_mult = 4 + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(layers): + self.layers.append(nn.ModuleList([ + Attention(dim = dim, dim_head = dim_head, heads = heads), + FeedForward(dim = dim, mult = ff_mult) + ])) + + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + +class ViTEncDec(nn.Module): + def __init__( + self, + dim, + channels = 3, + layers = 4, + patch_size = 8, + dim_head = 32, + heads = 8, + ff_mult = 4 + ): + super().__init__() + self.encoded_dim = dim + self.patch_size = patch_size + + input_dim = channels * (patch_size ** 2) + + self.encoder = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), + nn.Linear(input_dim, dim), + Transformer( + dim = dim, + dim_head = dim_head, + heads = heads, + ff_mult = ff_mult, + layers = layers + ), + RearrangeImage(), + Rearrange('b h w c -> b c h w') + ) + + self.decoder = nn.Sequential( + Rearrange('b c h w -> b (h w) c'), + Transformer( + dim = dim, + dim_head = dim_head, + heads = heads, + ff_mult = ff_mult, + layers = layers + ), + nn.Sequential( + nn.Linear(dim, dim * 4, bias = False), + nn.Tanh(), + nn.Linear(dim * 4, input_dim, bias = False), + ), + RearrangeImage(), + Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size) + ) + + def get_encoded_fmap_size(self, image_size): + return image_size // self.patch_size + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.decoder(x) + +# main vqgan-vae classes + class NullVQGanVAE(nn.Module): def __init__( self, @@ -320,81 +544,43 @@ class VQGanVAE(nn.Module): image_size, channels = 3, layers = 4, - layer_mults = None, l2_recon_loss = False, use_hinge_loss = True, - num_resnet_blocks = 1, vgg = None, vq_codebook_size = 512, vq_decay = 0.8, vq_commitment_weight = 1., vq_kmeans_init = True, vq_use_cosine_sim = True, - use_attn = True, - attn_dim_head = 64, - attn_heads = 8, - resnet_groups = 16, - attn_dropout = 0., - first_conv_kernel_size = 5, use_vgg_and_gan = True, + vae_type = 'resnet', + discr_layers = 4, **kwargs ): super().__init__() - assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)' - vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs) + encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs) self.image_size = image_size self.channels = channels - self.layers = layers - self.fmap_size = image_size // (layers ** 2) self.codebook_size = vq_codebook_size - self.encoders = MList([]) - self.decoders = MList([]) + if vae_type == 'resnet': + enc_dec_klass = ResnetEncDec + elif vae_type == 'vit': + enc_dec_klass = ViTEncDec + else: + raise ValueError(f'{vae_type} not valid') - layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers)))) - assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers' - - layer_dims = [dim * mult for mult in layer_mults] - dims = (dim, *layer_dims) - codebook_dim = layer_dims[-1] - - self.encoded_dim = dims[-1] - - dim_pairs = zip(dims[:-1], dims[1:]) - - append = lambda arr, t: arr.append(t) - prepend = lambda arr, t: arr.insert(0, t) - - if not isinstance(num_resnet_blocks, tuple): - num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks) - - if not isinstance(use_attn, tuple): - use_attn = (*((False,) * (layers - 1)), use_attn) - - assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers' - assert len(use_attn) == layers - - for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn): - append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu())) - prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu())) - - if layer_use_attn: - prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) - - for _ in range(layer_num_resnet_blocks): - append(self.encoders, ResBlock(dim_out, groups = resnet_groups)) - prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups)) - - if layer_use_attn: - append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) - - prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2)) - append(self.decoders, nn.Conv2d(dim, channels, 1)) + self.enc_dec = enc_dec_klass( + dim = dim, + channels = channels, + layers = layers, + **encdec_kwargs + ) self.vq = VQ( - dim = codebook_dim, + dim = self.enc_dec.encoded_dim, codebook_size = vq_codebook_size, decay = vq_decay, commitment_weight = vq_commitment_weight, @@ -427,13 +613,21 @@ class VQGanVAE(nn.Module): # gan related losses + layer_mults = list(map(lambda t: 2 ** t, range(discr_layers))) + layer_dims = [dim * mult for mult in layer_mults] + dims = (dim, *layer_dims) + self.discr = Discriminator(dims = dims, channels = channels) self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss + @property + def encoded_dim(self): + return self.enc_dec.encoded_dim + def get_encoded_fmap_size(self, image_size): - return image_size // (2 ** self.layers) + return self.enc_dec.get_encoded_fmap_size(image_size) def copy_for_eval(self): device = next(self.parameters()).device @@ -459,16 +653,13 @@ class VQGanVAE(nn.Module): return self.vq.codebook def encode(self, fmap): - for enc in self.encoders: - fmap = enc(fmap) - + fmap = self.enc_dec.encode(fmap) return fmap def decode(self, fmap, return_indices_and_loss = False): fmap, indices, commit_loss = self.vq(fmap) - for dec in self.decoders: - fmap = dec(fmap) + fmap = self.enc_dec.decode(fmap) if not return_indices_and_loss: return fmap diff --git a/setup.py b/setup.py index 36a7e2b..316375e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.52', + version = '0.0.54', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',