fix a bug with numerical stability in attention, sorry! 🐛

This commit is contained in:
Phil Wang
2022-05-09 16:23:12 -07:00
parent cb07b37970
commit db805e73e1
2 changed files with 3 additions and 3 deletions

View File

@@ -677,7 +677,7 @@ class Attention(nn.Module):
# attention
sim = sim - sim.amax(dim = -1, keepdim = True)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
@@ -1204,7 +1204,7 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.1',
version = '0.2.2',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',