mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 11:54:22 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bdf85a5e9 |
@@ -629,10 +629,13 @@ class Attention(nn.Module):
|
|||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
causal = False,
|
causal = False,
|
||||||
rotary_emb = None
|
rotary_emb = None,
|
||||||
|
pb_relax_alpha = 32 ** 2
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head ** -0.5
|
self.pb_relax_alpha = pb_relax_alpha
|
||||||
|
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
||||||
|
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
@@ -696,6 +699,9 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
|
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||||
|
sim = sim * self.pb_relax_alpha
|
||||||
|
|
||||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||||
attn = self.dropout(attn)
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
@@ -1210,10 +1216,12 @@ class CrossAttention(nn.Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
norm_context = False
|
norm_context = False,
|
||||||
|
pb_relax_alpha = 32 ** 2
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head ** -0.5
|
self.pb_relax_alpha = pb_relax_alpha
|
||||||
|
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
@@ -1259,6 +1267,9 @@ class CrossAttention(nn.Module):
|
|||||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||||
sim = sim.masked_fill(~mask, max_neg_value)
|
sim = sim.masked_fill(~mask, max_neg_value)
|
||||||
|
|
||||||
|
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||||
|
sim = sim * self.pb_relax_alpha
|
||||||
|
|
||||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||||
|
|
||||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.16.0'
|
__version__ = '0.16.1'
|
||||||
|
|||||||
Reference in New Issue
Block a user