From 083508ff8e8861f1b1670e807bda25b538ea032c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 20 Aug 2022 10:56:01 -0700 Subject: [PATCH] cast attention matrix back to original dtype pre-softmax in attention --- dalle2_pytorch/dalle2_pytorch.py | 3 +++ dalle2_pytorch/version.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 7920e69..2b11928 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -879,6 +879,8 @@ class Attention(nn.Module): # attention attn = sim.softmax(dim = -1, dtype = torch.float32) + attn = attn.type(sim.dtype) + attn = self.dropout(attn) # aggregate values @@ -1637,6 +1639,7 @@ class CrossAttention(nn.Module): sim = sim.masked_fill(~mask, max_neg_value) attn = sim.softmax(dim = -1, dtype = torch.float32) + attn = attn.type(sim.dtype) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index cfe6447..fa2822c 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.8.3' +__version__ = '1.8.4'