From 683dd98b96a8ca9ac9b43b40aedd5306587b8289 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 15 Dec 2022 10:54:21 -0800 Subject: [PATCH] extra insurance in case eos id is not there --- dalle2_pytorch/dalle2_pytorch.py | 2 ++ dalle2_pytorch/version.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c3b353d..6c8984f 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -360,6 +360,7 @@ class OpenAIClipAdapter(BaseClipAdapter): is_eos_id = (text == self.eos_id) text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0 text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True) + text_mask = text_mask & (text != 0) assert not self.cleared text_embed = self.clip.encode_text(text) @@ -434,6 +435,7 @@ class OpenClipAdapter(BaseClipAdapter): is_eos_id = (text == self.eos_id) text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0 text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True) + text_mask = text_mask & (text != 0) assert not self.cleared text_embed = self.clip.encode_text(text) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 6c371de..2279cc1 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.11.2' +__version__ = '1.11.4'