diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 021d5a1..1f94304 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1334,10 +1334,7 @@ class DiffusionPrior(nn.Module): # predict noise - if self.predict_x_start or self.predict_v: - pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start) - else: - pred_noise = pred + pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start) if time_next < 0: image_embed = x_start @@ -2975,10 +2972,7 @@ class Decoder(nn.Module): # predict noise - if predict_x_start or predict_v: - pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start) - else: - pred_noise = pred + pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start) c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 96ddfeb..3081afb 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.12.2' +__version__ = '1.12.3'