From 42d6e4738711c6ac5c52484a17248d0a4dcf85f2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 17 Apr 2022 15:14:05 +0200 Subject: [PATCH] added huber loss and other schedulers --- dalle2_pytorch/dalle2_pytorch.py | 40 +++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4b2a164..80a7039 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -98,6 +98,29 @@ def cosine_beta_schedule(timesteps, s = 0.008): betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) + +def linear_beta_schedule(timesteps): + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return torch.linspace(beta_start, beta_end, timesteps) + + +def quadratic_beta_schedule(timesteps): + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2 + + +def sigmoid_beta_schedule(timesteps): + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + betas = torch.linspace(-6, 6, timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + # diffusion prior class RMSNorm(nn.Module): @@ -601,6 +624,8 @@ class DiffusionPrior(nn.Module): loss = F.l1_loss(to_predict, x_recon) elif self.loss_type == 'l2': loss = F.mse_loss(to_predict, x_recon) + elif self.loss_type == "huber": + loss = F.smooth_l1_loss(to_predict, x_recon) else: raise NotImplementedError() @@ -958,7 +983,18 @@ class Decoder(nn.Module): self.image_size = clip.image_size self.cond_drop_prob = cond_drop_prob - betas = cosine_beta_schedule(timesteps) + if beta_schedule == "cosine": + betas = cosine_beta_schedule(timesteps) + elif beta_schedule == "linear": + betas = linear_beta_schedule(timesteps) + elif beta_schedule == "quadratic": + betas = quadratic_beta_schedule(timesteps) + elif beta_schedule == "jsd": + betas = 1.0 / torch.linspace(timesteps, 1, timesteps) + elif beta_schedule == "sigmoid": + betas = sigmoid_beta_schedule(timesteps) + else: + raise NotImplementedError() alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) @@ -1087,6 +1123,8 @@ class Decoder(nn.Module): loss = F.l1_loss(noise, x_recon) elif self.loss_type == 'l2': loss = F.mse_loss(noise, x_recon) + elif self.loss_type == "huber": + loss = F.smooth_l1_loss(noise, x_recon) else: raise NotImplementedError()