Compare commits

..

7 Commits

4 changed files with 46 additions and 22 deletions

View File

@@ -1,7 +1,7 @@
import math import math
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial from functools import partial, wraps
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
@@ -45,6 +45,14 @@ def exists(val):
def identity(t, *args, **kwargs): def identity(t, *args, **kwargs):
return t return t
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d): def default(val, d):
if exists(val): if exists(val):
return val return val
@@ -114,10 +122,10 @@ def resize_image_to(image, target_image_size):
# ddpms expect images to be in the range of -1 to 1 # ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise # but CLIP may otherwise
def normalize_img(img): def normalize_neg_one_to_one(img):
return img * 2 - 1 return img * 2 - 1
def unnormalize_img(normed_img): def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5 return (normed_img + 1) * 0.5
# clip related adapters # clip related adapters
@@ -606,7 +614,6 @@ class Attention(nn.Module):
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
causal = False, causal = False,
post_norm = False,
rotary_emb = None rotary_emb = None
): ):
super().__init__() super().__init__()
@@ -616,7 +623,6 @@ class Attention(nn.Module):
self.causal = causal self.causal = causal
self.norm = LayerNorm(dim) self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -627,7 +633,7 @@ class Attention(nn.Module):
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False), nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity() LayerNorm(dim)
) )
def forward(self, x, mask = None, attn_bias = None): def forward(self, x, mask = None, attn_bias = None):
@@ -684,8 +690,7 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v) out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out) return self.to_out(out)
return self.post_norm(out)
class CausalTransformer(nn.Module): class CausalTransformer(nn.Module):
def __init__( def __init__(
@@ -711,7 +716,7 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb), Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
])) ]))
@@ -1037,7 +1042,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization' assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
if exists(image): if exists(image):
image_embed, _ = self.clip.embed_image(unnormalize_img(image)) image_embed, _ = self.clip.embed_image(image)
# calculate text conditionings, based on what is passed in # calculate text conditionings, based on what is passed in
@@ -1173,7 +1178,11 @@ class CrossAttention(nn.Module):
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None): def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device b, n, device = *x.shape[:2], x.device
@@ -1821,7 +1830,7 @@ class Decoder(BaseGaussianDiffusion):
# eq 15 - https://arxiv.org/abs/2102.09672 # eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape) min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape) max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_img(var_interp_frac_unnormalized) var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp() posterior_variance = posterior_log_variance.exp()
@@ -1844,6 +1853,8 @@ class Decoder(BaseGaussianDiffusion):
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample( img = self.p_sample(
unet, unet,
@@ -1859,11 +1870,19 @@ class Decoder(BaseGaussianDiffusion):
clip_denoised = clip_denoised clip_denoised = clip_denoised
) )
return img unnormalize_img = unnormalize_zero_to_one(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False): def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
model_output = unet( model_output = unet(
@@ -2011,7 +2030,7 @@ class Decoder(BaseGaussianDiffusion):
if not exists(image_embed): if not exists(image_embed):
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init' assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(unnormalize_img(image)) image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None text_encodings = text_mask = None
if exists(text) and not exists(text_encodings): if exists(text) and not exists(text_encodings):

View File

@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
def get_optimizer( def get_optimizer(
params, params,
lr = 3e-4, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False filter_by_requires_grad = False
): ):
if filter_by_requires_grad: if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params)) params = list(filter(lambda t: t.requires_grad, params))
if wd == 0: if wd == 0:
return Adam(params, lr = lr, betas = betas) return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params) params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params) wd_params, no_wd_params = separate_weight_decayable_params(params)
@@ -26,4 +27,4 @@ def get_optimizer(
{'params': list(no_wd_params), 'weight_decay': 0}, {'params': list(no_wd_params), 'weight_decay': 0},
] ]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas) return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -90,7 +90,7 @@ class EMA(nn.Module):
def __init__( def __init__(
self, self,
model, model,
beta = 0.99, beta = 0.9999,
update_after_step = 1000, update_after_step = 1000,
update_every = 10, update_every = 10,
): ):
@@ -147,6 +147,7 @@ class DiffusionPriorTrainer(nn.Module):
use_ema = True, use_ema = True,
lr = 3e-4, lr = 3e-4,
wd = 1e-2, wd = 1e-2,
eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
**kwargs **kwargs
@@ -173,6 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
diffusion_prior.parameters(), diffusion_prior.parameters(),
lr = lr, lr = lr,
wd = wd, wd = wd,
eps = eps,
**kwargs **kwargs
) )
@@ -221,8 +223,9 @@ class DecoderTrainer(nn.Module):
self, self,
decoder, decoder,
use_ema = True, use_ema = True,
lr = 3e-4, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
eps = 1e-8,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
**kwargs **kwargs
@@ -247,13 +250,14 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay # be able to finely customize learning rate, weight decay
# per unet # per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd)) lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)): for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizer = get_optimizer( optimizer = get_optimizer(
unet.parameters(), unet.parameters(),
lr = unet_lr, lr = unet_lr,
wd = unet_wd, wd = unet_wd,
eps = unet_eps,
**kwargs **kwargs
) )

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.15', version = '0.2.22',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',