mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
complete logic in diffusion prior for sampling more than 1 image embeds, taking top similarity
This commit is contained in:
@@ -405,12 +405,29 @@ class DiffusionPrior(nn.Module):
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(self, text):
|
def sample(self, text, num_samples_per_batch = 2):
|
||||||
|
# in the paper, what they did was
|
||||||
|
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
||||||
|
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
||||||
|
|
||||||
batch_size = text.shape[0]
|
batch_size = text.shape[0]
|
||||||
image_embed_dim = self.image_embed_dim
|
image_embed_dim = self.image_embed_dim
|
||||||
|
|
||||||
text_cond = self.get_text_cond(text)
|
text_cond = self.get_text_cond(text)
|
||||||
return self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
|
||||||
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||||
|
text_embeds = text_cond['text_embeds']
|
||||||
|
|
||||||
|
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||||
|
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||||
|
|
||||||
|
text_image_sims = einsum('b r d, b r d -> b r')
|
||||||
|
top_sim_indices = text_image_sims.topk(k = 1).indices
|
||||||
|
|
||||||
|
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b d', d = image_embed_dim)
|
||||||
|
|
||||||
|
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
||||||
|
return top_image_embeds
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
def q_sample(self, x_start, t, noise=None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|||||||
Reference in New Issue
Block a user