Compare commits

..

2 Commits
1.8.3 ... 1.8.4

Author SHA1 Message Date
Phil Wang
083508ff8e cast attention matrix back to original dtype pre-softmax in attention 2022-08-20 10:56:01 -07:00
Phil Wang
7762edd0ff make it work for @ethancohen123 2022-08-19 11:28:58 -07:00
2 changed files with 7 additions and 2 deletions

View File

@@ -251,7 +251,9 @@ class XClipAdapter(BaseClipAdapter):
text_mask = text != 0
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output.ndim == 3 else (encoder_output, None)
encoder_output_is_cls = encoder_output.ndim == 3
text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output_is_cls else (encoder_output, None)
text_embed = self.clip.to_text_latent(text_cls)
if exists(text_encodings):
@@ -877,6 +879,8 @@ class Attention(nn.Module):
# attention
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.type(sim.dtype)
attn = self.dropout(attn)
# aggregate values
@@ -1635,6 +1639,7 @@ class CrossAttention(nn.Module):
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.type(sim.dtype)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -1 +1 @@
__version__ = '1.8.3'
__version__ = '1.8.4'