diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index f717d1e..c2fca66 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -161,6 +161,44 @@ class MLP(nn.Module): def forward(self, x): return self.net(x.float()) +# relative positional bias for causal transformer + +class RelPosBias(nn.Module): + def __init__( + self, + heads = 8, + num_buckets = 32, + max_distance = 128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket( + relative_position, + num_buckets = 32, + max_distance = 128 + ): + n = -relative_position + n = torch.max(n, torch.zeros_like(n)) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + return torch.where(is_small, n, val_if_large) + + def forward(self, i, j, *, device): + q_pos = torch.arange(i, dtype = torch.long, device = device) + k_pos = torch.arange(j, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + # feedforward class SwiGLU(nn.Module): @@ -208,7 +246,7 @@ class Attention(nn.Module): self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) - def forward(self, x, mask = None): + def forward(self, x, mask = None, attn_bias = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) @@ -225,6 +263,14 @@ class Attention(nn.Module): q = q * self.scale sim = einsum('b h i d, b j d -> b h i j', q, k) + + # relative positional encoding (T5 style) + + if exists(attn_bias): + sim = sim + attn_bias + + # masking + max_neg_value = -torch.finfo(sim.dtype).max if exists(mask): @@ -237,10 +283,14 @@ class Attention(nn.Module): causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) sim = sim.masked_fill(causal_mask, max_neg_value) + # attention + sim = sim - sim.amax(dim = -1, keepdim = True) attn = sim.softmax(dim = -1) attn = self.dropout(attn) + # aggregate values + out = einsum('b h i j, b j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') @@ -260,7 +310,7 @@ class CausalTransformer(nn.Module): ff_dropout = 0. ): super().__init__() - # todo - bring in rotary embeddings or alibi + self.rel_pos_bias = RelPosBias(heads = heads) self.layers = nn.ModuleList([]) for _ in range(depth): @@ -276,8 +326,12 @@ class CausalTransformer(nn.Module): x, mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings ): + n, device = x.shape[1], x.device + + attn_bias = self.rel_pos_bias(n, n + 1, device = device) + for attn, ff in self.layers: - x = attn(x, mask = mask) + x + x = attn(x, mask = mask, attn_bias = attn_bias) + x x = ff(x) + x return self.norm(x) diff --git a/setup.py b/setup.py index 9abd021..7373c19 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.14', + version = '0.0.15', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',