remove forcing of softmax in f32, in case it is interfering with deepspeed

This commit is contained in:
Phil Wang
2022-07-05 16:53:58 -07:00
parent ec68243479
commit ee75515c7d
2 changed files with 3 additions and 3 deletions

View File

@@ -704,7 +704,7 @@ class Attention(nn.Module):
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate values
@@ -1272,7 +1272,7 @@ class CrossAttention(nn.Module):
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')