mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 03:24:22 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9440411954 | ||
|
|
981d407792 | ||
|
|
7c5477b26d |
11
README.md
11
README.md
@@ -1253,4 +1253,15 @@ For detailed information on training the diffusion prior, please refer to the [d
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2022analog,
|
||||
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
|
||||
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
|
||||
year = {2022},
|
||||
eprint = {2208.04202},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -870,7 +870,7 @@ class Attention(nn.Module):
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
@@ -937,9 +937,12 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
num_image_embeds = 1,
|
||||
num_text_embeds = 1,
|
||||
max_text_len = 256,
|
||||
self_cond = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
self.num_time_embeds = num_time_embeds
|
||||
self.num_image_embeds = num_image_embeds
|
||||
self.num_text_embeds = num_text_embeds
|
||||
@@ -967,6 +970,10 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||
|
||||
# whether to use self conditioning, Hinton's group's new ddpm technique
|
||||
|
||||
self.self_cond = self_cond
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
*args,
|
||||
@@ -988,12 +995,19 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
*,
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
self_cond = None,
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
|
||||
|
||||
# setup self conditioning
|
||||
|
||||
self_cond = None
|
||||
if self.self_cond:
|
||||
self_cond = default(self_cond, lambda: torch.zeros(batch, 1, self.dim, device = device, dtype = dtype))
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
|
||||
@@ -1043,13 +1057,16 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
|
||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
|
||||
if self.self_cond:
|
||||
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
|
||||
|
||||
tokens = torch.cat((
|
||||
text_encodings,
|
||||
text_embed,
|
||||
@@ -1151,45 +1168,50 @@ class DiffusionPrior(nn.Module):
|
||||
def l2norm_clamp_embed(self, image_embed):
|
||||
return l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||
def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
x_start = pred
|
||||
else:
|
||||
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised and not self.predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
x_start.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_recon = l2norm(x_recon) * self.image_embed_scale
|
||||
x_start = l2norm(x_start) * self.image_embed_scale
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
return pred, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
||||
batch, device = shape[0], self.device
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
self_cond = x_start if self.net.self_cond else None
|
||||
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
|
||||
|
||||
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||
@@ -1207,6 +1229,8 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
@@ -1216,7 +1240,9 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
|
||||
self_cond = x_start if self.net.self_cond else None
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_start = pred
|
||||
@@ -1260,9 +1286,15 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||
|
||||
self_cond = None
|
||||
if self.net.self_cond and random.random() < 0.5:
|
||||
with torch.no_grad():
|
||||
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
||||
|
||||
pred = self.net(
|
||||
image_embed_noisy,
|
||||
times,
|
||||
self_cond = self_cond,
|
||||
cond_drop_prob = self.cond_drop_prob,
|
||||
**text_cond
|
||||
)
|
||||
@@ -1571,7 +1603,7 @@ class CrossAttention(nn.Module):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
@@ -1700,6 +1732,7 @@ class Unet(nn.Module):
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
||||
self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
|
||||
sparse_attn = False,
|
||||
cosine_sim_cross_attn = False,
|
||||
cosine_sim_self_attn = False,
|
||||
@@ -1735,12 +1768,21 @@ class Unet(nn.Module):
|
||||
|
||||
self.lowres_cond = lowres_cond
|
||||
|
||||
# whether to do self conditioning
|
||||
|
||||
self.self_cond = self_cond
|
||||
|
||||
# determine dimensions
|
||||
|
||||
self.channels = channels
|
||||
self.channels_out = default(channels_out, channels)
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
# initial number of channels depends on
|
||||
# (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade
|
||||
# (2) self conditioning (bit diffusion paper)
|
||||
|
||||
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
|
||||
|
||||
init_dim = default(init_dim, dim)
|
||||
|
||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||
@@ -1994,7 +2036,8 @@ class Unet(nn.Module):
|
||||
text_cond_drop_prob = 0.,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None,
|
||||
disable_checkpoint = False
|
||||
disable_checkpoint = False,
|
||||
self_cond = None
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
|
||||
@@ -2002,6 +2045,14 @@ class Unet(nn.Module):
|
||||
|
||||
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
||||
|
||||
# concat self conditioning, if needed
|
||||
|
||||
if self.self_cond:
|
||||
self_cond = default(self_cond, lambda: torch.zeros_like(x))
|
||||
x = torch.cat((x, self_cond), dim = 1)
|
||||
|
||||
# concat low resolution conditioning
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
@@ -2571,23 +2622,23 @@ class Decoder(nn.Module):
|
||||
x = x.clamp(-s, s) / s
|
||||
return x
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level))
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||
|
||||
if learned_variance:
|
||||
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
||||
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
x_start = pred
|
||||
else:
|
||||
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
x_recon = self.dynamic_threshold(x_recon)
|
||||
x_start = self.dynamic_threshold(x_start)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
|
||||
if learned_variance:
|
||||
# if learned variance, posterio variance and posterior log variance are predicted by the network
|
||||
@@ -2603,16 +2654,17 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||
posterior_variance = posterior_log_variance.exp()
|
||||
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
||||
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
return pred, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddpm(
|
||||
@@ -2638,6 +2690,8 @@ class Decoder(nn.Module):
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
is_inpaint = exists(inpaint_image)
|
||||
resample_times = inpaint_resample_times if is_inpaint else 1
|
||||
|
||||
@@ -2665,13 +2719,16 @@ class Decoder(nn.Module):
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
|
||||
img = self.p_sample(
|
||||
self_cond = x_start if unet.self_cond else None
|
||||
|
||||
img, x_start = self.p_sample(
|
||||
unet,
|
||||
img,
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
self_cond = self_cond,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
predict_x_start = predict_x_start,
|
||||
@@ -2730,6 +2787,8 @@ class Decoder(nn.Module):
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
if not is_latent_diffusion:
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
@@ -2750,7 +2809,9 @@ class Decoder(nn.Module):
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
self_cond = x_start if unet.self_cond else None
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
@@ -2810,13 +2871,35 @@ class Decoder(nn.Module):
|
||||
|
||||
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
|
||||
model_output = unet(
|
||||
x_noisy,
|
||||
times,
|
||||
# unet kwargs
|
||||
|
||||
unet_kwargs = dict(
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
)
|
||||
|
||||
# self conditioning
|
||||
|
||||
self_cond = None
|
||||
|
||||
if unet.self_cond and random.random() < 0.5:
|
||||
with torch.no_grad():
|
||||
self_cond = unet(x_noisy, times, **unet_kwargs)
|
||||
|
||||
if learned_variance:
|
||||
self_cond, _ = self_cond.chunk(2, dim = 1)
|
||||
|
||||
self_cond = self_cond.detach()
|
||||
|
||||
# forward to get model prediction
|
||||
|
||||
model_output = unet(
|
||||
x_noisy,
|
||||
times,
|
||||
**unet_kwargs,
|
||||
self_cond = self_cond,
|
||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
)
|
||||
@@ -2847,7 +2930,7 @@ class Decoder(nn.Module):
|
||||
# if learning the variance, also include the extra weight kl loss
|
||||
|
||||
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
|
||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.5.0'
|
||||
__version__ = '1.6.1'
|
||||
|
||||
Reference in New Issue
Block a user