complete logic in diffusion prior for sampling more than 1 image embeds, taking top similarity

This commit is contained in:
Phil Wang
2022-04-13 10:52:31 -07:00
parent d573c82f8c
commit d3cded3c6c

View File

@@ -405,12 +405,29 @@ class DiffusionPrior(nn.Module):
return img return img
@torch.no_grad() @torch.no_grad()
def sample(self, text): def sample(self, text, num_samples_per_batch = 2):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
batch_size = text.shape[0] batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text) text_cond = self.get_text_cond(text)
return self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embeds']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
text_image_sims = einsum('b r d, b r d -> b r')
top_sim_indices = text_image_sims.topk(k = 1).indices
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b d', d = image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return top_image_embeds
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))