diff --git a/README.md b/README.md index 91e1088..a27de1c 100644 --- a/README.md +++ b/README.md @@ -371,6 +371,7 @@ loss.backward() unet1 = Unet( dim = 128, image_embed_dim = 512, + text_embed_dim = 512, cond_dim = 128, channels = 3, dim_mults=(1, 2, 4, 8), diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 9c9827b..c912cfb 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -2938,7 +2938,7 @@ class DALLE2(nn.Module): image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale) text_cond = text if self.decoder_need_text_cond else None - images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) + images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale) if return_pil_images: images = list(map(self.to_pil, images.unbind(dim = 0))) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 58d478a..3f262a6 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.2.0' +__version__ = '1.2.1'