mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
rescale values in linear attention to mitigate overflows in fp16 setting
This commit is contained in:
@@ -1503,6 +1503,7 @@ class LinearAttention(nn.Module):
|
|||||||
k = k.softmax(dim = -2)
|
k = k.softmax(dim = -2)
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
v = v / (x * y)
|
||||||
|
|
||||||
context = einsum('b n d, b n e -> b d e', k, v)
|
context = einsum('b n d, b n e -> b d e', k, v)
|
||||||
out = einsum('b n d, b d e -> b n e', q, context)
|
out = einsum('b n d, b d e -> b n e', q, context)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.1.0'
|
__version__ = '1.2.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user