Compare commits

..

10 Commits

5 changed files with 61 additions and 22 deletions

View File

@@ -1017,6 +1017,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
## Citations

View File

@@ -1,7 +1,7 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from functools import partial, wraps
from contextlib import contextmanager
from collections import namedtuple
from pathlib import Path
@@ -45,6 +45,14 @@ def exists(val):
def identity(t, *args, **kwargs):
return t
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d):
if exists(val):
return val
@@ -114,10 +122,10 @@ def resize_image_to(image, target_image_size):
# ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise
def normalize_img(img):
def normalize_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_img(normed_img):
def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5
# clip related adapters
@@ -278,7 +286,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
def embed_image(self, image):
assert not self.cleared
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
@@ -606,7 +614,6 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
post_norm = False,
rotary_emb = None
):
super().__init__()
@@ -616,7 +623,6 @@ class Attention(nn.Module):
self.causal = causal
self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -627,7 +633,7 @@ class Attention(nn.Module):
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
LayerNorm(dim)
)
def forward(self, x, mask = None, attn_bias = None):
@@ -684,8 +690,7 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return self.post_norm(out)
return self.to_out(out)
class CausalTransformer(nn.Module):
def __init__(
@@ -711,7 +716,7 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
@@ -1158,6 +1163,7 @@ class CrossAttention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
norm_context = False
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -1167,13 +1173,17 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim)
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
@@ -1369,6 +1379,9 @@ class Unet(nn.Module):
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim)
# text encoding conditioning (optional)
self.text_to_cond = None
@@ -1584,6 +1597,11 @@ class Unet(nn.Module):
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# normalize conditioning tokens
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# go through the layers of the unet, down and up
hiddens = []
@@ -1821,7 +1839,7 @@ class Decoder(BaseGaussianDiffusion):
# eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_img(var_interp_frac_unnormalized)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()
@@ -1844,6 +1862,8 @@ class Decoder(BaseGaussianDiffusion):
b = shape[0]
img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(
unet,
@@ -1859,11 +1879,19 @@ class Decoder(BaseGaussianDiffusion):
clip_denoised = clip_denoised
)
return img
unnormalize_img = unnormalize_zero_to_one(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
model_output = unet(
@@ -1890,6 +1918,11 @@ class Decoder(BaseGaussianDiffusion):
# return simple loss if not using learned variance
return loss
# most of the code below is transcribed from
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)

View File

@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
def get_optimizer(
params,
lr = 3e-4,
lr = 2e-5,
wd = 1e-2,
betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if wd == 0:
return Adam(params, lr = lr, betas = betas)
return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
@@ -26,4 +27,4 @@ def get_optimizer(
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -90,7 +90,7 @@ class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
beta = 0.9999,
update_after_step = 1000,
update_every = 10,
):
@@ -147,6 +147,7 @@ class DiffusionPriorTrainer(nn.Module):
use_ema = True,
lr = 3e-4,
wd = 1e-2,
eps = 1e-6,
max_grad_norm = None,
amp = False,
**kwargs
@@ -173,6 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
diffusion_prior.parameters(),
lr = lr,
wd = wd,
eps = eps,
**kwargs
)
@@ -221,8 +223,9 @@ class DecoderTrainer(nn.Module):
self,
decoder,
use_ema = True,
lr = 3e-4,
lr = 2e-5,
wd = 1e-2,
eps = 1e-8,
max_grad_norm = None,
amp = False,
**kwargs
@@ -247,13 +250,14 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay
# per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
**kwargs
)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.14',
version = '0.2.23',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',