From f9882077188e9846ebf67bd78b2f8621754232ab Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 13 Jul 2022 12:56:02 -0700 Subject: [PATCH] hack around some inplace error, also make sure for openai clip text encoding, only tokens after eos_id is masked out --- dalle2_pytorch/dalle2_pytorch.py | 8 ++++++-- dalle2_pytorch/version.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 003b9d3..3e111a4 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -278,6 +278,7 @@ class OpenAIClipAdapter(BaseClipAdapter): import clip openai_clip, preprocess = clip.load(name) super().__init__(openai_clip) + self.eos_id = 49407 # for handling 0 being also '!' text_attention_final = self.find_layer('ln_final') self.handle = text_attention_final.register_forward_hook(self._hook) @@ -316,7 +317,10 @@ class OpenAIClipAdapter(BaseClipAdapter): @torch.no_grad() def embed_text(self, text): text = text[..., :self.max_text_len] - text_mask = text != 0 + + is_eos_id = (text == self.eos_id) + text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0 + text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True) assert not self.cleared text_embed = self.clip.encode_text(text) @@ -900,7 +904,7 @@ class DiffusionPriorNetwork(nn.Module): null_text_embeds = self.null_text_embed.to(text_encodings.dtype) text_encodings = torch.where( - rearrange(mask, 'b n -> b n 1'), + rearrange(mask, 'b n -> b n 1').clone(), text_encodings, null_text_embeds ) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index caf9513..3a1985b 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.23.7' +__version__ = '0.23.8'