diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 6a45281..b3a023b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -704,7 +704,7 @@ class Attention(nn.Module): sim = sim - sim.amax(dim = -1, keepdim = True).detach() sim = sim * self.pb_relax_alpha - attn = sim.softmax(dim = -1, dtype = torch.float32) + attn = sim.softmax(dim = -1) attn = self.dropout(attn) # aggregate values @@ -1272,7 +1272,7 @@ class CrossAttention(nn.Module): sim = sim - sim.amax(dim = -1, keepdim = True).detach() sim = sim * self.pb_relax_alpha - attn = sim.softmax(dim = -1, dtype = torch.float32) + attn = sim.softmax(dim = -1) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index f1bc033..05c037a 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.6' +__version__ = '0.16.7'