mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
🐛
This commit is contained in:
@@ -743,7 +743,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embed']
|
||||
|
||||
Reference in New Issue
Block a user