From 36fb46a95e59a415cf0826fa569ed50c1c89f78c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 28 Jul 2022 08:33:51 -0700 Subject: [PATCH] fix readme and a small bug in DALLE2 class --- README.md | 1 + dalle2_pytorch/dalle2_pytorch.py | 2 +- dalle2_pytorch/version.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 91e1088..a27de1c 100644 --- a/README.md +++ b/README.md @@ -371,6 +371,7 @@ loss.backward() unet1 = Unet( dim = 128, image_embed_dim = 512, + text_embed_dim = 512, cond_dim = 128, channels = 3, dim_mults=(1, 2, 4, 8), diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 9c9827b..c912cfb 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -2938,7 +2938,7 @@ class DALLE2(nn.Module): image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale) text_cond = text if self.decoder_need_text_cond else None - images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) + images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale) if return_pil_images: images = list(map(self.to_pil, images.unbind(dim = 0))) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 58d478a..3f262a6 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.2.0' +__version__ = '1.2.1'