From ddde8ca1bf3d26e3274b5b452a480271a1446e34 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 19 Apr 2022 20:54:28 -0700 Subject: [PATCH] fix cosine bbeta schedule, thanks to @Zhengxinyang --- dalle2_pytorch/dalle2_pytorch.py | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index d1a4538..8326f9a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -105,8 +105,8 @@ def cosine_beta_schedule(timesteps, s = 0.008): as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 - x = torch.linspace(0, steps, steps) - alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + x = torch.linspace(0, timesteps, steps) + alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) diff --git a/setup.py b/setup.py index 28513c8..c1e4f9a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.25', + version = '0.0.26', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',