From 7762edd0ff45788ed45c54c96fb080b28522a082 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 19 Aug 2022 11:28:58 -0700 Subject: [PATCH] make it work for @ethancohen123 --- dalle2_pytorch/dalle2_pytorch.py | 10 ++++++++-- dalle2_pytorch/version.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index ce7339e..7920e69 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -250,9 +250,15 @@ class XClipAdapter(BaseClipAdapter): text = text[..., :self.max_text_len] text_mask = text != 0 encoder_output = self.clip.text_transformer(text) - text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:] + + encoder_output_is_cls = encoder_output.ndim == 3 + + text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output_is_cls else (encoder_output, None) text_embed = self.clip.to_text_latent(text_cls) - text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) + + if exists(text_encodings): + text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) + return EmbeddedText(l2norm(text_embed), text_encodings) @torch.no_grad() diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index aa1a8c4..cfe6447 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.8.2' +__version__ = '1.8.3'