This commit is contained in:
Phil Wang
2022-04-28 07:21:18 -07:00
parent dbf4a281f1
commit 625ce23f6b
2 changed files with 2 additions and 2 deletions

View File

@@ -743,7 +743,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']