mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
hack around some inplace error, also make sure for openai clip text encoding, only tokens after eos_id is masked out
This commit is contained in:
@@ -278,6 +278,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
import clip
|
import clip
|
||||||
openai_clip, preprocess = clip.load(name)
|
openai_clip, preprocess = clip.load(name)
|
||||||
super().__init__(openai_clip)
|
super().__init__(openai_clip)
|
||||||
|
self.eos_id = 49407 # for handling 0 being also '!'
|
||||||
|
|
||||||
text_attention_final = self.find_layer('ln_final')
|
text_attention_final = self.find_layer('ln_final')
|
||||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||||
@@ -316,7 +317,10 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def embed_text(self, text):
|
def embed_text(self, text):
|
||||||
text = text[..., :self.max_text_len]
|
text = text[..., :self.max_text_len]
|
||||||
text_mask = text != 0
|
|
||||||
|
is_eos_id = (text == self.eos_id)
|
||||||
|
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
||||||
|
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
||||||
assert not self.cleared
|
assert not self.cleared
|
||||||
|
|
||||||
text_embed = self.clip.encode_text(text)
|
text_embed = self.clip.encode_text(text)
|
||||||
@@ -900,7 +904,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
||||||
|
|
||||||
text_encodings = torch.where(
|
text_encodings = torch.where(
|
||||||
rearrange(mask, 'b n -> b n 1'),
|
rearrange(mask, 'b n -> b n 1').clone(),
|
||||||
text_encodings,
|
text_encodings,
|
||||||
null_text_embeds
|
null_text_embeds
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.23.7'
|
__version__ = '0.23.8'
|
||||||
|
|||||||
Reference in New Issue
Block a user