always rederive the predicted noise from the clipped x0 for ddim + predict noise objective

This commit is contained in:
Phil Wang
2023-03-05 10:45:44 -08:00
parent cc58f75474
commit 848e8a480a
2 changed files with 3 additions and 9 deletions

View File

@@ -1334,10 +1334,7 @@ class DiffusionPrior(nn.Module):
# predict noise # predict noise
if self.predict_x_start or self.predict_v:
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start) pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
else:
pred_noise = pred
if time_next < 0: if time_next < 0:
image_embed = x_start image_embed = x_start
@@ -2975,10 +2972,7 @@ class Decoder(nn.Module):
# predict noise # predict noise
if predict_x_start or predict_v:
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start) pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
else:
pred_noise = pred
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()

View File

@@ -1 +1 @@
__version__ = '1.12.2' __version__ = '1.12.3'