mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
bring in rotary embeddings for diffusion prior causal transformer (the most powerful relative positional encoding, used in PaLM) - 0.1.0 because of breaking change
This commit is contained in:
@@ -23,6 +23,10 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
|
||||
from resize_right import resize
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
|
||||
# use x-clip
|
||||
|
||||
from x_clip import CLIP
|
||||
@@ -566,7 +570,8 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
post_norm = False
|
||||
post_norm = False,
|
||||
rotary_emb = None
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -582,6 +587,8 @@ class Attention(nn.Module):
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim) if post_norm else nn.Identity()
|
||||
@@ -594,6 +601,12 @@ class Attention(nn.Module):
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
||||
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||
q = q * self.scale
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
if exists(self.rotary_emb):
|
||||
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
|
||||
|
||||
# add null key / value for classifier free guidance in prior net
|
||||
|
||||
@@ -601,7 +614,7 @@ class Attention(nn.Module):
|
||||
k = torch.cat((nk, k), dim = -2)
|
||||
v = torch.cat((nv, v), dim = -2)
|
||||
|
||||
q = q * self.scale
|
||||
# calculate query / key similarities
|
||||
|
||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||
|
||||
@@ -651,15 +664,18 @@ class CausalTransformer(nn.Module):
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
final_proj = True,
|
||||
normformer = False
|
||||
normformer = False,
|
||||
rotary_emb = True
|
||||
):
|
||||
super().__init__()
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
|
||||
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user