diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c2fca66..f85a61f 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -450,7 +450,7 @@ class DiffusionPrior(nn.Module): alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) timesteps, = betas.shape self.num_timesteps = int(timesteps) @@ -941,7 +941,7 @@ class Decoder(nn.Module): alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) timesteps, = betas.shape self.num_timesteps = int(timesteps) diff --git a/setup.py b/setup.py index 7373c19..86c0053 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.15', + version = '0.0.16', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',