Compare commits

...

7 Commits

4 changed files with 40 additions and 16 deletions

View File

@@ -1,7 +1,7 @@
import math import math
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial from functools import partial, wraps
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
@@ -45,6 +45,14 @@ def exists(val):
def identity(t, *args, **kwargs): def identity(t, *args, **kwargs):
return t return t
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d): def default(val, d):
if exists(val): if exists(val):
return val return val
@@ -114,10 +122,10 @@ def resize_image_to(image, target_image_size):
# ddpms expect images to be in the range of -1 to 1 # ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise # but CLIP may otherwise
def normalize_img(img): def normalize_neg_one_to_one(img):
return img * 2 - 1 return img * 2 - 1
def unnormalize_img(normed_img): def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5 return (normed_img + 1) * 0.5
# clip related adapters # clip related adapters
@@ -278,7 +286,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
def embed_image(self, image): def embed_image(self, image):
assert not self.cleared assert not self.cleared
image = resize_image_to(image, self.image_size) image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image)) image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image) image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None) return EmbeddedImage(l2norm(image_embed.float()), None)
@@ -606,7 +614,6 @@ class Attention(nn.Module):
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
causal = False, causal = False,
post_norm = False,
rotary_emb = None rotary_emb = None
): ):
super().__init__() super().__init__()
@@ -616,7 +623,6 @@ class Attention(nn.Module):
self.causal = causal self.causal = causal
self.norm = LayerNorm(dim) self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -627,7 +633,7 @@ class Attention(nn.Module):
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False), nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity() LayerNorm(dim)
) )
def forward(self, x, mask = None, attn_bias = None): def forward(self, x, mask = None, attn_bias = None):
@@ -684,8 +690,7 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v) out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out) return self.to_out(out)
return self.post_norm(out)
class CausalTransformer(nn.Module): class CausalTransformer(nn.Module):
def __init__( def __init__(
@@ -711,7 +716,7 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb), Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
])) ]))
@@ -1173,7 +1178,11 @@ class CrossAttention(nn.Module):
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None): def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device b, n, device = *x.shape[:2], x.device
@@ -1821,7 +1830,7 @@ class Decoder(BaseGaussianDiffusion):
# eq 15 - https://arxiv.org/abs/2102.09672 # eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape) min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape) max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_img(var_interp_frac_unnormalized) var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp() posterior_variance = posterior_log_variance.exp()
@@ -1844,6 +1853,8 @@ class Decoder(BaseGaussianDiffusion):
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample( img = self.p_sample(
unet, unet,
@@ -1859,11 +1870,19 @@ class Decoder(BaseGaussianDiffusion):
clip_denoised = clip_denoised clip_denoised = clip_denoised
) )
return img unnormalize_img = unnormalize_zero_to_one(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False): def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
model_output = unet( model_output = unet(
@@ -1890,6 +1909,11 @@ class Decoder(BaseGaussianDiffusion):
# return simple loss if not using learned variance # return simple loss if not using learned variance
return loss return loss
# most of the code below is transcribed from
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
# if learning the variance, also include the extra weight kl loss # if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times) true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)

View File

@@ -7,7 +7,7 @@ def separate_weight_decayable_params(params):
def get_optimizer( def get_optimizer(
params, params,
lr = 3e-4, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.999),
filter_by_requires_grad = False filter_by_requires_grad = False

View File

@@ -221,7 +221,7 @@ class DecoderTrainer(nn.Module):
self, self,
decoder, decoder,
use_ema = True, use_ema = True,
lr = 3e-4, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.14', version = '0.2.20',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',