From 42d6e4738711c6ac5c52484a17248d0a4dcf85f2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 17 Apr 2022 15:14:05 +0200 Subject: [PATCH 1/3] added huber loss and other schedulers --- dalle2_pytorch/dalle2_pytorch.py | 40 +++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4b2a164..80a7039 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -98,6 +98,29 @@ def cosine_beta_schedule(timesteps, s = 0.008): betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) + +def linear_beta_schedule(timesteps): + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return torch.linspace(beta_start, beta_end, timesteps) + + +def quadratic_beta_schedule(timesteps): + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2 + + +def sigmoid_beta_schedule(timesteps): + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + betas = torch.linspace(-6, 6, timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + # diffusion prior class RMSNorm(nn.Module): @@ -601,6 +624,8 @@ class DiffusionPrior(nn.Module): loss = F.l1_loss(to_predict, x_recon) elif self.loss_type == 'l2': loss = F.mse_loss(to_predict, x_recon) + elif self.loss_type == "huber": + loss = F.smooth_l1_loss(to_predict, x_recon) else: raise NotImplementedError() @@ -958,7 +983,18 @@ class Decoder(nn.Module): self.image_size = clip.image_size self.cond_drop_prob = cond_drop_prob - betas = cosine_beta_schedule(timesteps) + if beta_schedule == "cosine": + betas = cosine_beta_schedule(timesteps) + elif beta_schedule == "linear": + betas = linear_beta_schedule(timesteps) + elif beta_schedule == "quadratic": + betas = quadratic_beta_schedule(timesteps) + elif beta_schedule == "jsd": + betas = 1.0 / torch.linspace(timesteps, 1, timesteps) + elif beta_schedule == "sigmoid": + betas = sigmoid_beta_schedule(timesteps) + else: + raise NotImplementedError() alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) @@ -1087,6 +1123,8 @@ class Decoder(nn.Module): loss = F.l1_loss(noise, x_recon) elif self.loss_type == 'l2': loss = F.mse_loss(noise, x_recon) + elif self.loss_type == "huber": + loss = F.smooth_l1_loss(noise, x_recon) else: raise NotImplementedError() From 51361c2d15478ab8b6b9a240e1940fda74e0914a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 17 Apr 2022 15:19:33 +0200 Subject: [PATCH 2/3] added beta_schedule argument --- dalle2_pytorch/dalle2_pytorch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 80a7039..686db85 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -969,9 +969,10 @@ class Decoder(nn.Module): net, *, clip, - timesteps = 1000, - cond_drop_prob = 0.2, - loss_type = 'l1' + timesteps=1000, + cond_drop_prob=0.2, + loss_type="l1", + beta_schedule="cosine", ): super().__init__() assert isinstance(clip, CLIP) From b0f2fbaa9521f0e2f35bfdf20df6092006ebba1b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 17 Apr 2022 15:21:47 +0200 Subject: [PATCH 3/3] schedule to Prior --- dalle2_pytorch/dalle2_pytorch.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 686db85..84a7528 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -450,10 +450,11 @@ class DiffusionPrior(nn.Module): net, *, clip, - timesteps = 1000, - cond_drop_prob = 0.2, - loss_type = 'l1', - predict_x0 = True + timesteps=1000, + cond_drop_prob=0.2, + loss_type="l1", + predict_x0=True, + beta_schedule="cosine", ): super().__init__() assert isinstance(clip, CLIP) @@ -469,7 +470,18 @@ class DiffusionPrior(nn.Module): self.predict_x0 = predict_x0 # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. - betas = cosine_beta_schedule(timesteps) + if beta_schedule == "cosine": + betas = cosine_beta_schedule(timesteps) + elif beta_schedule == "linear": + betas = linear_beta_schedule(timesteps) + elif beta_schedule == "quadratic": + betas = quadratic_beta_schedule(timesteps) + elif beta_schedule == "jsd": + betas = 1.0 / torch.linspace(timesteps, 1, timesteps) + elif beta_schedule == "sigmoid": + betas = sigmoid_beta_schedule(timesteps) + else: + raise NotImplementedError() alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0)