diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 59fc151..00d5a6d 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -736,11 +736,10 @@ class DiffusionPrior(BaseGaussianDiffusion): text_embed, text_encodings = self.clip.embed_text(text) - text_cond = dict( - text_embed = text_embed, - text_encodings = text_encodings, - mask = text != 0 - ) + text_cond = dict(text_embed = text_embed) + + if self.condition_on_text_encodings: + text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) text_embeds = text_cond['text_embed'] @@ -780,11 +779,10 @@ class DiffusionPrior(BaseGaussianDiffusion): text_embed, text_encodings = self.clip.embed_text(text) text_mask = text != 0 - text_cond = dict( - text_embed = text_embed, - text_encodings = text_encodings, - mask = text_mask - ) + text_cond = dict(text_embed = text_embed) + + if self.condition_on_text_encodings: + text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} # timestep conditioning from ddpm diff --git a/setup.py b/setup.py index b479399..54ba910 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.57', + version = '0.0.58', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',