From 301a97197f81e3985abfe7256501de89dc4b38a6 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 12 Aug 2022 12:29:25 -0700 Subject: [PATCH] fix self conditioning shape in diffusion prior --- dalle2_pytorch/dalle2_pytorch.py | 6 +++--- dalle2_pytorch/version.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 5de4a90..11622a2 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1004,9 +1004,9 @@ class DiffusionPriorNetwork(nn.Module): # setup self conditioning - self_cond = None if self.self_cond: - self_cond = default(self_cond, lambda: torch.zeros(batch, 1, self.dim, device = device, dtype = dtype)) + self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype)) + self_cond = rearrange(self_cond, 'b d -> b 1 d') # in section 2.2, last paragraph # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" @@ -1287,7 +1287,7 @@ class DiffusionPrior(nn.Module): image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise) self_cond = None - if self.net.self_cond and random.random() < 0.5: + if self.net.self_cond and random.random() < 1.5: with torch.no_grad(): self_cond = self.net(image_embed_noisy, times, **text_cond).detach() diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index bb64aa4..4a9b978 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.6.1' +__version__ = '1.6.2'