Compare commits

...

24 Commits

Author SHA1 Message Date
Phil Wang
f1739267e4 simplify more 2022-05-14 17:13:13 -07:00
Phil Wang
b0cd5f24b6 take care of gradient accumulation automatically for researchers, by passing in a max_batch_size on the decoder or diffusion prior trainer forward 2022-05-14 17:04:09 -07:00
Phil Wang
b494ed81d4 take care of backwards within trainer classes for diffusion prior and decoder, readying to take care of gradient accumulation as well (plus, unsure if loss should be backwards within autocast block) 2022-05-14 15:49:24 -07:00
Phil Wang
ff3474f05c normalize conditioning tokens outside of cross attention blocks 2022-05-14 14:23:52 -07:00
Phil Wang
d5293f19f1 lineup with paper 2022-05-14 13:57:00 -07:00
Phil Wang
e697183849 be able to customize adam eps 2022-05-14 13:55:04 -07:00
Phil Wang
591d37e266 lower default initial learning rate to what Jonathan Ho had in his original repo 2022-05-14 13:22:43 -07:00
Phil Wang
d1f02e8f49 always use sandwich norm for attention layer 2022-05-14 12:13:41 -07:00
Phil Wang
9faab59b23 use post-attn-branch layernorm in attempt to stabilize cross attention conditioning in decoder 2022-05-14 11:58:09 -07:00
Phil Wang
5d27029e98 make sure lowres conditioning image is properly normalized to -1 to 1 for cascading ddpm 2022-05-14 01:23:54 -07:00
Phil Wang
3115fa17b3 fix everything around normalizing images to -1 to 1 for ddpm training automatically 2022-05-14 01:17:11 -07:00
Phil Wang
124d8577c8 move the inverse normalization function called before image embeddings are derived from clip to within the diffusion prior and decoder classes 2022-05-14 00:37:52 -07:00
Phil Wang
2db0c9794c comments 2022-05-12 14:25:20 -07:00
Phil Wang
2277b47ffd make sure learned variance can work for any number of unets in the decoder, defaults to first unet, as suggested was used in the paper 2022-05-12 14:18:15 -07:00
Phil Wang
28b58e568c cleanup in preparation of option for learned variance 2022-05-12 12:04:52 -07:00
Phil Wang
924455d97d align the ema model device back after sampling from the cascading ddpm in the decoder 2022-05-11 19:56:54 -07:00
Phil Wang
6021945fc8 default to l2 loss 2022-05-11 19:24:51 -07:00
Light-V
6f76652d11 fix typo in README.md (#85)
The default config for clip from openai should be ViT-B/32
2022-05-11 13:38:16 -07:00
Phil Wang
3dda2570ed fix amp issue for https://github.com/lucidrains/DALLE2-pytorch/issues/82 2022-05-11 08:21:39 -07:00
Phil Wang
2f3c02dba8 numerical accuracy for noise schedule parameters 2022-05-10 15:28:46 -07:00
Phil Wang
908088cfea wrap up cross embed layer feature 2022-05-10 12:19:34 -07:00
Phil Wang
8dc8a3de0d product management 2022-05-10 11:51:38 -07:00
Phil Wang
35f89556ba bring in the cross embed layer from Crossformer paper for initial convolution in unet 2022-05-10 11:50:38 -07:00
Phil Wang
2b55f753b9 fix new issue with github actions and auto pypi package uploading 2022-05-10 10:51:15 -07:00
5 changed files with 357 additions and 84 deletions

View File

@@ -508,7 +508,7 @@ To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it i
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
# openai pretrained clip - defaults to ViT/B-32
# openai pretrained clip - defaults to ViT-B/32
clip = OpenAIClipAdapter()
@@ -732,8 +732,8 @@ clip = CLIP(
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()
# decoder (with unet)
@@ -774,8 +774,12 @@ decoder_trainer = DecoderTrainer(
)
for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
loss.backward()
loss = decoder_trainer(
images,
text = text,
unet_number = unet_number, # which unet to train on
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
)
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
@@ -839,7 +843,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
)
loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop
@@ -1002,12 +1005,13 @@ Once built, images will be saved to the same directory the command is invoked
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
@@ -1016,6 +1020,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
## Citations
@@ -1093,4 +1098,15 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@misc{wang2021crossformer,
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
year = {2021},
eprint = {2108.00154},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,7 +1,7 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from functools import partial, wraps
from contextlib import contextmanager
from collections import namedtuple
from pathlib import Path
@@ -33,6 +33,10 @@ from rotary_embedding_torch import RotaryEmbedding
from x_clip import CLIP
from coca_pytorch import CoCa
# constants
NAT = 1. / math.log(2.)
# helper functions
def exists(val):
@@ -41,6 +45,14 @@ def exists(val):
def identity(t, *args, **kwargs):
return t
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d):
if exists(val):
return val
@@ -91,6 +103,9 @@ def freeze_model_and_make_eval_(model):
# tensor helpers
def log(t, eps = 1e-12):
return torch.log(t.clamp(min = eps))
def l2norm(t):
return F.normalize(t, dim = -1)
@@ -107,10 +122,10 @@ def resize_image_to(image, target_image_size):
# ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise
def normalize_img(img):
def normalize_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_img(normed_img):
def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5
# clip related adapters
@@ -271,7 +286,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
def embed_image(self, image):
assert not self.cleared
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
@@ -297,13 +312,43 @@ def noise_like(shape, device, repeat=False):
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def meanflat(x):
return x.mean(dim = tuple(range(1, len(x.shape))))
def normal_kl(mean1, logvar1, mean2, logvar2):
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
def approx_standard_normal_cdf(x):
return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1. / 255.)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = log(cdf_plus)
log_one_minus_cdf_min = log(1. - cdf_min)
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(x < -thres,
log_cdf_plus,
torch.where(x > thres,
log_one_minus_cdf_min,
log(cdf_delta)))
return log_probs
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
@@ -314,21 +359,21 @@ def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps)
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
def sigmoid_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = torch.linspace(-6, 6, timesteps)
betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
@@ -368,17 +413,21 @@ class BaseGaussianDiffusion(nn.Module):
self.loss_type = loss_type
self.loss_fn = loss_fn
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# register buffer helper function to cast double back to float
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
@@ -386,19 +435,13 @@ class BaseGaussianDiffusion(nn.Module):
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
@@ -571,7 +614,6 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
post_norm = False,
rotary_emb = None
):
super().__init__()
@@ -581,7 +623,6 @@ class Attention(nn.Module):
self.causal = causal
self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -592,7 +633,7 @@ class Attention(nn.Module):
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
LayerNorm(dim)
)
def forward(self, x, mask = None, attn_bias = None):
@@ -649,8 +690,7 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return self.post_norm(out)
return self.to_out(out)
class CausalTransformer(nn.Module):
def __init__(
@@ -676,7 +716,7 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
@@ -827,7 +867,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
image_channels = 3,
timesteps = 1000,
cond_drop_prob = 0.,
loss_type = "l1",
loss_type = "l2",
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
@@ -1123,6 +1163,7 @@ class CrossAttention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
norm_context = False
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -1132,13 +1173,17 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim)
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
@@ -1228,6 +1273,33 @@ class LinearAttention(nn.Module):
out = self.nonlin(out)
return self.to_out(out)
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
kernel_sizes,
dim_out = None,
stride = 2
):
super().__init__()
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# calculate the dimension at each scale
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
class Unet(nn.Module):
def __init__(
self,
@@ -1241,6 +1313,7 @@ class Unet(nn.Module):
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
channels_out = None,
attn_dim_head = 32,
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
@@ -1252,6 +1325,9 @@ class Unet(nn.Module):
init_dim = None,
init_conv_kernel_size = 7,
resnet_groups = 8,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
**kwargs
):
super().__init__()
@@ -1268,12 +1344,12 @@ class Unet(nn.Module):
# determine dimensions
self.channels = channels
self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 2)
init_dim = default(init_dim, dim // 3 * 2)
assert (init_conv_kernel_size % 2) == 1
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
@@ -1303,6 +1379,9 @@ class Unet(nn.Module):
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim)
# text encoding conditioning (optional)
self.text_to_cond = None
@@ -1333,6 +1412,12 @@ class Unet(nn.Module):
assert len(resnet_groups) == len(in_out)
# downsample klass
downsample_klass = Downsample
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# layers
self.downs = nn.ModuleList([])
@@ -1348,7 +1433,7 @@ class Unet(nn.Module):
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Downsample(dim_out) if not is_last else nn.Identity()
downsample_klass(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
@@ -1368,11 +1453,9 @@ class Unet(nn.Module):
Upsample(dim_in)
]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
ResnetBlock(dim, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, out_dim, 1)
nn.Conv2d(dim, self.channels_out, 1)
)
# if the current settings for the unet are not correct
@@ -1382,13 +1465,25 @@ class Unet(nn.Module):
*,
lowres_cond,
channels,
channels_out,
cond_on_image_embeds,
cond_on_text_encodings
):
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \
channels_out == self.channels_out:
return self
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
updated_kwargs = dict(
lowres_cond = lowres_cond,
channels = channels,
channels_out = channels_out,
cond_on_image_embeds = cond_on_image_embeds,
cond_on_text_encodings = cond_on_text_encodings
)
return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale(
@@ -1453,11 +1548,12 @@ class Unet(nn.Module):
if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where(
image_keep_mask,
image_tokens,
self.null_image_embed
null_image_embed
)
# take care of text encodings (optional)
@@ -1481,10 +1577,12 @@ class Unet(nn.Module):
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
text_tokens = torch.where(
text_keep_mask,
text_tokens,
self.null_text_embed
null_text_embed
)
# main conditioning tokens (c)
@@ -1499,6 +1597,11 @@ class Unet(nn.Module):
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# normalize conditioning tokens
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# go through the layers of the unet, down and up
hiddens = []
@@ -1572,7 +1675,7 @@ class Decoder(BaseGaussianDiffusion):
timesteps = 1000,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l1',
loss_type = 'l2',
beta_schedule = 'cosine',
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
@@ -1585,6 +1688,8 @@ class Decoder(BaseGaussianDiffusion):
clip_denoised = True,
clip_x_start = True,
clip_adapter_overrides = dict(),
learned_variance = True,
vb_loss_weight = 0.001,
unconditional = False
):
super().__init__(
@@ -1623,10 +1728,18 @@ class Decoder(BaseGaussianDiffusion):
unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
self.learned_variance = learned_variance
self.vb_loss_weight = vb_loss_weight
# construct unets and vaes
self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
@@ -1634,12 +1747,14 @@ class Decoder(BaseGaussianDiffusion):
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
unet_channels = default(latent_dim, self.channels)
unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
cond_on_image_embeds = is_first and not unconditional,
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
channels = unet_channels
channels = unet_channels,
channels_out = unet_channels_out
)
self.unets.append(one_unet)
@@ -1702,8 +1817,11 @@ class Decoder(BaseGaussianDiffusion):
yield
unet.cpu()
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
if predict_x_start:
x_recon = pred
@@ -1714,24 +1832,38 @@ class Decoder(BaseGaussianDiffusion):
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network
# by an interpolation of the max and min log beta values
# eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()
return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(
unet,
@@ -1743,17 +1875,26 @@ class Decoder(BaseGaussianDiffusion):
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
return img
unnormalize_img = unnormalize_zero_to_one(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
pred = unet(
model_output = unet(
x_noisy,
times,
image_embed = image_embed,
@@ -1764,10 +1905,48 @@ class Decoder(BaseGaussianDiffusion):
text_cond_drop_prob = self.text_cond_drop_prob,
)
if learned_variance:
pred, _ = model_output.chunk(2, dim = 1)
else:
pred = model_output
target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target)
return loss
if not learned_variance:
# return simple loss if not using learned variance
return loss
# most of the code below is transcribed from
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper
detached_model_mean = model_mean.detach()
kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
kl = meanflat(kl) * NAT
decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
decoder_nll = meanflat(decoder_nll) * NAT
# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
vb_losses = torch.where(times == 0, decoder_nll, kl)
# weight the vb loss smaller, for stability, as in the paper (recommended 0.001)
vb_loss = vb_losses.mean() * self.vb_loss_weight
return loss + vb_loss
@torch.inference_mode()
@eval_decorator
@@ -1794,7 +1973,7 @@ class Decoder(BaseGaussianDiffusion):
img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
@@ -1820,6 +1999,7 @@ class Decoder(BaseGaussianDiffusion):
text_mask = text_mask,
cond_scale = cond_scale,
predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img
)
@@ -1849,6 +2029,7 @@ class Decoder(BaseGaussianDiffusion):
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels)
@@ -1886,7 +2067,7 @@ class Decoder(BaseGaussianDiffusion):
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
# main class
@@ -1936,4 +2117,3 @@ class DALLE2(nn.Module):
return images[0]
return images

View File

@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
def get_optimizer(
params,
lr = 3e-4,
lr = 2e-5,
wd = 1e-2,
betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if wd == 0:
return Adam(params, lr = lr, betas = betas)
return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
@@ -26,4 +27,4 @@ def get_optimizer(
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -1,6 +1,8 @@
import time
import copy
from math import ceil
from functools import partial
from collections.abc import Iterable
import torch
from torch import nn
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
@@ -40,6 +45,47 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# gradient accumulation functions
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
batch_size = len(x)
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
all_args = (x, *args, *kwargs.values())
len_all_args = len(all_args)
split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
@@ -90,7 +136,7 @@ class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
beta = 0.9999,
update_after_step = 1000,
update_every = 10,
):
@@ -105,6 +151,10 @@ class EMA(nn.Module):
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def update(self):
self.step += 1
@@ -143,6 +193,7 @@ class DiffusionPriorTrainer(nn.Module):
use_ema = True,
lr = 3e-4,
wd = 1e-2,
eps = 1e-6,
max_grad_norm = None,
amp = False,
**kwargs
@@ -169,6 +220,7 @@ class DiffusionPriorTrainer(nn.Module):
diffusion_prior.parameters(),
lr = lr,
wd = wd,
eps = eps,
**kwargs
)
@@ -202,13 +254,21 @@ class DiffusionPriorTrainer(nn.Module):
def forward(
self,
x,
*args,
divisor = 1,
max_batch_size = None,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor)
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
total_loss += loss.item() * chunk_size_frac
self.scaler.scale(loss * chunk_size_frac).backward()
return total_loss
# decoder trainer
@@ -217,8 +277,9 @@ class DecoderTrainer(nn.Module):
self,
decoder,
use_ema = True,
lr = 3e-4,
lr = 2e-5,
wd = 1e-2,
eps = 1e-8,
max_grad_norm = None,
amp = False,
**kwargs
@@ -243,13 +304,14 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay
# per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
**kwargs
)
@@ -305,6 +367,11 @@ class DecoderTrainer(nn.Module):
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
ema.restore_ema_model_device()
return output
def forward(
@@ -312,9 +379,17 @@ class DecoderTrainer(nn.Module):
x,
*,
unet_number,
divisor = 1,
max_batch_size = None,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number)
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
total_loss += loss.item() * chunk_size_frac
self.scale(loss * chunk_size_frac, unet_number = unet_number).backward()
return total_loss

View File

@@ -10,11 +10,12 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.6',
version = '0.2.27',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
long_description_content_type = 'text/markdown',
url = 'https://github.com/lucidrains/dalle2-pytorch',
keywords = [
'artificial intelligence',