mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
extra insurance in case eos id is not there
This commit is contained in:
@@ -360,6 +360,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
is_eos_id = (text == self.eos_id)
|
is_eos_id = (text == self.eos_id)
|
||||||
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
||||||
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
||||||
|
text_mask = text_mask & (text != 0)
|
||||||
assert not self.cleared
|
assert not self.cleared
|
||||||
|
|
||||||
text_embed = self.clip.encode_text(text)
|
text_embed = self.clip.encode_text(text)
|
||||||
@@ -434,6 +435,7 @@ class OpenClipAdapter(BaseClipAdapter):
|
|||||||
is_eos_id = (text == self.eos_id)
|
is_eos_id = (text == self.eos_id)
|
||||||
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
||||||
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
||||||
|
text_mask = text_mask & (text != 0)
|
||||||
assert not self.cleared
|
assert not self.cleared
|
||||||
|
|
||||||
text_embed = self.clip.encode_text(text)
|
text_embed = self.clip.encode_text(text)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.11.2'
|
__version__ = '1.11.4'
|
||||||
|
|||||||
Reference in New Issue
Block a user