allow for setting beta schedules of unets differently in the decoder, as what was used in the paper was cosine, cosine, linear

This commit is contained in:
Phil Wang
2022-06-20 08:56:32 -07:00
parent f5a906f5d3
commit 138079ca83
4 changed files with 87 additions and 55 deletions

View File

@@ -378,7 +378,7 @@ def sigmoid_beta_schedule(timesteps):
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
class BaseGaussianDiffusion(nn.Module):
class NoiseScheduler(nn.Module):
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
super().__init__()
@@ -472,11 +472,10 @@ class BaseGaussianDiffusion(nn.Module):
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def sample(self, *args, **kwargs):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def p2_reweigh_loss(self, loss, times):
if not self.has_p2_loss_reweighting:
return loss
return loss * extract(self.p2_loss_weight, times, loss.shape)
# diffusion prior
@@ -687,8 +686,7 @@ class Attention(nn.Module):
# attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.dropout(attn)
# aggregate values
@@ -862,7 +860,7 @@ class DiffusionPriorNetwork(nn.Module):
return pred_image_embed
class DiffusionPrior(BaseGaussianDiffusion):
class DiffusionPrior(nn.Module):
def __init__(
self,
net,
@@ -883,7 +881,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__(
super().__init__()
self.noise_scheduler = NoiseScheduler(
beta_schedule = beta_schedule,
timesteps = timesteps,
loss_type = loss_type
@@ -923,6 +923,13 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
# device tracker
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -933,7 +940,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.)
@@ -941,7 +948,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
@@ -955,7 +962,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
@torch.no_grad()
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
device = self.betas.device
device = self.device
b = shape[0]
image_embed = torch.randn(shape, device=device)
@@ -963,7 +970,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
@@ -972,7 +979,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
pred = self.net(
image_embed_noisy,
@@ -986,7 +993,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
target = noise if not self.predict_x_start else image_embed
loss = self.loss_fn(pred, target)
loss = self.noise_scheduler.loss_fn(pred, target)
return loss
@torch.no_grad()
@@ -997,7 +1004,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = self.noise_scheduler.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
return img
@@ -1069,7 +1076,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long)
# scale image embed (Katherine)
@@ -1234,8 +1241,7 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1739,7 +1745,7 @@ class LowresConditioner(nn.Module):
return cond_fmap
class Decoder(BaseGaussianDiffusion):
class Decoder(nn.Module):
def __init__(
self,
unet,
@@ -1752,7 +1758,7 @@ class Decoder(BaseGaussianDiffusion):
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l2',
beta_schedule = 'cosine',
beta_schedule = None,
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
@@ -1774,13 +1780,7 @@ class Decoder(BaseGaussianDiffusion):
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1
):
super().__init__(
beta_schedule = beta_schedule,
timesteps = timesteps,
loss_type = loss_type,
p2_loss_weight_gamma = p2_loss_weight_gamma,
p2_loss_weight_k = p2_loss_weight_k
)
super().__init__()
self.unconditional = unconditional
@@ -1824,6 +1824,8 @@ class Decoder(BaseGaussianDiffusion):
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet)
num_unets = len(unets)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
@@ -1859,6 +1861,24 @@ class Decoder(BaseGaussianDiffusion):
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# create noise schedulers per unet
if not exists(beta_schedule):
beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
self.noise_schedulers = nn.ModuleList([])
for unet_beta_schedule in beta_schedule:
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,
timesteps = timesteps,
loss_type = loss_type,
p2_loss_weight_gamma = p2_loss_weight_gamma,
p2_loss_weight_k = p2_loss_weight_k
)
self.noise_schedulers.append(noise_scheduler)
# unet image sizes
image_sizes = default(image_sizes, (image_size,))
@@ -1908,6 +1928,14 @@ class Decoder(BaseGaussianDiffusion):
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
# device tracker
self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
@@ -1931,7 +1959,7 @@ class Decoder(BaseGaussianDiffusion):
for unet, device in zip(self.unets, devices):
unet.to(device)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
@@ -1942,7 +1970,7 @@ class Decoder(BaseGaussianDiffusion):
if predict_x_start:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
# s is the threshold amount
@@ -1961,14 +1989,14 @@ class Decoder(BaseGaussianDiffusion):
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network
# by an interpolation of the max and min log beta values
# eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(noise_scheduler.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
if self.learned_variance_constrain_frac:
@@ -1980,17 +2008,17 @@ class Decoder(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.betas.device
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.device
b = shape[0]
img = torch.randn(shape, device = device)
@@ -1998,7 +2026,7 @@ class Decoder(BaseGaussianDiffusion):
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
img = self.p_sample(
unet,
img,
@@ -2009,6 +2037,7 @@ class Decoder(BaseGaussianDiffusion):
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
@@ -2016,7 +2045,7 @@ class Decoder(BaseGaussianDiffusion):
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
@@ -2027,7 +2056,7 @@ class Decoder(BaseGaussianDiffusion):
# get x_t
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
model_output = unet(
x_noisy,
@@ -2047,11 +2076,10 @@ class Decoder(BaseGaussianDiffusion):
target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target, reduction = 'none')
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
if self.has_p2_loss_reweighting:
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
loss = noise_scheduler.p2_reweigh_loss(loss, times)
loss = loss.mean()
@@ -2066,8 +2094,8 @@ class Decoder(BaseGaussianDiffusion):
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper
@@ -2117,7 +2145,7 @@ class Decoder(BaseGaussianDiffusion):
img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
@@ -2145,7 +2173,8 @@ class Decoder(BaseGaussianDiffusion):
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
is_latent_diffusion = is_latent_diffusion
is_latent_diffusion = is_latent_diffusion,
noise_scheduler = noise_scheduler
)
img = vae.decode(img)
@@ -2171,6 +2200,7 @@ class Decoder(BaseGaussianDiffusion):
unet = self.get_unet(unet_number)
vae = self.vaes[unet_index]
noise_scheduler = self.noise_schedulers[unet_index]
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
@@ -2180,7 +2210,7 @@ class Decoder(BaseGaussianDiffusion):
check_shape(image, 'b c h w', c = self.channels)
assert h >= target_image_size and w >= target_image_size
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed) and not self.unconditional:
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
@@ -2211,7 +2241,7 @@ class Decoder(BaseGaussianDiffusion):
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
# main class