schedule to Prior

This commit is contained in:
Kashif Rasul
2022-04-17 15:21:47 +02:00
parent 51361c2d15
commit b0f2fbaa95

View File

@@ -450,10 +450,11 @@ class DiffusionPrior(nn.Module):
net,
*,
clip,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1',
predict_x0 = True
timesteps=1000,
cond_drop_prob=0.2,
loss_type="l1",
predict_x0=True,
beta_schedule="cosine",
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -469,7 +470,18 @@ class DiffusionPrior(nn.Module):
self.predict_x0 = predict_x0
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
betas = cosine_beta_schedule(timesteps)
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)