diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 8999f28..df3c8fb 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -23,6 +23,10 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE from resize_right import resize +# rotary embeddings + +from rotary_embedding_torch import RotaryEmbedding + # use x-clip from x_clip import CLIP @@ -566,7 +570,8 @@ class Attention(nn.Module): heads = 8, dropout = 0., causal = False, - post_norm = False + post_norm = False, + rotary_emb = None ): super().__init__() self.scale = dim_head ** -0.5 @@ -582,6 +587,8 @@ class Attention(nn.Module): self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) + self.rotary_emb = rotary_emb + self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), LayerNorm(dim) if post_norm else nn.Identity() @@ -594,6 +601,12 @@ class Attention(nn.Module): q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) + q = q * self.scale + + # rotary embeddings + + if exists(self.rotary_emb): + q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k)) # add null key / value for classifier free guidance in prior net @@ -601,7 +614,7 @@ class Attention(nn.Module): k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) - q = q * self.scale + # calculate query / key similarities sim = einsum('b h i d, b j d -> b h i j', q, k) @@ -651,15 +664,18 @@ class CausalTransformer(nn.Module): attn_dropout = 0., ff_dropout = 0., final_proj = True, - normformer = False + normformer = False, + rotary_emb = True ): super().__init__() self.rel_pos_bias = RelPosBias(heads = heads) + rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None + self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer), + Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb), FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) ])) diff --git a/setup.py b/setup.py index cfe6fcc..d20eb44 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.109', + version = '0.1.0', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',