hack around some inplace error, also make sure for openai clip text encoding, only tokens after eos_id is masked out

This commit is contained in:
Phil Wang
2022-07-13 12:56:02 -07:00
parent b2073219f0
commit f988207718
2 changed files with 7 additions and 3 deletions

View File

@@ -278,6 +278,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
import clip import clip
openai_clip, preprocess = clip.load(name) openai_clip, preprocess = clip.load(name)
super().__init__(openai_clip) super().__init__(openai_clip)
self.eos_id = 49407 # for handling 0 being also '!'
text_attention_final = self.find_layer('ln_final') text_attention_final = self.find_layer('ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook) self.handle = text_attention_final.register_forward_hook(self._hook)
@@ -316,7 +317,10 @@ class OpenAIClipAdapter(BaseClipAdapter):
@torch.no_grad() @torch.no_grad()
def embed_text(self, text): def embed_text(self, text):
text = text[..., :self.max_text_len] text = text[..., :self.max_text_len]
text_mask = text != 0
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
assert not self.cleared assert not self.cleared
text_embed = self.clip.encode_text(text) text_embed = self.clip.encode_text(text)
@@ -900,7 +904,7 @@ class DiffusionPriorNetwork(nn.Module):
null_text_embeds = self.null_text_embed.to(text_encodings.dtype) null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
text_encodings = torch.where( text_encodings = torch.where(
rearrange(mask, 'b n -> b n 1'), rearrange(mask, 'b n -> b n 1').clone(),
text_encodings, text_encodings,
null_text_embeds null_text_embeds
) )

View File

@@ -1 +1 @@
__version__ = '0.23.7' __version__ = '0.23.8'