From ee75515c7d241c4bb90e81964e865a3577028bb5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 5 Jul 2022 16:53:58 -0700 Subject: [PATCH] remove forcing of softmax in f32, in case it is interfering with deepspeed --- dalle2_pytorch/dalle2_pytorch.py | 4 ++-- dalle2_pytorch/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 6a45281..b3a023b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -704,7 +704,7 @@ class Attention(nn.Module): sim = sim - sim.amax(dim = -1, keepdim = True).detach() sim = sim * self.pb_relax_alpha - attn = sim.softmax(dim = -1, dtype = torch.float32) + attn = sim.softmax(dim = -1) attn = self.dropout(attn) # aggregate values @@ -1272,7 +1272,7 @@ class CrossAttention(nn.Module): sim = sim - sim.amax(dim = -1, keepdim = True).detach() sim = sim * self.pb_relax_alpha - attn = sim.softmax(dim = -1, dtype = torch.float32) + attn = sim.softmax(dim = -1) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index f1bc033..05c037a 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.6' +__version__ = '0.16.7'