mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-06 16:54:20 +01:00
schedule to Prior
This commit is contained in:
@@ -450,10 +450,11 @@ class DiffusionPrior(nn.Module):
|
||||
net,
|
||||
*,
|
||||
clip,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
predict_x0 = True
|
||||
timesteps=1000,
|
||||
cond_drop_prob=0.2,
|
||||
loss_type="l1",
|
||||
predict_x0=True,
|
||||
beta_schedule="cosine",
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
@@ -469,7 +470,18 @@ class DiffusionPrior(nn.Module):
|
||||
self.predict_x0 = predict_x0
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
if beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
elif beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps)
|
||||
elif beta_schedule == "quadratic":
|
||||
betas = quadratic_beta_schedule(timesteps)
|
||||
elif beta_schedule == "jsd":
|
||||
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
||||
elif beta_schedule == "sigmoid":
|
||||
betas = sigmoid_beta_schedule(timesteps)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
|
||||
Reference in New Issue
Block a user