From 07abfcf45bb00224af73f7b0dd36d35dd67d750a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 27 Jul 2022 12:27:32 -0700 Subject: [PATCH] rescale values in linear attention to mitigate overflows in fp16 setting --- dalle2_pytorch/dalle2_pytorch.py | 1 + dalle2_pytorch/version.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 71c92e1..9c9827b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1503,6 +1503,7 @@ class LinearAttention(nn.Module): k = k.softmax(dim = -2) q = q * self.scale + v = v / (x * y) context = einsum('b n d, b n e -> b d e', k, v) out = einsum('b n d, b d e -> b n e', q, context) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 1a72d32..58d478a 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.1.0' +__version__ = '1.2.0'