Compare commits

..

12 Commits

Author SHA1 Message Date
Phil Wang
b0cd5f24b6 take care of gradient accumulation automatically for researchers, by passing in a max_batch_size on the decoder or diffusion prior trainer forward 2022-05-14 17:04:09 -07:00
Phil Wang
b494ed81d4 take care of backwards within trainer classes for diffusion prior and decoder, readying to take care of gradient accumulation as well (plus, unsure if loss should be backwards within autocast block) 2022-05-14 15:49:24 -07:00
Phil Wang
ff3474f05c normalize conditioning tokens outside of cross attention blocks 2022-05-14 14:23:52 -07:00
Phil Wang
d5293f19f1 lineup with paper 2022-05-14 13:57:00 -07:00
Phil Wang
e697183849 be able to customize adam eps 2022-05-14 13:55:04 -07:00
Phil Wang
591d37e266 lower default initial learning rate to what Jonathan Ho had in his original repo 2022-05-14 13:22:43 -07:00
Phil Wang
d1f02e8f49 always use sandwich norm for attention layer 2022-05-14 12:13:41 -07:00
Phil Wang
9faab59b23 use post-attn-branch layernorm in attempt to stabilize cross attention conditioning in decoder 2022-05-14 11:58:09 -07:00
Phil Wang
5d27029e98 make sure lowres conditioning image is properly normalized to -1 to 1 for cascading ddpm 2022-05-14 01:23:54 -07:00
Phil Wang
3115fa17b3 fix everything around normalizing images to -1 to 1 for ddpm training automatically 2022-05-14 01:17:11 -07:00
Phil Wang
124d8577c8 move the inverse normalization function called before image embeddings are derived from clip to within the diffusion prior and decoder classes 2022-05-14 00:37:52 -07:00
Phil Wang
2db0c9794c comments 2022-05-12 14:25:20 -07:00
5 changed files with 145 additions and 35 deletions

View File

@@ -732,8 +732,8 @@ clip = CLIP(
# mock data # mock data
text = torch.randint(0, 49408, (4, 256)).cuda() text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda() images = torch.randn(32, 3, 256, 256).cuda()
# decoder (with unet) # decoder (with unet)
@@ -774,8 +774,12 @@ decoder_trainer = DecoderTrainer(
) )
for unet_number in (1, 2): for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward loss = decoder_trainer(
loss.backward() images,
text = text,
unet_number = unet_number, # which unet to train on
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
)
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
@@ -839,7 +843,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
) )
loss = diffusion_prior_trainer(text, images) loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop # after much of the above three lines in a loop
@@ -1017,6 +1020,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes - [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
## Citations ## Citations

View File

@@ -1,7 +1,7 @@
import math import math
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial from functools import partial, wraps
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
@@ -45,6 +45,14 @@ def exists(val):
def identity(t, *args, **kwargs): def identity(t, *args, **kwargs):
return t return t
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d): def default(val, d):
if exists(val): if exists(val):
return val return val
@@ -114,10 +122,10 @@ def resize_image_to(image, target_image_size):
# ddpms expect images to be in the range of -1 to 1 # ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise # but CLIP may otherwise
def normalize_img(img): def normalize_neg_one_to_one(img):
return img * 2 - 1 return img * 2 - 1
def unnormalize_img(normed_img): def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5 return (normed_img + 1) * 0.5
# clip related adapters # clip related adapters
@@ -278,7 +286,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
def embed_image(self, image): def embed_image(self, image):
assert not self.cleared assert not self.cleared
image = resize_image_to(image, self.image_size) image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image)) image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image) image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None) return EmbeddedImage(l2norm(image_embed.float()), None)
@@ -606,7 +614,6 @@ class Attention(nn.Module):
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
causal = False, causal = False,
post_norm = False,
rotary_emb = None rotary_emb = None
): ):
super().__init__() super().__init__()
@@ -616,7 +623,6 @@ class Attention(nn.Module):
self.causal = causal self.causal = causal
self.norm = LayerNorm(dim) self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -627,7 +633,7 @@ class Attention(nn.Module):
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False), nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity() LayerNorm(dim)
) )
def forward(self, x, mask = None, attn_bias = None): def forward(self, x, mask = None, attn_bias = None):
@@ -684,8 +690,7 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v) out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out) return self.to_out(out)
return self.post_norm(out)
class CausalTransformer(nn.Module): class CausalTransformer(nn.Module):
def __init__( def __init__(
@@ -711,7 +716,7 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb), Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
])) ]))
@@ -1158,6 +1163,7 @@ class CrossAttention(nn.Module):
dim_head = 64, dim_head = 64,
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
norm_context = False
): ):
super().__init__() super().__init__()
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
@@ -1167,13 +1173,17 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim) context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim) self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim) self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None): def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device b, n, device = *x.shape[:2], x.device
@@ -1369,6 +1379,9 @@ class Unet(nn.Module):
Rearrange('b (n d) -> b n d', n = num_image_tokens) Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity() ) if image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim)
# text encoding conditioning (optional) # text encoding conditioning (optional)
self.text_to_cond = None self.text_to_cond = None
@@ -1584,6 +1597,11 @@ class Unet(nn.Module):
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2) mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# normalize conditioning tokens
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# go through the layers of the unet, down and up # go through the layers of the unet, down and up
hiddens = [] hiddens = []
@@ -1821,7 +1839,7 @@ class Decoder(BaseGaussianDiffusion):
# eq 15 - https://arxiv.org/abs/2102.09672 # eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape) min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape) max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_img(var_interp_frac_unnormalized) var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp() posterior_variance = posterior_log_variance.exp()
@@ -1844,6 +1862,8 @@ class Decoder(BaseGaussianDiffusion):
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample( img = self.p_sample(
unet, unet,
@@ -1859,11 +1879,19 @@ class Decoder(BaseGaussianDiffusion):
clip_denoised = clip_denoised clip_denoised = clip_denoised
) )
return img unnormalize_img = unnormalize_zero_to_one(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False): def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
model_output = unet( model_output = unet(
@@ -1890,6 +1918,11 @@ class Decoder(BaseGaussianDiffusion):
# return simple loss if not using learned variance # return simple loss if not using learned variance
return loss return loss
# most of the code below is transcribed from
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
# if learning the variance, also include the extra weight kl loss # if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times) true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)

View File

@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
def get_optimizer( def get_optimizer(
params, params,
lr = 3e-4, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False filter_by_requires_grad = False
): ):
if filter_by_requires_grad: if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params)) params = list(filter(lambda t: t.requires_grad, params))
if wd == 0: if wd == 0:
return Adam(params, lr = lr, betas = betas) return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params) params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params) wd_params, no_wd_params = separate_weight_decayable_params(params)
@@ -26,4 +27,4 @@ def get_optimizer(
{'params': list(no_wd_params), 'weight_decay': 0}, {'params': list(no_wd_params), 'weight_decay': 0},
] ]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas) return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -1,6 +1,8 @@
import time import time
import copy import copy
from math import ceil
from functools import partial from functools import partial
from collections.abc import Iterable
import torch import torch
from torch import nn from torch import nn
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
def exists(val): def exists(val):
return val is not None return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1): def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length) return val if isinstance(val, tuple) else ((val,) * length)
@@ -40,6 +45,46 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
# gradient accumulation functions
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
batch_size = len(x)
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
all_args = (x, *args, *kwargs.values())
len_all_args = len(all_args)
split_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_index], chunked_all_args[split_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
yield chunk_size, (chunked_args, chunked_kwargs)
# print helpers # print helpers
def print_ribbon(s, symbol = '=', repeat = 40): def print_ribbon(s, symbol = '=', repeat = 40):
@@ -90,7 +135,7 @@ class EMA(nn.Module):
def __init__( def __init__(
self, self,
model, model,
beta = 0.99, beta = 0.9999,
update_after_step = 1000, update_after_step = 1000,
update_every = 10, update_every = 10,
): ):
@@ -147,6 +192,7 @@ class DiffusionPriorTrainer(nn.Module):
use_ema = True, use_ema = True,
lr = 3e-4, lr = 3e-4,
wd = 1e-2, wd = 1e-2,
eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
**kwargs **kwargs
@@ -173,6 +219,7 @@ class DiffusionPriorTrainer(nn.Module):
diffusion_prior.parameters(), diffusion_prior.parameters(),
lr = lr, lr = lr,
wd = wd, wd = wd,
eps = eps,
**kwargs **kwargs
) )
@@ -206,13 +253,25 @@ class DiffusionPriorTrainer(nn.Module):
def forward( def forward(
self, self,
x,
*args, *args,
divisor = 1, max_batch_size = None,
**kwargs **kwargs
): ):
with autocast(enabled = self.amp): batch_size = x.shape[0]
loss = self.diffusion_prior(*args, **kwargs) total_samples = 0
return self.scaler.scale(loss / divisor) total_loss = 0.
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
total_loss += loss.item() * chunk_size
total_samples += chunk_size
self.scaler.scale(loss * (chunk_size / batch_size)).backward()
return total_loss / total_samples
# decoder trainer # decoder trainer
@@ -221,8 +280,9 @@ class DecoderTrainer(nn.Module):
self, self,
decoder, decoder,
use_ema = True, use_ema = True,
lr = 3e-4, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
eps = 1e-8,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
**kwargs **kwargs
@@ -247,13 +307,14 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay # be able to finely customize learning rate, weight decay
# per unet # per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd)) lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)): for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizer = get_optimizer( optimizer = get_optimizer(
unet.parameters(), unet.parameters(),
lr = unet_lr, lr = unet_lr,
wd = unet_wd, wd = unet_wd,
eps = unet_eps,
**kwargs **kwargs
) )
@@ -321,9 +382,20 @@ class DecoderTrainer(nn.Module):
x, x,
*, *,
unet_number, unet_number,
divisor = 1, max_batch_size = None,
**kwargs **kwargs
): ):
with autocast(enabled = self.amp): batch_size = x.shape[0]
loss = self.decoder(x, unet_number = unet_number, **kwargs) total_samples = 0
return self.scale(loss / divisor, unet_number = unet_number) total_loss = 0.
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
total_loss += loss.item() * chunk_size
total_samples += chunk_size
self.scale(loss * (chunk_size / batch_size), unet_number = unet_number).backward()
return total_loss / total_samples

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.14', version = '0.2.26',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',