use t5 relative positional bias in prior network causal transformer, since it makes more sense than rotary embeddings

This commit is contained in:
Phil Wang
2022-04-14 12:01:09 -07:00
parent 9f55c24db6
commit 6e27f617f1
2 changed files with 58 additions and 4 deletions

View File

@@ -161,6 +161,44 @@ class MLP(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x.float()) return self.net(x.float())
# relative positional bias for causal transformer
class RelPosBias(nn.Module):
def __init__(
self,
heads = 8,
num_buckets = 32,
max_distance = 128,
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(
relative_position,
num_buckets = 32,
max_distance = 128
):
n = -relative_position
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
return torch.where(is_small, n, val_if_large)
def forward(self, i, j, *, device):
q_pos = torch.arange(i, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j')
# feedforward # feedforward
class SwiGLU(nn.Module): class SwiGLU(nn.Module):
@@ -208,7 +246,7 @@ class Attention(nn.Module):
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, mask = None): def forward(self, x, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device b, n, device = *x.shape[:2], x.device
x = self.norm(x) x = self.norm(x)
@@ -225,6 +263,14 @@ class Attention(nn.Module):
q = q * self.scale q = q * self.scale
sim = einsum('b h i d, b j d -> b h i j', q, k) sim = einsum('b h i d, b j d -> b h i j', q, k)
# relative positional encoding (T5 style)
if exists(attn_bias):
sim = sim + attn_bias
# masking
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask): if exists(mask):
@@ -237,10 +283,14 @@ class Attention(nn.Module):
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value) sim = sim.masked_fill(causal_mask, max_neg_value)
# attention
sim = sim - sim.amax(dim = -1, keepdim = True) sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1) attn = sim.softmax(dim = -1)
attn = self.dropout(attn) attn = self.dropout(attn)
# aggregate values
out = einsum('b h i j, b j d -> b h i d', attn, v) out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
@@ -260,7 +310,7 @@ class CausalTransformer(nn.Module):
ff_dropout = 0. ff_dropout = 0.
): ):
super().__init__() super().__init__()
# todo - bring in rotary embeddings or alibi self.rel_pos_bias = RelPosBias(heads = heads)
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
@@ -276,8 +326,12 @@ class CausalTransformer(nn.Module):
x, x,
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
): ):
n, device = x.shape[1], x.device
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
for attn, ff in self.layers: for attn, ff in self.layers:
x = attn(x, mask = mask) + x x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x x = ff(x) + x
return self.norm(x) return self.norm(x)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.14', version = '0.0.15',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',