mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
bet on the new self-conditioning technique out of geoffrey hintons group
This commit is contained in:
@@ -870,7 +870,7 @@ class Attention(nn.Module):
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
@@ -1157,17 +1157,17 @@ class DiffusionPrior(nn.Module):
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
x_start = pred
|
||||
else:
|
||||
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised and not self.predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
x_start.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_recon = l2norm(x_recon) * self.image_embed_scale
|
||||
x_start = l2norm(x_start) * self.image_embed_scale
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -1571,7 +1571,7 @@ class CrossAttention(nn.Module):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
@@ -1700,6 +1700,7 @@ class Unet(nn.Module):
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
||||
self_cond = False,
|
||||
sparse_attn = False,
|
||||
cosine_sim_cross_attn = False,
|
||||
cosine_sim_self_attn = False,
|
||||
@@ -1735,12 +1736,21 @@ class Unet(nn.Module):
|
||||
|
||||
self.lowres_cond = lowres_cond
|
||||
|
||||
# whether to do self conditioning
|
||||
|
||||
self.self_cond = self_cond
|
||||
|
||||
# determine dimensions
|
||||
|
||||
self.channels = channels
|
||||
self.channels_out = default(channels_out, channels)
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
# initial number of channels depends on
|
||||
# (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade
|
||||
# (2) self conditioning (bit diffusion paper)
|
||||
|
||||
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
|
||||
|
||||
init_dim = default(init_dim, dim)
|
||||
|
||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||
@@ -1994,7 +2004,8 @@ class Unet(nn.Module):
|
||||
text_cond_drop_prob = 0.,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None,
|
||||
disable_checkpoint = False
|
||||
disable_checkpoint = False,
|
||||
self_cond = None
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
|
||||
@@ -2002,6 +2013,14 @@ class Unet(nn.Module):
|
||||
|
||||
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
||||
|
||||
# concat self conditioning, if needed
|
||||
|
||||
if self.self_cond:
|
||||
self_cond = default(self_cond, lambda: torch.zeros_like(x))
|
||||
x = torch.cat((x, self_cond), dim = 1)
|
||||
|
||||
# concat low resolution conditioning
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
@@ -2571,23 +2590,23 @@ class Decoder(nn.Module):
|
||||
x = x.clamp(-s, s) / s
|
||||
return x
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level))
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||
|
||||
if learned_variance:
|
||||
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
||||
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
x_start = pred
|
||||
else:
|
||||
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
x_recon = self.dynamic_threshold(x_recon)
|
||||
x_start = self.dynamic_threshold(x_start)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
|
||||
if learned_variance:
|
||||
# if learned variance, posterio variance and posterior log variance are predicted by the network
|
||||
@@ -2603,16 +2622,17 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||
posterior_variance = posterior_log_variance.exp()
|
||||
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
||||
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
return pred, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddpm(
|
||||
@@ -2638,6 +2658,8 @@ class Decoder(nn.Module):
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
is_inpaint = exists(inpaint_image)
|
||||
resample_times = inpaint_resample_times if is_inpaint else 1
|
||||
|
||||
@@ -2665,13 +2687,16 @@ class Decoder(nn.Module):
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
|
||||
img = self.p_sample(
|
||||
self_cond = x_start if unet.self_cond else None
|
||||
|
||||
img, x_start = self.p_sample(
|
||||
unet,
|
||||
img,
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
self_cond = self_cond,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
predict_x_start = predict_x_start,
|
||||
@@ -2730,6 +2755,8 @@ class Decoder(nn.Module):
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
if not is_latent_diffusion:
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
@@ -2750,7 +2777,9 @@ class Decoder(nn.Module):
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
self_cond = x_start if unet.self_cond else None
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
@@ -2810,13 +2839,35 @@ class Decoder(nn.Module):
|
||||
|
||||
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
|
||||
model_output = unet(
|
||||
x_noisy,
|
||||
times,
|
||||
# unet kwargs
|
||||
|
||||
unet_kwargs = dict(
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
)
|
||||
|
||||
# self conditioning
|
||||
|
||||
self_cond = None
|
||||
|
||||
if unet.self_cond and random.random() < 0.5:
|
||||
with torch.no_grad():
|
||||
self_cond = unet(x_noisy, times, **unet_kwargs)
|
||||
|
||||
if learned_variance:
|
||||
self_cond, _ = self_cond.chunk(2, dim = 1)
|
||||
|
||||
self_cond = self_cond.detach()
|
||||
|
||||
# forward to get model prediction
|
||||
|
||||
model_output = unet(
|
||||
x_noisy,
|
||||
times,
|
||||
**unet_kwargs,
|
||||
self_cond = self_cond,
|
||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
)
|
||||
@@ -2847,7 +2898,7 @@ class Decoder(nn.Module):
|
||||
# if learning the variance, also include the extra weight kl loss
|
||||
|
||||
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
|
||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||
|
||||
|
||||
Reference in New Issue
Block a user