Compare commits

..

1 Commits
1.8.4 ... 1.8.3

Author SHA1 Message Date
Phil Wang
3df86acc8b make it work for @ethancohen123 2022-08-19 11:25:34 -07:00
2 changed files with 2 additions and 7 deletions

View File

@@ -251,9 +251,7 @@ class XClipAdapter(BaseClipAdapter):
text_mask = text != 0
encoder_output = self.clip.text_transformer(text)
encoder_output_is_cls = encoder_output.ndim == 3
text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output_is_cls else (encoder_output, None)
text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output.ndim == 3 else (encoder_output, None)
text_embed = self.clip.to_text_latent(text_cls)
if exists(text_encodings):
@@ -879,8 +877,6 @@ class Attention(nn.Module):
# attention
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.type(sim.dtype)
attn = self.dropout(attn)
# aggregate values
@@ -1639,7 +1635,6 @@ class CrossAttention(nn.Module):
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.type(sim.dtype)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -1 +1 @@
__version__ = '1.8.4'
__version__ = '1.8.3'