From a59882001267caef4ccf2879dc9b7efcccb6d93a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 9 Jul 2022 18:38:40 -0700 Subject: [PATCH] do not noise for the last step in ddim --- dalle2_pytorch/dalle2_pytorch.py | 7 ++++--- dalle2_pytorch/version.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index b5fa986..ed4ebe4 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1059,10 +1059,10 @@ class DiffusionPrior(nn.Module): c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() - new_noise = torch.randn_like(image_embed) + noise = torch.randn_like(image_embed) if time_next > 0 else 0. img = x_start * alpha_next.sqrt() + \ - c1 * new_noise + \ + c1 * noise + \ c2 * pred_noise return image_embed @@ -2275,9 +2275,10 @@ class Decoder(nn.Module): c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() + noise = torch.randn_like(img) if time_next > 0 else 0. img = x_start * alpha_next.sqrt() + \ - c1 * torch.randn_like(img) + \ + c1 * noise + \ c2 * pred_noise img = self.unnormalize_img(img) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index b699138..2c49b72 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.19.3' +__version__ = '0.19.4'