mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 01:34:19 +01:00
cast attention matrix back to original dtype pre-softmax in attention
This commit is contained in:
@@ -879,6 +879,8 @@ class Attention(nn.Module):
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
attn = attn.type(sim.dtype)
|
||||
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
@@ -1637,6 +1639,7 @@ class CrossAttention(nn.Module):
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
attn = attn.type(sim.dtype)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.8.3'
|
||||
__version__ = '1.8.4'
|
||||
|
||||
Reference in New Issue
Block a user