rescale values in linear attention to mitigate overflows in fp16 setting

This commit is contained in:
Phil Wang
2022-07-27 12:27:32 -07:00
parent 2e35a9967d
commit 07abfcf45b
2 changed files with 2 additions and 1 deletions

View File

@@ -1503,6 +1503,7 @@ class LinearAttention(nn.Module):
k = k.softmax(dim = -2)
q = q * self.scale
v = v / (x * y)
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)