mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
only pass text encodings conditioning in diffusion prior if specified on initialization
This commit is contained in:
@@ -736,11 +736,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(
|
||||
text_embed = text_embed,
|
||||
text_encodings = text_encodings,
|
||||
mask = text != 0
|
||||
)
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embed']
|
||||
@@ -780,11 +779,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
text_mask = text != 0
|
||||
|
||||
text_cond = dict(
|
||||
text_embed = text_embed,
|
||||
text_encodings = text_encodings,
|
||||
mask = text_mask
|
||||
)
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
# timestep conditioning from ddpm
|
||||
|
||||
|
||||
Reference in New Issue
Block a user