From b0f2fbaa9521f0e2f35bfdf20df6092006ebba1b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 17 Apr 2022 15:21:47 +0200 Subject: [PATCH] schedule to Prior --- dalle2_pytorch/dalle2_pytorch.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 686db85..84a7528 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -450,10 +450,11 @@ class DiffusionPrior(nn.Module): net, *, clip, - timesteps = 1000, - cond_drop_prob = 0.2, - loss_type = 'l1', - predict_x0 = True + timesteps=1000, + cond_drop_prob=0.2, + loss_type="l1", + predict_x0=True, + beta_schedule="cosine", ): super().__init__() assert isinstance(clip, CLIP) @@ -469,7 +470,18 @@ class DiffusionPrior(nn.Module): self.predict_x0 = predict_x0 # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. - betas = cosine_beta_schedule(timesteps) + if beta_schedule == "cosine": + betas = cosine_beta_schedule(timesteps) + elif beta_schedule == "linear": + betas = linear_beta_schedule(timesteps) + elif beta_schedule == "quadratic": + betas = quadratic_beta_schedule(timesteps) + elif beta_schedule == "jsd": + betas = 1.0 / torch.linspace(timesteps, 1, timesteps) + elif beta_schedule == "sigmoid": + betas = sigmoid_beta_schedule(timesteps) + else: + raise NotImplementedError() alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0)