From 625ce23f6b5e91b4fc75464d63ab5ee6a5b7c011 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 28 Apr 2022 07:21:18 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dalle2_pytorch/dalle2_pytorch.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index be861c9..a62171c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -743,7 +743,7 @@ class DiffusionPrior(BaseGaussianDiffusion): text_cond = dict(text_embed = text_embed) if self.condition_on_text_encodings: - text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} + text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0} image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) text_embeds = text_cond['text_embed'] diff --git a/setup.py b/setup.py index 894f64d..8a3a1e7 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.61', + version = '0.0.62', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',