mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix cosine bbeta schedule, thanks to @Zhengxinyang
This commit is contained in:
@@ -105,8 +105,8 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
|||||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||||
"""
|
"""
|
||||||
steps = timesteps + 1
|
steps = timesteps + 1
|
||||||
x = torch.linspace(0, steps, steps)
|
x = torch.linspace(0, timesteps, steps)
|
||||||
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||||
return torch.clip(betas, 0, 0.999)
|
return torch.clip(betas, 0, 0.999)
|
||||||
|
|||||||
Reference in New Issue
Block a user