mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix bug with classifier free guidance for prior network, even though it seems it may not be used
This commit is contained in:
@@ -316,8 +316,15 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||||
|
|
||||||
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||||
|
# but let's just do it right
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
all_masked_out = mask.any(dim = -1)
|
||||||
|
mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1)
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||||
|
|
||||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||||
|
|||||||
Reference in New Issue
Block a user