Compare commits

...

12 Commits

4 changed files with 187 additions and 44 deletions

View File

@@ -1,7 +1,7 @@
import math import math
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial from functools import partial, wraps
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
@@ -33,6 +33,10 @@ from rotary_embedding_torch import RotaryEmbedding
from x_clip import CLIP from x_clip import CLIP
from coca_pytorch import CoCa from coca_pytorch import CoCa
# constants
NAT = 1. / math.log(2.)
# helper functions # helper functions
def exists(val): def exists(val):
@@ -41,6 +45,14 @@ def exists(val):
def identity(t, *args, **kwargs): def identity(t, *args, **kwargs):
return t return t
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d): def default(val, d):
if exists(val): if exists(val):
return val return val
@@ -91,6 +103,9 @@ def freeze_model_and_make_eval_(model):
# tensor helpers # tensor helpers
def log(t, eps = 1e-12):
return torch.log(t.clamp(min = eps))
def l2norm(t): def l2norm(t):
return F.normalize(t, dim = -1) return F.normalize(t, dim = -1)
@@ -107,10 +122,10 @@ def resize_image_to(image, target_image_size):
# ddpms expect images to be in the range of -1 to 1 # ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise # but CLIP may otherwise
def normalize_img(img): def normalize_neg_one_to_one(img):
return img * 2 - 1 return img * 2 - 1
def unnormalize_img(normed_img): def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5 return (normed_img + 1) * 0.5
# clip related adapters # clip related adapters
@@ -271,7 +286,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
def embed_image(self, image): def embed_image(self, image):
assert not self.cleared assert not self.cleared
image = resize_image_to(image, self.image_size) image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image)) image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image) image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None) return EmbeddedImage(l2norm(image_embed.float()), None)
@@ -297,6 +312,36 @@ def noise_like(shape, device, repeat=False):
noise = lambda: torch.randn(shape, device=device) noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise() return repeat_noise() if repeat else noise()
def meanflat(x):
return x.mean(dim = tuple(range(1, len(x.shape))))
def normal_kl(mean1, logvar1, mean2, logvar2):
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
def approx_standard_normal_cdf(x):
return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1. / 255.)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = log(cdf_plus)
log_one_minus_cdf_min = log(1. - cdf_min)
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(x < -thres,
log_cdf_plus,
torch.where(x > thres,
log_one_minus_cdf_min,
log(cdf_delta)))
return log_probs
def cosine_beta_schedule(timesteps, s = 0.008): def cosine_beta_schedule(timesteps, s = 0.008):
""" """
cosine schedule cosine schedule
@@ -398,12 +443,6 @@ class BaseGaussianDiffusion(nn.Module):
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def q_posterior(self, x_start, x_t, t): def q_posterior(self, x_start, x_t, t):
posterior_mean = ( posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
@@ -575,7 +614,6 @@ class Attention(nn.Module):
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
causal = False, causal = False,
post_norm = False,
rotary_emb = None rotary_emb = None
): ):
super().__init__() super().__init__()
@@ -585,7 +623,6 @@ class Attention(nn.Module):
self.causal = causal self.causal = causal
self.norm = LayerNorm(dim) self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -596,7 +633,7 @@ class Attention(nn.Module):
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False), nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity() LayerNorm(dim)
) )
def forward(self, x, mask = None, attn_bias = None): def forward(self, x, mask = None, attn_bias = None):
@@ -653,8 +690,7 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v) out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out) return self.to_out(out)
return self.post_norm(out)
class CausalTransformer(nn.Module): class CausalTransformer(nn.Module):
def __init__( def __init__(
@@ -680,7 +716,7 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb), Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
])) ]))
@@ -1142,7 +1178,11 @@ class CrossAttention(nn.Module):
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None): def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device b, n, device = *x.shape[:2], x.device
@@ -1272,6 +1312,7 @@ class Unet(nn.Module):
out_dim = None, out_dim = None,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
channels = 3, channels = 3,
channels_out = None,
attn_dim_head = 32, attn_dim_head = 32,
attn_heads = 16, attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
@@ -1302,6 +1343,7 @@ class Unet(nn.Module):
# determine dimensions # determine dimensions
self.channels = channels self.channels = channels
self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 3 * 2) init_dim = default(init_dim, dim // 3 * 2)
@@ -1407,11 +1449,9 @@ class Unet(nn.Module):
Upsample(dim_in) Upsample(dim_in)
])) ]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
ResnetBlock(dim, dim, groups = resnet_groups[0]), ResnetBlock(dim, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, out_dim, 1) nn.Conv2d(dim, self.channels_out, 1)
) )
# if the current settings for the unet are not correct # if the current settings for the unet are not correct
@@ -1421,13 +1461,25 @@ class Unet(nn.Module):
*, *,
lowres_cond, lowres_cond,
channels, channels,
channels_out,
cond_on_image_embeds, cond_on_image_embeds,
cond_on_text_encodings cond_on_text_encodings
): ):
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings: if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \
channels_out == self.channels_out:
return self return self
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings} updated_kwargs = dict(
lowres_cond = lowres_cond,
channels = channels,
channels_out = channels_out,
cond_on_image_embeds = cond_on_image_embeds,
cond_on_text_encodings = cond_on_text_encodings
)
return self.__class__(**{**self._locals, **updated_kwargs}) return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale( def forward_with_cond_scale(
@@ -1627,6 +1679,8 @@ class Decoder(BaseGaussianDiffusion):
clip_denoised = True, clip_denoised = True,
clip_x_start = True, clip_x_start = True,
clip_adapter_overrides = dict(), clip_adapter_overrides = dict(),
learned_variance = True,
vb_loss_weight = 0.001,
unconditional = False unconditional = False
): ):
super().__init__( super().__init__(
@@ -1665,10 +1719,18 @@ class Decoder(BaseGaussianDiffusion):
unets = cast_tuple(unet) unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels)) vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
self.learned_variance = learned_variance
self.vb_loss_weight = vb_loss_weight
# construct unets and vaes
self.unets = nn.ModuleList([]) self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([]) self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)): for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
assert isinstance(one_unet, Unet) assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE)) assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
@@ -1676,12 +1738,14 @@ class Decoder(BaseGaussianDiffusion):
latent_dim = one_vae.encoded_dim if exists(one_vae) else None latent_dim = one_vae.encoded_dim if exists(one_vae) else None
unet_channels = default(latent_dim, self.channels) unet_channels = default(latent_dim, self.channels)
unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)
one_unet = one_unet.cast_model_parameters( one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first, lowres_cond = not is_first,
cond_on_image_embeds = is_first and not unconditional, cond_on_image_embeds = is_first and not unconditional,
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional, cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
channels = unet_channels channels = unet_channels,
channels_out = unet_channels_out
) )
self.unets.append(one_unet) self.unets.append(one_unet)
@@ -1744,8 +1808,11 @@ class Decoder(BaseGaussianDiffusion):
yield yield
unet.cpu() unet.cpu()
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
if predict_x_start: if predict_x_start:
x_recon = pred x_recon = pred
@@ -1756,24 +1823,38 @@ class Decoder(BaseGaussianDiffusion):
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network
# by an interpolation of the max and min log beta values
# eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode() @torch.inference_mode()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False): def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode() @torch.inference_mode()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample( img = self.p_sample(
unet, unet,
@@ -1785,17 +1866,26 @@ class Decoder(BaseGaussianDiffusion):
cond_scale = cond_scale, cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = clip_denoised clip_denoised = clip_denoised
) )
return img unnormalize_img = unnormalize_zero_to_one(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None): def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
pred = unet( model_output = unet(
x_noisy, x_noisy,
times, times,
image_embed = image_embed, image_embed = image_embed,
@@ -1806,10 +1896,48 @@ class Decoder(BaseGaussianDiffusion):
text_cond_drop_prob = self.text_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob,
) )
if learned_variance:
pred, _ = model_output.chunk(2, dim = 1)
else:
pred = model_output
target = noise if not predict_x_start else x_start target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target) loss = self.loss_fn(pred, target)
return loss
if not learned_variance:
# return simple loss if not using learned variance
return loss
# most of the code below is transcribed from
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper
detached_model_mean = model_mean.detach()
kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
kl = meanflat(kl) * NAT
decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
decoder_nll = meanflat(decoder_nll) * NAT
# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
vb_losses = torch.where(times == 0, decoder_nll, kl)
# weight the vb loss smaller, for stability, as in the paper (recommended 0.001)
vb_loss = vb_losses.mean() * self.vb_loss_weight
return loss + vb_loss
@torch.inference_mode() @torch.inference_mode()
@eval_decorator @eval_decorator
@@ -1836,7 +1964,7 @@ class Decoder(BaseGaussianDiffusion):
img = None img = None
is_cuda = next(self.parameters()).is_cuda is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context() context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
@@ -1862,6 +1990,7 @@ class Decoder(BaseGaussianDiffusion):
text_mask = text_mask, text_mask = text_mask,
cond_scale = cond_scale, cond_scale = cond_scale,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion, clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img lowres_cond_img = lowres_cond_img
) )
@@ -1891,6 +2020,7 @@ class Decoder(BaseGaussianDiffusion):
target_image_size = self.image_sizes[unet_index] target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index] predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index] random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels) check_shape(image, 'b c h w', c = self.channels)
@@ -1928,7 +2058,7 @@ class Decoder(BaseGaussianDiffusion):
if exists(lowres_cond_img): if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img) lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start) return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
# main class # main class
@@ -1978,4 +2108,3 @@ class DALLE2(nn.Module):
return images[0] return images[0]
return images return images

View File

@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
def get_optimizer( def get_optimizer(
params, params,
lr = 3e-4, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False filter_by_requires_grad = False
): ):
if filter_by_requires_grad: if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params)) params = list(filter(lambda t: t.requires_grad, params))
if wd == 0: if wd == 0:
return Adam(params, lr = lr, betas = betas) return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params) params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params) wd_params, no_wd_params = separate_weight_decayable_params(params)
@@ -26,4 +27,4 @@ def get_optimizer(
{'params': list(no_wd_params), 'weight_decay': 0}, {'params': list(no_wd_params), 'weight_decay': 0},
] ]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas) return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -90,7 +90,7 @@ class EMA(nn.Module):
def __init__( def __init__(
self, self,
model, model,
beta = 0.99, beta = 0.9999,
update_after_step = 1000, update_after_step = 1000,
update_every = 10, update_every = 10,
): ):
@@ -105,6 +105,10 @@ class EMA(nn.Module):
self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def update(self): def update(self):
self.step += 1 self.step += 1
@@ -143,6 +147,7 @@ class DiffusionPriorTrainer(nn.Module):
use_ema = True, use_ema = True,
lr = 3e-4, lr = 3e-4,
wd = 1e-2, wd = 1e-2,
eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
**kwargs **kwargs
@@ -169,6 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
diffusion_prior.parameters(), diffusion_prior.parameters(),
lr = lr, lr = lr,
wd = wd, wd = wd,
eps = eps,
**kwargs **kwargs
) )
@@ -217,8 +223,9 @@ class DecoderTrainer(nn.Module):
self, self,
decoder, decoder,
use_ema = True, use_ema = True,
lr = 3e-4, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
eps = 1e-8,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
**kwargs **kwargs
@@ -243,13 +250,14 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay # be able to finely customize learning rate, weight decay
# per unet # per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd)) lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)): for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizer = get_optimizer( optimizer = get_optimizer(
unet.parameters(), unet.parameters(),
lr = unet_lr, lr = unet_lr,
wd = unet_wd, wd = unet_wd,
eps = unet_eps,
**kwargs **kwargs
) )
@@ -305,6 +313,11 @@ class DecoderTrainer(nn.Module):
if self.use_ema: if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
ema.restore_ema_model_device()
return output return output
def forward( def forward(

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.11', version = '0.2.22',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',