diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index ca3eba7..faa7bc1 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -373,7 +373,7 @@ def quadratic_beta_schedule(timesteps): scale = 1000 / timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 - return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2 + return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2 def sigmoid_beta_schedule(timesteps):