mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9549bd43b7 | ||
|
|
aee92dba4a | ||
|
|
b0cd5f24b6 | ||
|
|
b494ed81d4 | ||
|
|
ff3474f05c | ||
|
|
d5293f19f1 | ||
|
|
e697183849 | ||
|
|
591d37e266 | ||
|
|
d1f02e8f49 | ||
|
|
9faab59b23 | ||
|
|
5d27029e98 | ||
|
|
3115fa17b3 | ||
|
|
124d8577c8 | ||
|
|
2db0c9794c | ||
|
|
2277b47ffd | ||
|
|
28b58e568c | ||
|
|
924455d97d | ||
|
|
6021945fc8 | ||
|
|
6f76652d11 |
16
README.md
16
README.md
@@ -508,7 +508,7 @@ To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it i
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
|
||||
|
||||
# openai pretrained clip - defaults to ViT/B-32
|
||||
# openai pretrained clip - defaults to ViT-B/32
|
||||
|
||||
clip = OpenAIClipAdapter()
|
||||
|
||||
@@ -732,8 +732,8 @@ clip = CLIP(
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
text = torch.randint(0, 49408, (32, 256)).cuda()
|
||||
images = torch.randn(32, 3, 256, 256).cuda()
|
||||
|
||||
# decoder (with unet)
|
||||
|
||||
@@ -774,8 +774,12 @@ decoder_trainer = DecoderTrainer(
|
||||
)
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
||||
loss.backward()
|
||||
loss = decoder_trainer(
|
||||
images,
|
||||
text = text,
|
||||
unet_number = unet_number, # which unet to train on
|
||||
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
|
||||
)
|
||||
|
||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||
|
||||
@@ -839,7 +843,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
|
||||
)
|
||||
|
||||
loss = diffusion_prior_trainer(text, images)
|
||||
loss.backward()
|
||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||
|
||||
# after much of the above three lines in a loop
|
||||
@@ -1017,6 +1020,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
@@ -33,6 +33,10 @@ from rotary_embedding_torch import RotaryEmbedding
|
||||
from x_clip import CLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
# constants
|
||||
|
||||
NAT = 1. / math.log(2.)
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
@@ -41,6 +45,14 @@ def exists(val):
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def maybe(fn):
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
if not exists(x):
|
||||
return x
|
||||
return fn(x)
|
||||
return inner
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
@@ -91,6 +103,9 @@ def freeze_model_and_make_eval_(model):
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def log(t, eps = 1e-12):
|
||||
return torch.log(t.clamp(min = eps))
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
@@ -107,10 +122,10 @@ def resize_image_to(image, target_image_size):
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
# but CLIP may otherwise
|
||||
|
||||
def normalize_img(img):
|
||||
def normalize_neg_one_to_one(img):
|
||||
return img * 2 - 1
|
||||
|
||||
def unnormalize_img(normed_img):
|
||||
def unnormalize_zero_to_one(normed_img):
|
||||
return (normed_img + 1) * 0.5
|
||||
|
||||
# clip related adapters
|
||||
@@ -271,7 +286,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image = self.clip_normalize(image)
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
|
||||
@@ -297,6 +312,36 @@ def noise_like(shape, device, repeat=False):
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
def meanflat(x):
|
||||
return x.mean(dim = tuple(range(1, len(x.shape))))
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
|
||||
|
||||
def approx_standard_normal_cdf(x):
|
||||
return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
|
||||
|
||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
assert x.shape == means.shape == log_scales.shape
|
||||
|
||||
centered_x = x - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 1. / 255.)
|
||||
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - 1. / 255.)
|
||||
cdf_min = approx_standard_normal_cdf(min_in)
|
||||
log_cdf_plus = log(cdf_plus)
|
||||
log_one_minus_cdf_min = log(1. - cdf_min)
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
log_probs = torch.where(x < -thres,
|
||||
log_cdf_plus,
|
||||
torch.where(x > thres,
|
||||
log_one_minus_cdf_min,
|
||||
log(cdf_delta)))
|
||||
|
||||
return log_probs
|
||||
|
||||
def cosine_beta_schedule(timesteps, s = 0.008):
|
||||
"""
|
||||
cosine schedule
|
||||
@@ -398,12 +443,6 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
@@ -575,7 +614,6 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
post_norm = False,
|
||||
rotary_emb = None
|
||||
):
|
||||
super().__init__()
|
||||
@@ -585,7 +623,6 @@ class Attention(nn.Module):
|
||||
|
||||
self.causal = causal
|
||||
self.norm = LayerNorm(dim)
|
||||
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -596,7 +633,7 @@ class Attention(nn.Module):
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim) if post_norm else nn.Identity()
|
||||
LayerNorm(dim)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, attn_bias = None):
|
||||
@@ -653,8 +690,7 @@ class Attention(nn.Module):
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return self.post_norm(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class CausalTransformer(nn.Module):
|
||||
def __init__(
|
||||
@@ -680,7 +716,7 @@ class CausalTransformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
@@ -831,7 +867,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.,
|
||||
loss_type = "l1",
|
||||
loss_type = "l2",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
@@ -1127,6 +1163,7 @@ class CrossAttention(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
norm_context = False
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -1136,13 +1173,17 @@ class CrossAttention(nn.Module):
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.norm_context = LayerNorm(context_dim)
|
||||
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim)
|
||||
)
|
||||
|
||||
def forward(self, x, context, mask = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
@@ -1272,6 +1313,7 @@ class Unet(nn.Module):
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
channels_out = None,
|
||||
attn_dim_head = 32,
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
@@ -1302,6 +1344,7 @@ class Unet(nn.Module):
|
||||
# determine dimensions
|
||||
|
||||
self.channels = channels
|
||||
self.channels_out = default(channels_out, channels)
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
@@ -1336,6 +1379,9 @@ class Unet(nn.Module):
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||
|
||||
# text encoding conditioning (optional)
|
||||
|
||||
self.text_to_cond = None
|
||||
@@ -1407,11 +1453,9 @@ class Unet(nn.Module):
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
out_dim = default(out_dim, channels)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
nn.Conv2d(dim, self.channels_out, 1)
|
||||
)
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
@@ -1421,13 +1465,25 @@ class Unet(nn.Module):
|
||||
*,
|
||||
lowres_cond,
|
||||
channels,
|
||||
channels_out,
|
||||
cond_on_image_embeds,
|
||||
cond_on_text_encodings
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
|
||||
if lowres_cond == self.lowres_cond and \
|
||||
channels == self.channels and \
|
||||
cond_on_image_embeds == self.cond_on_image_embeds and \
|
||||
cond_on_text_encodings == self.cond_on_text_encodings and \
|
||||
channels_out == self.channels_out:
|
||||
return self
|
||||
|
||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
|
||||
updated_kwargs = dict(
|
||||
lowres_cond = lowres_cond,
|
||||
channels = channels,
|
||||
channels_out = channels_out,
|
||||
cond_on_image_embeds = cond_on_image_embeds,
|
||||
cond_on_text_encodings = cond_on_text_encodings
|
||||
)
|
||||
|
||||
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||
|
||||
def forward_with_cond_scale(
|
||||
@@ -1541,6 +1597,11 @@ class Unet(nn.Module):
|
||||
|
||||
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
||||
|
||||
# normalize conditioning tokens
|
||||
|
||||
c = self.norm_cond(c)
|
||||
mid_c = self.norm_mid_cond(mid_c)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
hiddens = []
|
||||
@@ -1614,7 +1675,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
timesteps = 1000,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
loss_type = 'l1',
|
||||
loss_type = 'l2',
|
||||
beta_schedule = 'cosine',
|
||||
predict_x_start = False,
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
@@ -1627,6 +1688,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict(),
|
||||
learned_variance = True,
|
||||
vb_loss_weight = 0.001,
|
||||
unconditional = False
|
||||
):
|
||||
super().__init__(
|
||||
@@ -1665,10 +1728,18 @@ class Decoder(BaseGaussianDiffusion):
|
||||
unets = cast_tuple(unet)
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
|
||||
|
||||
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
||||
self.learned_variance = learned_variance
|
||||
self.vb_loss_weight = vb_loss_weight
|
||||
|
||||
# construct unets and vaes
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
self.vaes = nn.ModuleList([])
|
||||
|
||||
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
|
||||
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
|
||||
assert isinstance(one_unet, Unet)
|
||||
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
|
||||
|
||||
@@ -1676,12 +1747,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
||||
|
||||
unet_channels = default(latent_dim, self.channels)
|
||||
unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first and not unconditional,
|
||||
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
||||
channels = unet_channels
|
||||
channels = unet_channels,
|
||||
channels_out = unet_channels_out
|
||||
)
|
||||
|
||||
self.unets.append(one_unet)
|
||||
@@ -1744,8 +1817,11 @@ class Decoder(BaseGaussianDiffusion):
|
||||
yield
|
||||
unet.cpu()
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
||||
|
||||
if learned_variance:
|
||||
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
||||
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
@@ -1756,24 +1832,38 @@ class Decoder(BaseGaussianDiffusion):
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
|
||||
if learned_variance:
|
||||
# if learned variance, posterio variance and posterior log variance are predicted by the network
|
||||
# by an interpolation of the max and min log beta values
|
||||
# eq 15 - https://arxiv.org/abs/2102.09672
|
||||
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||
|
||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||
posterior_variance = posterior_log_variance.exp()
|
||||
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.inference_mode()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.inference_mode()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
@@ -1785,17 +1875,26 @@ class Decoder(BaseGaussianDiffusion):
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = clip_denoised
|
||||
)
|
||||
|
||||
return img
|
||||
unnormalize_img = unnormalize_zero_to_one(img)
|
||||
return unnormalize_img
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
# normalize to [-1, 1]
|
||||
|
||||
x_start = normalize_neg_one_to_one(x_start)
|
||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||
|
||||
# get x_t
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
|
||||
pred = unet(
|
||||
model_output = unet(
|
||||
x_noisy,
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
@@ -1806,10 +1905,48 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = model_output.chunk(2, dim = 1)
|
||||
else:
|
||||
pred = model_output
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
if not learned_variance:
|
||||
# return simple loss if not using learned variance
|
||||
return loss
|
||||
|
||||
# most of the code below is transcribed from
|
||||
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
|
||||
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
|
||||
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
|
||||
|
||||
# if learning the variance, also include the extra weight kl loss
|
||||
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
|
||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||
|
||||
detached_model_mean = model_mean.detach()
|
||||
|
||||
kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
|
||||
kl = meanflat(kl) * NAT
|
||||
|
||||
decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
|
||||
decoder_nll = meanflat(decoder_nll) * NAT
|
||||
|
||||
# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||
|
||||
vb_losses = torch.where(times == 0, decoder_nll, kl)
|
||||
|
||||
# weight the vb loss smaller, for stability, as in the paper (recommended 0.001)
|
||||
|
||||
vb_loss = vb_losses.mean() * self.vb_loss_weight
|
||||
|
||||
return loss + vb_loss
|
||||
|
||||
@torch.inference_mode()
|
||||
@eval_decorator
|
||||
@@ -1836,7 +1973,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||
|
||||
@@ -1862,6 +1999,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
)
|
||||
@@ -1891,6 +2029,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
random_crop_size = self.random_crop_sizes[unet_index]
|
||||
learned_variance = self.learned_variance[unet_index]
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
check_shape(image, 'b c h w', c = self.channels)
|
||||
@@ -1928,7 +2067,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||
|
||||
# main class
|
||||
|
||||
@@ -1978,4 +2117,3 @@ class DALLE2(nn.Module):
|
||||
return images[0]
|
||||
|
||||
return images
|
||||
|
||||
|
||||
@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 3e-4,
|
||||
lr = 2e-5,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
eps = 1e-8,
|
||||
filter_by_requires_grad = False
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
params = list(filter(lambda t: t.requires_grad, params))
|
||||
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas)
|
||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||
|
||||
params = set(params)
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
@@ -26,4 +27,4 @@ def get_optimizer(
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
import copy
|
||||
from math import ceil
|
||||
from functools import partial
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
@@ -40,6 +45,47 @@ def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# gradient accumulation functions
|
||||
|
||||
def split_iterable(it, split_size):
|
||||
accum = []
|
||||
for ind in range(ceil(len(it) / split_size)):
|
||||
start_index = ind * split_size
|
||||
accum.append(it[start_index: (start_index + split_size)])
|
||||
return accum
|
||||
|
||||
def split(t, split_size = None):
|
||||
if not exists(split_size):
|
||||
return t
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.split(split_size, dim = 0)
|
||||
|
||||
if isinstance(t, Iterable):
|
||||
return split_iterable(t, split_size)
|
||||
|
||||
return TypeError
|
||||
|
||||
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
|
||||
batch_size = len(x)
|
||||
split_size = default(split_size, batch_size)
|
||||
chunk_size = ceil(batch_size / split_size)
|
||||
|
||||
dict_len = len(kwargs)
|
||||
dict_keys = kwargs.keys()
|
||||
all_args = (x, *args, *kwargs.values())
|
||||
len_all_args = len(all_args)
|
||||
split_kwargs_index = len_all_args - dict_len
|
||||
|
||||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
|
||||
chunk_sizes = tuple(map(len, split_all_args[0]))
|
||||
|
||||
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
||||
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
|
||||
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
|
||||
chunk_size_frac = chunk_size / batch_size
|
||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
@@ -90,7 +136,7 @@ class EMA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.99,
|
||||
beta = 0.9999,
|
||||
update_after_step = 1000,
|
||||
update_every = 10,
|
||||
):
|
||||
@@ -105,6 +151,10 @@ class EMA(nn.Module):
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def restore_ema_model_device(self):
|
||||
device = self.initted.device
|
||||
self.ema_model.to(device)
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
@@ -143,6 +193,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
**kwargs
|
||||
@@ -169,6 +220,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
diffusion_prior.parameters(),
|
||||
lr = lr,
|
||||
wd = wd,
|
||||
eps = eps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -202,13 +254,22 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*args,
|
||||
divisor = 1,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*args, **kwargs)
|
||||
return self.scaler.scale(loss / divisor)
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||
|
||||
loss = loss * chunk_size_frac
|
||||
total_loss += loss.item()
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
return total_loss
|
||||
|
||||
# decoder trainer
|
||||
|
||||
@@ -217,8 +278,9 @@ class DecoderTrainer(nn.Module):
|
||||
self,
|
||||
decoder,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
lr = 2e-5,
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
**kwargs
|
||||
@@ -243,13 +305,14 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
|
||||
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
|
||||
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -305,6 +368,11 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
if self.use_ema:
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
|
||||
# cast the ema_model unets back to original device
|
||||
for ema in self.ema_unets:
|
||||
ema.restore_ema_model_device()
|
||||
|
||||
return output
|
||||
|
||||
def forward(
|
||||
@@ -312,9 +380,17 @@ class DecoderTrainer(nn.Module):
|
||||
x,
|
||||
*,
|
||||
unet_number,
|
||||
divisor = 1,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||
return self.scale(loss / divisor, unet_number = unet_number)
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
|
||||
loss = loss * chunk_size_frac
|
||||
total_loss += loss.item()
|
||||
self.scale(loss, unet_number = unet_number).backward()
|
||||
|
||||
return total_loss
|
||||
|
||||
Reference in New Issue
Block a user