mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
remove forcing of softmax in f32, in case it is interfering with deepspeed
This commit is contained in:
@@ -704,7 +704,7 @@ class Attention(nn.Module):
|
|||||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||||
sim = sim * self.pb_relax_alpha
|
sim = sim * self.pb_relax_alpha
|
||||||
|
|
||||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
attn = sim.softmax(dim = -1)
|
||||||
attn = self.dropout(attn)
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
# aggregate values
|
# aggregate values
|
||||||
@@ -1272,7 +1272,7 @@ class CrossAttention(nn.Module):
|
|||||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||||
sim = sim * self.pb_relax_alpha
|
sim = sim * self.pb_relax_alpha
|
||||||
|
|
||||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
attn = sim.softmax(dim = -1)
|
||||||
|
|
||||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.16.6'
|
__version__ = '0.16.7'
|
||||||
|
|||||||
Reference in New Issue
Block a user