added huber loss and other schedulers

This commit is contained in:
Kashif Rasul
2022-04-17 15:14:05 +02:00
parent 1e939153fb
commit 42d6e47387

View File

@@ -98,6 +98,29 @@ def cosine_beta_schedule(timesteps, s = 0.008):
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999) return torch.clip(betas, 0, 0.999)
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
def sigmoid_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
# diffusion prior # diffusion prior
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@@ -601,6 +624,8 @@ class DiffusionPrior(nn.Module):
loss = F.l1_loss(to_predict, x_recon) loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2': elif self.loss_type == 'l2':
loss = F.mse_loss(to_predict, x_recon) loss = F.mse_loss(to_predict, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(to_predict, x_recon)
else: else:
raise NotImplementedError() raise NotImplementedError()
@@ -958,7 +983,18 @@ class Decoder(nn.Module):
self.image_size = clip.image_size self.image_size = clip.image_size
self.cond_drop_prob = cond_drop_prob self.cond_drop_prob = cond_drop_prob
betas = cosine_beta_schedule(timesteps) if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
alphas = 1. - betas alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod = torch.cumprod(alphas, axis=0)
@@ -1087,6 +1123,8 @@ class Decoder(nn.Module):
loss = F.l1_loss(noise, x_recon) loss = F.l1_loss(noise, x_recon)
elif self.loss_type == 'l2': elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon) loss = F.mse_loss(noise, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(noise, x_recon)
else: else:
raise NotImplementedError() raise NotImplementedError()