mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
use t5 relative positional bias in prior network causal transformer, since it makes more sense than rotary embeddings
This commit is contained in:
@@ -161,6 +161,44 @@ class MLP(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x.float())
|
return self.net(x.float())
|
||||||
|
|
||||||
|
# relative positional bias for causal transformer
|
||||||
|
|
||||||
|
class RelPosBias(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
heads = 8,
|
||||||
|
num_buckets = 32,
|
||||||
|
max_distance = 128,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_buckets = num_buckets
|
||||||
|
self.max_distance = max_distance
|
||||||
|
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _relative_position_bucket(
|
||||||
|
relative_position,
|
||||||
|
num_buckets = 32,
|
||||||
|
max_distance = 128
|
||||||
|
):
|
||||||
|
n = -relative_position
|
||||||
|
n = torch.max(n, torch.zeros_like(n))
|
||||||
|
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = n < max_exact
|
||||||
|
|
||||||
|
val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
|
||||||
|
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
||||||
|
return torch.where(is_small, n, val_if_large)
|
||||||
|
|
||||||
|
def forward(self, i, j, *, device):
|
||||||
|
q_pos = torch.arange(i, dtype = torch.long, device = device)
|
||||||
|
k_pos = torch.arange(j, dtype = torch.long, device = device)
|
||||||
|
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
||||||
|
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
||||||
|
values = self.relative_attention_bias(rp_bucket)
|
||||||
|
return rearrange(values, 'i j h -> h i j')
|
||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
|
|
||||||
class SwiGLU(nn.Module):
|
class SwiGLU(nn.Module):
|
||||||
@@ -208,7 +246,7 @@ class Attention(nn.Module):
|
|||||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||||
|
|
||||||
def forward(self, x, mask = None):
|
def forward(self, x, mask = None, attn_bias = None):
|
||||||
b, n, device = *x.shape[:2], x.device
|
b, n, device = *x.shape[:2], x.device
|
||||||
|
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
@@ -225,6 +263,14 @@ class Attention(nn.Module):
|
|||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
|
||||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||||
|
|
||||||
|
# relative positional encoding (T5 style)
|
||||||
|
|
||||||
|
if exists(attn_bias):
|
||||||
|
sim = sim + attn_bias
|
||||||
|
|
||||||
|
# masking
|
||||||
|
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
@@ -237,10 +283,14 @@ class Attention(nn.Module):
|
|||||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
||||||
sim = sim.masked_fill(causal_mask, max_neg_value)
|
sim = sim.masked_fill(causal_mask, max_neg_value)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||||
attn = sim.softmax(dim = -1)
|
attn = sim.softmax(dim = -1)
|
||||||
attn = self.dropout(attn)
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
|
# aggregate values
|
||||||
|
|
||||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||||
|
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
@@ -260,7 +310,7 @@ class CausalTransformer(nn.Module):
|
|||||||
ff_dropout = 0.
|
ff_dropout = 0.
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# todo - bring in rotary embeddings or alibi
|
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
@@ -276,8 +326,12 @@ class CausalTransformer(nn.Module):
|
|||||||
x,
|
x,
|
||||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
||||||
):
|
):
|
||||||
|
n, device = x.shape[1], x.device
|
||||||
|
|
||||||
|
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||||
|
|
||||||
for attn, ff in self.layers:
|
for attn, ff in self.layers:
|
||||||
x = attn(x, mask = mask) + x
|
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
||||||
x = ff(x) + x
|
x = ff(x) + x
|
||||||
|
|
||||||
return self.norm(x)
|
return self.norm(x)
|
||||||
|
|||||||
Reference in New Issue
Block a user