From 3b520dfa857df3a7ed775ecf64b5c513668895b3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 25 Apr 2022 17:27:45 -0700 Subject: [PATCH] bring in attention-based upsampling to strengthen vqgan-vae, seems to work as advertised in initial experiments in GAN --- README.md | 10 ++++ dalle2_pytorch/vqgan_vae.py | 108 +++++++++++++++++++++++++++++++++++- setup.py | 2 +- 3 files changed, 118 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 27dc5ab..f89cd03 100644 --- a/README.md +++ b/README.md @@ -577,4 +577,14 @@ Once built, images will be saved to the same directory the command is invoked } ``` +```bibtex +@article{Arar2021LearnedQF, + title = {Learned Queries for Efficient Local Attention}, + author = {Moab Arar and Ariel Shamir and Amit H. Bermano}, + journal = {ArXiv}, + year = {2021}, + volume = {abs/2112.11435} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index 8441237..3b69355 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -243,6 +243,112 @@ class ResBlock(nn.Module): def forward(self, x): return self.net(x) + x +# attention-based upsampling +# from https://arxiv.org/abs/2112.11435 + +class QueryAndAttend(nn.Module): + def __init__( + self, + *, + dim, + num_queries = 1, + dim_head = 32, + heads = 8, + window_size = 3 + ): + super().__init__() + self.scale = dim_head ** -0.5 + inner_dim = dim_head * heads + self.heads = heads + self.dim_head = dim_head + self.window_size = window_size + self.num_queries = num_queries + + self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1)) + + self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head)) + self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False) + self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False) + + def forward(self, x): + """ + einstein notation + b - batch + h - heads + l - num queries + d - head dimension + x - height + y - width + j - source sequence for attending to (kernel size squared in this case) + """ + + wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries + batch, _, height, width = x.shape + + is_one_query = self.num_queries == 1 + + # queries, keys, values + + q = self.queries * self.scale + k, v = self.to_kv(x).chunk(2, dim = 1) + + # similarities + + sim = einsum('h l d, b d x y -> b h l x y', q, k) + sim = rearrange(sim, 'b ... x y -> b (...) x y') + + # unfold the similarity scores, with float(-inf) as padding value + + mask_value = -torch.finfo(sim.dtype).max + sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value) + sim = F.unfold(sim, kernel_size = wsz) + sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width) + + # rel pos bias + + sim = sim + self.rel_pos_bias + + # numerically stable attention + + sim = sim - sim.amax(dim = -3, keepdim = True).detach() + attn = sim.softmax(dim = -3) + + # unfold values + + v = F.pad(v, ((wsz // 2,) * 4), value = 0.) + v = F.unfold(v, kernel_size = wsz) + v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width) + + # aggregate values + + out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v) + + # combine heads + + out = rearrange(out, 'b l h d x y -> (b l) (h d) x y') + out = self.to_out(out) + out = rearrange(out, '(b l) d x y -> b l d x y', b = batch) + + # return original input if one query + + if is_one_query: + out = rearrange(out, 'b 1 ... -> b ...') + + return out + +class QueryAttnUpsample(nn.Module): + def __init__(self, dim, **kwargs): + super().__init__() + self.norm = LayerNormChan(dim) + self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs) + + def forward(self, x): + x = self.norm(x) + out = self.qna(x) + out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2) + return out + +# vqgan attention layer class VQGanAttention(nn.Module): def __init__( self, @@ -375,7 +481,7 @@ class VQGanVAE(nn.Module): for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn): append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu())) - prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu())) + prepend(self.decoders, nn.Sequential(QueryAttnUpsample(dim_out), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu())) if layer_use_attn: prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) diff --git a/setup.py b/setup.py index 520aecc..c49d40f 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.43', + version = '0.0.44', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',