mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
do not noise for the last step in ddim
This commit is contained in:
@@ -1059,10 +1059,10 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||||
new_noise = torch.randn_like(image_embed)
|
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
||||||
|
|
||||||
img = x_start * alpha_next.sqrt() + \
|
img = x_start * alpha_next.sqrt() + \
|
||||||
c1 * new_noise + \
|
c1 * noise + \
|
||||||
c2 * pred_noise
|
c2 * pred_noise
|
||||||
|
|
||||||
return image_embed
|
return image_embed
|
||||||
@@ -2275,9 +2275,10 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||||
|
noise = torch.randn_like(img) if time_next > 0 else 0.
|
||||||
|
|
||||||
img = x_start * alpha_next.sqrt() + \
|
img = x_start * alpha_next.sqrt() + \
|
||||||
c1 * torch.randn_like(img) + \
|
c1 * noise + \
|
||||||
c2 * pred_noise
|
c2 * pred_noise
|
||||||
|
|
||||||
img = self.unnormalize_img(img)
|
img = self.unnormalize_img(img)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.19.3'
|
__version__ = '0.19.4'
|
||||||
|
|||||||
Reference in New Issue
Block a user