mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
hack around some inplace error, also make sure for openai clip text encoding, only tokens after eos_id is masked out
This commit is contained in:
@@ -278,6 +278,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
import clip
|
||||
openai_clip, preprocess = clip.load(name)
|
||||
super().__init__(openai_clip)
|
||||
self.eos_id = 49407 # for handling 0 being also '!'
|
||||
|
||||
text_attention_final = self.find_layer('ln_final')
|
||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||
@@ -316,7 +317,10 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
|
||||
is_eos_id = (text == self.eos_id)
|
||||
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
||||
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
||||
assert not self.cleared
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
@@ -900,7 +904,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
||||
|
||||
text_encodings = torch.where(
|
||||
rearrange(mask, 'b n -> b n 1'),
|
||||
rearrange(mask, 'b n -> b n 1').clone(),
|
||||
text_encodings,
|
||||
null_text_embeds
|
||||
)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.23.7'
|
||||
__version__ = '0.23.8'
|
||||
|
||||
Reference in New Issue
Block a user