Compare commits

...

1 Commits
1.6.4 ... 1.6.5

Author SHA1 Message Date
Phil Wang
34806663e3 make it so diffusion prior p_sample_loop returns unnormalized image embeddings 2022-08-13 10:03:40 -07:00
2 changed files with 6 additions and 5 deletions

View File

@@ -1279,9 +1279,12 @@ class DiffusionPrior(nn.Module):
is_ddim = timesteps < self.noise_scheduler.num_timesteps
if not is_ddim:
return self.p_sample_loop_ddpm(*args, **kwargs)
normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
else:
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
image_embed = normalized_image_embed / self.image_embed_scale
return image_embed
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
@@ -1350,8 +1353,6 @@ class DiffusionPrior(nn.Module):
# retrieve original unscaled image embed
image_embeds /= self.image_embed_scale
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)

View File

@@ -1 +1 @@
__version__ = '1.6.4'
__version__ = '1.6.5'