mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-22 02:54:20 +01:00
Merge pull request #14 from kashif/loss-schedule
added huber loss and other schedulers
This commit is contained in:
@@ -98,6 +98,29 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
|||||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||||
return torch.clip(betas, 0, 0.999)
|
return torch.clip(betas, 0, 0.999)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_beta_schedule(timesteps):
|
||||||
|
scale = 1000 / timesteps
|
||||||
|
beta_start = scale * 0.0001
|
||||||
|
beta_end = scale * 0.02
|
||||||
|
return torch.linspace(beta_start, beta_end, timesteps)
|
||||||
|
|
||||||
|
|
||||||
|
def quadratic_beta_schedule(timesteps):
|
||||||
|
scale = 1000 / timesteps
|
||||||
|
beta_start = scale * 0.0001
|
||||||
|
beta_end = scale * 0.02
|
||||||
|
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid_beta_schedule(timesteps):
|
||||||
|
scale = 1000 / timesteps
|
||||||
|
beta_start = scale * 0.0001
|
||||||
|
beta_end = scale * 0.02
|
||||||
|
betas = torch.linspace(-6, 6, timesteps)
|
||||||
|
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||||
|
|
||||||
|
|
||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@@ -427,10 +450,11 @@ class DiffusionPrior(nn.Module):
|
|||||||
net,
|
net,
|
||||||
*,
|
*,
|
||||||
clip,
|
clip,
|
||||||
timesteps = 1000,
|
timesteps=1000,
|
||||||
cond_drop_prob = 0.2,
|
cond_drop_prob=0.2,
|
||||||
loss_type = 'l1',
|
loss_type="l1",
|
||||||
predict_x0 = True
|
predict_x0=True,
|
||||||
|
beta_schedule="cosine",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
@@ -446,7 +470,18 @@ class DiffusionPrior(nn.Module):
|
|||||||
self.predict_x0 = predict_x0
|
self.predict_x0 = predict_x0
|
||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
|
||||||
|
if beta_schedule == "cosine":
|
||||||
betas = cosine_beta_schedule(timesteps)
|
betas = cosine_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "linear":
|
||||||
|
betas = linear_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "quadratic":
|
||||||
|
betas = quadratic_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "jsd":
|
||||||
|
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
||||||
|
elif beta_schedule == "sigmoid":
|
||||||
|
betas = sigmoid_beta_schedule(timesteps)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
alphas = 1. - betas
|
alphas = 1. - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||||
@@ -601,6 +636,8 @@ class DiffusionPrior(nn.Module):
|
|||||||
loss = F.l1_loss(to_predict, x_recon)
|
loss = F.l1_loss(to_predict, x_recon)
|
||||||
elif self.loss_type == 'l2':
|
elif self.loss_type == 'l2':
|
||||||
loss = F.mse_loss(to_predict, x_recon)
|
loss = F.mse_loss(to_predict, x_recon)
|
||||||
|
elif self.loss_type == "huber":
|
||||||
|
loss = F.smooth_l1_loss(to_predict, x_recon)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@@ -944,9 +981,10 @@ class Decoder(nn.Module):
|
|||||||
net,
|
net,
|
||||||
*,
|
*,
|
||||||
clip,
|
clip,
|
||||||
timesteps = 1000,
|
timesteps=1000,
|
||||||
cond_drop_prob = 0.2,
|
cond_drop_prob=0.2,
|
||||||
loss_type = 'l1'
|
loss_type="l1",
|
||||||
|
beta_schedule="cosine",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
@@ -958,7 +996,18 @@ class Decoder(nn.Module):
|
|||||||
self.image_size = clip.image_size
|
self.image_size = clip.image_size
|
||||||
self.cond_drop_prob = cond_drop_prob
|
self.cond_drop_prob = cond_drop_prob
|
||||||
|
|
||||||
|
if beta_schedule == "cosine":
|
||||||
betas = cosine_beta_schedule(timesteps)
|
betas = cosine_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "linear":
|
||||||
|
betas = linear_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "quadratic":
|
||||||
|
betas = quadratic_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "jsd":
|
||||||
|
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
||||||
|
elif beta_schedule == "sigmoid":
|
||||||
|
betas = sigmoid_beta_schedule(timesteps)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
alphas = 1. - betas
|
alphas = 1. - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||||
@@ -1087,6 +1136,8 @@ class Decoder(nn.Module):
|
|||||||
loss = F.l1_loss(noise, x_recon)
|
loss = F.l1_loss(noise, x_recon)
|
||||||
elif self.loss_type == 'l2':
|
elif self.loss_type == 'l2':
|
||||||
loss = F.mse_loss(noise, x_recon)
|
loss = F.mse_loss(noise, x_recon)
|
||||||
|
elif self.loss_type == "huber":
|
||||||
|
loss = F.smooth_l1_loss(noise, x_recon)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user