mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
rescale values in linear attention to mitigate overflows in fp16 setting
This commit is contained in:
@@ -1503,6 +1503,7 @@ class LinearAttention(nn.Module):
|
||||
k = k.softmax(dim = -2)
|
||||
|
||||
q = q * self.scale
|
||||
v = v / (x * y)
|
||||
|
||||
context = einsum('b n d, b n e -> b d e', k, v)
|
||||
out = einsum('b n d, b d e -> b n e', q, context)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.1.0'
|
||||
__version__ = '1.2.0'
|
||||
|
||||
Reference in New Issue
Block a user