fix cosine bbeta schedule, thanks to @Zhengxinyang

This commit is contained in:
Phil Wang
2022-04-19 20:54:28 -07:00
parent c26b77ad20
commit ddde8ca1bf
2 changed files with 3 additions and 3 deletions

View File

@@ -105,8 +105,8 @@ def cosine_beta_schedule(timesteps, s = 0.008):
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, steps, steps)
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.25',
version = '0.0.26',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',