only pass text encodings conditioning in diffusion prior if specified on initialization

This commit is contained in:
Phil Wang
2022-04-27 19:48:16 -07:00
parent 6700381a37
commit 8c610aad9a
2 changed files with 9 additions and 11 deletions

View File

@@ -736,11 +736,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
text_embed, text_encodings = self.clip.embed_text(text) text_embed, text_encodings = self.clip.embed_text(text)
text_cond = dict( text_cond = dict(text_embed = text_embed)
text_embed = text_embed,
text_encodings = text_encodings, if self.condition_on_text_encodings:
mask = text != 0 text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed'] text_embeds = text_cond['text_embed']
@@ -780,11 +779,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
text_embed, text_encodings = self.clip.embed_text(text) text_embed, text_encodings = self.clip.embed_text(text)
text_mask = text != 0 text_mask = text != 0
text_cond = dict( text_cond = dict(text_embed = text_embed)
text_embed = text_embed,
text_encodings = text_encodings, if self.condition_on_text_encodings:
mask = text_mask text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
)
# timestep conditioning from ddpm # timestep conditioning from ddpm

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.57', version = '0.0.58',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',