diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 928cf10..52247d9 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -405,12 +405,29 @@ class DiffusionPrior(nn.Module): return img @torch.no_grad() - def sample(self, text): + def sample(self, text, num_samples_per_batch = 2): + # in the paper, what they did was + # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP + text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch) + batch_size = text.shape[0] image_embed_dim = self.image_embed_dim text_cond = self.get_text_cond(text) - return self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) + + image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) + text_embeds = text_cond['text_embeds'] + + text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch) + image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch) + + text_image_sims = einsum('b r d, b r d -> b r') + top_sim_indices = text_image_sims.topk(k = 1).indices + + top_sim_indices = repeat(top_sim_indices, 'b 1 -> b d', d = image_embed_dim) + + top_image_embeds = image_embeds.gather(1, top_sim_indices) + return top_image_embeds def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start))