mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
complete logic in diffusion prior for sampling more than 1 image embeds, taking top similarity
This commit is contained in:
@@ -405,12 +405,29 @@ class DiffusionPrior(nn.Module):
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, text):
|
||||
def sample(self, text, num_samples_per_batch = 2):
|
||||
# in the paper, what they did was
|
||||
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
||||
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
||||
|
||||
batch_size = text.shape[0]
|
||||
image_embed_dim = self.image_embed_dim
|
||||
|
||||
text_cond = self.get_text_cond(text)
|
||||
return self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embeds']
|
||||
|
||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
|
||||
text_image_sims = einsum('b r d, b r d -> b r')
|
||||
top_sim_indices = text_image_sims.topk(k = 1).indices
|
||||
|
||||
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b d', d = image_embed_dim)
|
||||
|
||||
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
||||
return top_image_embeds
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
Reference in New Issue
Block a user