From d3cded3c6cd76eca7135f669a96658a244d0f49f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 13 Apr 2022 10:52:31 -0700 Subject: [PATCH] complete logic in diffusion prior for sampling more than 1 image embeds, taking top similarity --- dalle2_pytorch/dalle2_pytorch.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 928cf10..52247d9 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -405,12 +405,29 @@ class DiffusionPrior(nn.Module): return img @torch.no_grad() - def sample(self, text): + def sample(self, text, num_samples_per_batch = 2): + # in the paper, what they did was + # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP + text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch) + batch_size = text.shape[0] image_embed_dim = self.image_embed_dim text_cond = self.get_text_cond(text) - return self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) + + image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) + text_embeds = text_cond['text_embeds'] + + text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch) + image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch) + + text_image_sims = einsum('b r d, b r d -> b r') + top_sim_indices = text_image_sims.topk(k = 1).indices + + top_sim_indices = repeat(top_sim_indices, 'b 1 -> b d', d = image_embed_dim) + + top_image_embeds = image_embeds.gather(1, top_sim_indices) + return top_image_embeds def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start))