make sure learned variance can work for any number of unets in the decoder, defaults to first unet, as suggested was used in the paper

This commit is contained in:
Phil Wang
2022-05-12 14:18:15 -07:00
parent 28b58e568c
commit 2277b47ffd
2 changed files with 130 additions and 19 deletions

View File

@@ -33,6 +33,10 @@ from rotary_embedding_torch import RotaryEmbedding
from x_clip import CLIP
from coca_pytorch import CoCa
# constants
NAT = 1. / math.log(2.)
# helper functions
def exists(val):
@@ -91,6 +95,9 @@ def freeze_model_and_make_eval_(model):
# tensor helpers
def log(t, eps = 1e-12):
return torch.log(t.clamp(min = eps))
def l2norm(t):
return F.normalize(t, dim = -1)
@@ -297,6 +304,36 @@ def noise_like(shape, device, repeat=False):
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def meanflat(x):
return x.mean(dim = tuple(range(1, len(x.shape))))
def normal_kl(mean1, logvar1, mean2, logvar2):
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
def approx_standard_normal_cdf(x):
return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1. / 255.)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = log(cdf_plus)
log_one_minus_cdf_min = log(1. - cdf_min)
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(x < -thres,
log_cdf_plus,
torch.where(x > thres,
log_one_minus_cdf_min,
log(cdf_delta)))
return log_probs
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
@@ -1266,6 +1303,7 @@ class Unet(nn.Module):
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
channels_out = None,
attn_dim_head = 32,
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
@@ -1296,6 +1334,7 @@ class Unet(nn.Module):
# determine dimensions
self.channels = channels
self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 3 * 2)
@@ -1401,11 +1440,9 @@ class Unet(nn.Module):
Upsample(dim_in)
]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
ResnetBlock(dim, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, out_dim, 1)
nn.Conv2d(dim, self.channels_out, 1)
)
# if the current settings for the unet are not correct
@@ -1415,13 +1452,25 @@ class Unet(nn.Module):
*,
lowres_cond,
channels,
channels_out,
cond_on_image_embeds,
cond_on_text_encodings
):
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \
channels_out == self.channels_out:
return self
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
updated_kwargs = dict(
lowres_cond = lowres_cond,
channels = channels,
channels_out = channels_out,
cond_on_image_embeds = cond_on_image_embeds,
cond_on_text_encodings = cond_on_text_encodings
)
return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale(
@@ -1621,6 +1670,8 @@ class Decoder(BaseGaussianDiffusion):
clip_denoised = True,
clip_x_start = True,
clip_adapter_overrides = dict(),
learned_variance = True,
vb_loss_weight = 0.001,
unconditional = False
):
super().__init__(
@@ -1659,10 +1710,18 @@ class Decoder(BaseGaussianDiffusion):
unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
self.learned_variance = learned_variance
self.vb_loss_weight = vb_loss_weight
# construct unets and vaes
self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
@@ -1670,12 +1729,14 @@ class Decoder(BaseGaussianDiffusion):
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
unet_channels = default(latent_dim, self.channels)
unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
cond_on_image_embeds = is_first and not unconditional,
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
channels = unet_channels
channels = unet_channels,
channels_out = unet_channels_out
)
self.unets.append(one_unet)
@@ -1738,8 +1799,11 @@ class Decoder(BaseGaussianDiffusion):
yield
unet.cpu()
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
if predict_x_start:
x_recon = pred
@@ -1750,19 +1814,31 @@ class Decoder(BaseGaussianDiffusion):
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network
# by an interpolation of the max and min log beta values
# eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_img(var_interp_frac_unnormalized)
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()
return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
@@ -1779,17 +1855,18 @@ class Decoder(BaseGaussianDiffusion):
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
return img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
pred = unet(
model_output = unet(
x_noisy,
times,
image_embed = image_embed,
@@ -1800,10 +1877,43 @@ class Decoder(BaseGaussianDiffusion):
text_cond_drop_prob = self.text_cond_drop_prob,
)
if learned_variance:
pred, _ = model_output.chunk(2, dim = 1)
else:
pred = model_output
target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target)
return loss
if not learned_variance:
# return simple loss if not using learned variance
return loss
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper
detached_model_mean = model_mean.detach()
kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
kl = meanflat(kl) * NAT
decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
decoder_nll = meanflat(decoder_nll) * NAT
# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
vb_losses = torch.where(times == 0, decoder_nll, kl)
# weight the vb loss smaller, for stability, as in the paper (recommended 0.001)
vb_loss = vb_losses.mean() * self.vb_loss_weight
return loss + vb_loss
@torch.inference_mode()
@eval_decorator
@@ -1830,7 +1940,7 @@ class Decoder(BaseGaussianDiffusion):
img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
@@ -1856,6 +1966,7 @@ class Decoder(BaseGaussianDiffusion):
text_mask = text_mask,
cond_scale = cond_scale,
predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img
)
@@ -1885,6 +1996,7 @@ class Decoder(BaseGaussianDiffusion):
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels)
@@ -1922,7 +2034,7 @@ class Decoder(BaseGaussianDiffusion):
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
# main class
@@ -1972,4 +2084,3 @@ class DALLE2(nn.Module):
return images[0]
return images