mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
only pass text encodings conditioning in diffusion prior if specified on initialization
This commit is contained in:
@@ -736,11 +736,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
text_embed, text_encodings = self.clip.embed_text(text)
|
text_embed, text_encodings = self.clip.embed_text(text)
|
||||||
|
|
||||||
text_cond = dict(
|
text_cond = dict(text_embed = text_embed)
|
||||||
text_embed = text_embed,
|
|
||||||
text_encodings = text_encodings,
|
if self.condition_on_text_encodings:
|
||||||
mask = text != 0
|
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||||
)
|
|
||||||
|
|
||||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||||
text_embeds = text_cond['text_embed']
|
text_embeds = text_cond['text_embed']
|
||||||
@@ -780,11 +779,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
text_embed, text_encodings = self.clip.embed_text(text)
|
text_embed, text_encodings = self.clip.embed_text(text)
|
||||||
text_mask = text != 0
|
text_mask = text != 0
|
||||||
|
|
||||||
text_cond = dict(
|
text_cond = dict(text_embed = text_embed)
|
||||||
text_embed = text_embed,
|
|
||||||
text_encodings = text_encodings,
|
if self.condition_on_text_encodings:
|
||||||
mask = text_mask
|
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||||
)
|
|
||||||
|
|
||||||
# timestep conditioning from ddpm
|
# timestep conditioning from ddpm
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user