Compare commits

...

2 Commits
1.4.6 ... 1.6.0

Author SHA1 Message Date
Phil Wang
7c5477b26d bet on the new self-conditioning technique out of geoffrey hintons group 2022-08-12 11:36:08 -07:00
Phil Wang
be3bb868bf add gradient checkpointing for all resnet blocks 2022-08-02 19:21:44 -07:00
3 changed files with 138 additions and 35 deletions

View File

@@ -1253,4 +1253,15 @@ For detailed information on training the diffusion prior, please refer to the [d
}
```
```bibtex
@misc{chen2022analog,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
year = {2022},
eprint = {2208.04202},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -8,6 +8,7 @@ from pathlib import Path
import torch
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch import nn, einsum
import torchvision.transforms as T
@@ -108,6 +109,28 @@ def pad_tuple_to_length(t, length, fillvalue = None):
return t
return (*t, *((fillvalue,) * remain_length))
# checkpointing helper function
def make_checkpointable(fn, **kwargs):
if isinstance(fn, nn.ModuleList):
return [maybe(make_checkpointable)(el, **kwargs) for el in fn]
condition = kwargs.pop('condition', None)
if exists(condition) and not condition(fn):
return fn
@wraps(fn)
def inner(*args):
input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
if not input_needs_grad:
return fn(*args)
return checkpoint(fn, *args)
return inner
# for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad):
@@ -847,7 +870,7 @@ class Attention(nn.Module):
# attention
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.dropout(attn)
# aggregate values
@@ -1134,17 +1157,17 @@ class DiffusionPrior(nn.Module):
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
if self.predict_x_start:
x_recon = pred
x_start = pred
else:
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.)
x_start.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon) * self.image_embed_scale
x_start = l2norm(x_start) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
@@ -1548,7 +1571,7 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1677,6 +1700,7 @@ class Unet(nn.Module):
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
self_cond = False,
sparse_attn = False,
cosine_sim_cross_attn = False,
cosine_sim_self_attn = False,
@@ -1698,6 +1722,7 @@ class Unet(nn.Module):
pixel_shuffle_upsample = True,
final_conv_kernel_size = 1,
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
checkpoint_during_training = False,
**kwargs
):
super().__init__()
@@ -1711,12 +1736,21 @@ class Unet(nn.Module):
self.lowres_cond = lowres_cond
# whether to do self conditioning
self.self_cond = self_cond
# determine dimensions
self.channels = channels
self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
# initial number of channels depends on
# (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade
# (2) self conditioning (bit diffusion paper)
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
init_dim = default(init_dim, dim)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
@@ -1908,6 +1942,10 @@ class Unet(nn.Module):
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
# whether to checkpoint during training
self.checkpoint_during_training = checkpoint_during_training
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
@@ -1965,7 +2003,9 @@ class Unet(nn.Module):
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
blur_kernel_size = None,
disable_checkpoint = False,
self_cond = None
):
batch_size, device = x.shape[0], x.device
@@ -1973,6 +2013,14 @@ class Unet(nn.Module):
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
# concat self conditioning, if needed
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x, self_cond), dim = 1)
# concat low resolution conditioning
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
@@ -2087,17 +2135,29 @@ class Unet(nn.Module):
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# gradient checkpointing
can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint
apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity
# make checkpointable modules
init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]
can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)
downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]
# initial resnet block
if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
if exists(init_resnet_block):
x = init_resnet_block(x, t)
# go through the layers of the unet, down and up
down_hiddens = []
up_hiddens = []
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:
if exists(pre_downsample):
x = pre_downsample(x)
@@ -2113,16 +2173,16 @@ class Unet(nn.Module):
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, t, mid_c)
x = mid_block1(x, t, mid_c)
if exists(self.mid_attn):
x = self.mid_attn(x)
if exists(mid_attn):
x = mid_attn(x)
x = self.mid_block2(x, t, mid_c)
x = mid_block2(x, t, mid_c)
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, resnet_blocks, attn, upsample in self.ups:
for init_block, resnet_blocks, attn, upsample in ups:
x = connect_skip(x)
x = init_block(x, t, c)
@@ -2139,7 +2199,7 @@ class Unet(nn.Module):
x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t)
x = final_resnet_block(x, t)
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
@@ -2530,23 +2590,23 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s
return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level))
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
if predict_x_start:
x_recon = pred
x_start = pred
else:
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
x_recon = self.dynamic_threshold(x_recon)
x_start = self.dynamic_threshold(x_start)
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network
@@ -2562,16 +2622,17 @@ class Decoder(nn.Module):
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()
return model_mean, posterior_variance, posterior_log_variance
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad()
def p_sample_loop_ddpm(
@@ -2597,6 +2658,8 @@ class Decoder(nn.Module):
b = shape[0]
img = torch.randn(shape, device = device)
x_start = None # for self-conditioning
is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1
@@ -2624,13 +2687,16 @@ class Decoder(nn.Module):
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
img = self.p_sample(
self_cond = x_start if unet.self_cond else None
img, x_start = self.p_sample(
unet,
img,
times,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
self_cond = self_cond,
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start,
@@ -2689,6 +2755,8 @@ class Decoder(nn.Module):
img = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
@@ -2709,7 +2777,9 @@ class Decoder(nn.Module):
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
self_cond = x_start if unet.self_cond else None
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
@@ -2769,13 +2839,35 @@ class Decoder(nn.Module):
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
model_output = unet(
x_noisy,
times,
# unet kwargs
unet_kwargs = dict(
image_embed = image_embed,
text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
)
# self conditioning
self_cond = None
if unet.self_cond and random.random() < 0.5:
with torch.no_grad():
self_cond = unet(x_noisy, times, **unet_kwargs)
if learned_variance:
self_cond, _ = self_cond.chunk(2, dim = 1)
self_cond = self_cond.detach()
# forward to get model prediction
model_output = unet(
x_noisy,
times,
**unet_kwargs,
self_cond = self_cond,
image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
)
@@ -2806,7 +2898,7 @@ class Decoder(nn.Module):
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper

View File

@@ -1 +1 @@
__version__ = '1.4.6'
__version__ = '1.6.0'