diff --git a/README.md b/README.md index f461d80..3637dd8 100644 --- a/README.md +++ b/README.md @@ -697,16 +697,6 @@ Once built, images will be saved to the same directory the command is invoked } ``` -```bibtex -@article{Arar2021LearnedQF, - title = {Learned Queries for Efficient Local Attention}, - author = {Moab Arar and Ariel Shamir and Amit H. Bermano}, - journal = {ArXiv}, - year = {2021}, - volume = {abs/2112.11435} -} -``` - ```bibtex @article{Yu2021VectorquantizedIM, title = {Vector-quantized Image Modeling with Improved VQGAN}, diff --git a/dalle2_pytorch/attention.py b/dalle2_pytorch/attention.py deleted file mode 100644 index 7b6145d..0000000 --- a/dalle2_pytorch/attention.py +++ /dev/null @@ -1,130 +0,0 @@ -import torch -from torch import nn, einsum -import torch.nn.functional as F - -from einops import rearrange, repeat - -class LayerNormChan(nn.Module): - def __init__( - self, - dim, - eps = 1e-5 - ): - super().__init__() - self.eps = eps - self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1)) - - def forward(self, x): - var = torch.var(x, dim = 1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = 1, keepdim = True) - return (x - mean) / (var + self.eps).sqrt() * self.gamma - -# attention-based upsampling -# from https://arxiv.org/abs/2112.11435 - -class QueryAndAttend(nn.Module): - def __init__( - self, - *, - dim, - num_queries = 1, - dim_head = 32, - heads = 8, - window_size = 3 - ): - super().__init__() - self.scale = dim_head ** -0.5 - inner_dim = dim_head * heads - self.heads = heads - self.dim_head = dim_head - self.window_size = window_size - self.num_queries = num_queries - - self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1)) - - self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head)) - self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False) - - self.to_out = nn.Sequential( - nn.Conv2d(inner_dim, dim * 2, 1, bias = False), - nn.Tanh(), - nn.Conv2d(dim * 2, dim, 1, bias = False) - ) - - def forward(self, x): - """ - einstein notation - b - batch - h - heads - l - num queries - d - head dimension - x - height - y - width - j - source sequence for attending to (kernel size squared in this case) - """ - - wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries - batch, _, height, width = x.shape - - is_one_query = self.num_queries == 1 - - # queries, keys, values - - q = self.queries * self.scale - k, v = self.to_kv(x).chunk(2, dim = 1) - - # similarities - - sim = einsum('h l d, b d x y -> b h l x y', q, k) - sim = rearrange(sim, 'b ... x y -> b (...) x y') - - # unfold the similarity scores, with float(-inf) as padding value - - mask_value = -torch.finfo(sim.dtype).max - sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value) - sim = F.unfold(sim, kernel_size = wsz) - sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width) - - # rel pos bias - - sim = sim + self.rel_pos_bias - - # numerically stable attention - - sim = sim - sim.amax(dim = -3, keepdim = True).detach() - attn = sim.softmax(dim = -3) - - # unfold values - - v = F.pad(v, ((wsz // 2,) * 4), value = 0.) - v = F.unfold(v, kernel_size = wsz) - v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width) - - # aggregate values - - out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v) - - # combine heads - - out = rearrange(out, 'b l h d x y -> (b l) (h d) x y') - out = self.to_out(out) - out = rearrange(out, '(b l) d x y -> b l d x y', b = batch) - - # return original input if one query - - if is_one_query: - out = rearrange(out, 'b 1 ... -> b ...') - - return out - -class QueryAttnUpsample(nn.Module): - def __init__(self, dim, **kwargs): - super().__init__() - self.norm = LayerNormChan(dim) - self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs) - - def forward(self, x): - x = self.norm(x) - out = self.qna(x) - out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2) - return out diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 42e3f88..3d87c82 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -17,7 +17,6 @@ from kornia.filters import gaussian_blur2d from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE -from dalle2_pytorch.attention import QueryAttnUpsample # use x-clip diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index dceff14..7e59805 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -15,8 +15,6 @@ from einops import rearrange, reduce, repeat from einops_exts import rearrange_many from einops.layers.torch import Rearrange -from dalle2_pytorch.attention import QueryAttnUpsample - # constants MList = nn.ModuleList