cast attention matrix back to original dtype pre-softmax in attention

This commit is contained in:
Phil Wang
2022-08-20 10:56:01 -07:00
parent 7762edd0ff
commit 083508ff8e
2 changed files with 4 additions and 1 deletions

View File

@@ -879,6 +879,8 @@ class Attention(nn.Module):
# attention
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.type(sim.dtype)
attn = self.dropout(attn)
# aggregate values
@@ -1637,6 +1639,7 @@ class CrossAttention(nn.Module):
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.type(sim.dtype)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -1 +1 @@
__version__ = '1.8.3'
__version__ = '1.8.4'