mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
bring in rotary embeddings for diffusion prior causal transformer (the most powerful relative positional encoding, used in PaLM) - 0.1.0 because of breaking change
This commit is contained in:
@@ -23,6 +23,10 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
|||||||
|
|
||||||
from resize_right import resize
|
from resize_right import resize
|
||||||
|
|
||||||
|
# rotary embeddings
|
||||||
|
|
||||||
|
from rotary_embedding_torch import RotaryEmbedding
|
||||||
|
|
||||||
# use x-clip
|
# use x-clip
|
||||||
|
|
||||||
from x_clip import CLIP
|
from x_clip import CLIP
|
||||||
@@ -566,7 +570,8 @@ class Attention(nn.Module):
|
|||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
causal = False,
|
causal = False,
|
||||||
post_norm = False
|
post_norm = False,
|
||||||
|
rotary_emb = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
@@ -582,6 +587,8 @@ class Attention(nn.Module):
|
|||||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||||
|
|
||||||
|
self.rotary_emb = rotary_emb
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, dim, bias = False),
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
LayerNorm(dim) if post_norm else nn.Identity()
|
LayerNorm(dim) if post_norm else nn.Identity()
|
||||||
@@ -594,6 +601,12 @@ class Attention(nn.Module):
|
|||||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
||||||
|
|
||||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# rotary embeddings
|
||||||
|
|
||||||
|
if exists(self.rotary_emb):
|
||||||
|
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
|
||||||
|
|
||||||
# add null key / value for classifier free guidance in prior net
|
# add null key / value for classifier free guidance in prior net
|
||||||
|
|
||||||
@@ -601,7 +614,7 @@ class Attention(nn.Module):
|
|||||||
k = torch.cat((nk, k), dim = -2)
|
k = torch.cat((nk, k), dim = -2)
|
||||||
v = torch.cat((nv, v), dim = -2)
|
v = torch.cat((nv, v), dim = -2)
|
||||||
|
|
||||||
q = q * self.scale
|
# calculate query / key similarities
|
||||||
|
|
||||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||||
|
|
||||||
@@ -651,15 +664,18 @@ class CausalTransformer(nn.Module):
|
|||||||
attn_dropout = 0.,
|
attn_dropout = 0.,
|
||||||
ff_dropout = 0.,
|
ff_dropout = 0.,
|
||||||
final_proj = True,
|
final_proj = True,
|
||||||
normformer = False
|
normformer = False,
|
||||||
|
rotary_emb = True
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||||
|
|
||||||
|
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(nn.ModuleList([
|
self.layers.append(nn.ModuleList([
|
||||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
|
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
||||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.0.109',
|
version = '0.1.1',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -31,6 +31,7 @@ setup(
|
|||||||
'kornia>=0.5.4',
|
'kornia>=0.5.4',
|
||||||
'pillow',
|
'pillow',
|
||||||
'resize-right>=0.0.2',
|
'resize-right>=0.0.2',
|
||||||
|
'rotary-embedding-torch',
|
||||||
'torch>=1.10',
|
'torch>=1.10',
|
||||||
'torchvision',
|
'torchvision',
|
||||||
'tqdm',
|
'tqdm',
|
||||||
|
|||||||
Reference in New Issue
Block a user