bring in rotary embeddings for diffusion prior causal transformer (the most powerful relative positional encoding, used in PaLM) - 0.1.0 because of breaking change

This commit is contained in:
Phil Wang
2022-05-06 08:45:30 -07:00
parent 0be1e0d64c
commit ad20a14a4d
2 changed files with 22 additions and 5 deletions

View File

@@ -23,6 +23,10 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
from resize_right import resize
# rotary embeddings
from rotary_embedding_torch import RotaryEmbedding
# use x-clip
from x_clip import CLIP
@@ -566,7 +570,8 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
post_norm = False
post_norm = False,
rotary_emb = None
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -582,6 +587,8 @@ class Attention(nn.Module):
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.rotary_emb = rotary_emb
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
@@ -594,6 +601,12 @@ class Attention(nn.Module):
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
q = q * self.scale
# rotary embeddings
if exists(self.rotary_emb):
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
# add null key / value for classifier free guidance in prior net
@@ -601,7 +614,7 @@ class Attention(nn.Module):
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
# calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k)
@@ -651,15 +664,18 @@ class CausalTransformer(nn.Module):
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True,
normformer = False
normformer = False,
rotary_emb = True
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.109',
version = '0.1.1',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -31,6 +31,7 @@ setup(
'kornia>=0.5.4',
'pillow',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',
'torchvision',
'tqdm',