|
|
|
|
@@ -1,7 +1,6 @@
|
|
|
|
|
import math
|
|
|
|
|
import random
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
from inspect import isfunction
|
|
|
|
|
from functools import partial, wraps
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from collections import namedtuple
|
|
|
|
|
@@ -12,7 +11,7 @@ import torch.nn.functional as F
|
|
|
|
|
from torch import nn, einsum
|
|
|
|
|
import torchvision.transforms as T
|
|
|
|
|
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
|
from einops import rearrange, repeat, reduce
|
|
|
|
|
from einops.layers.torch import Rearrange
|
|
|
|
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
|
|
|
|
from einops_exts.torch import EinopsToAndFrom
|
|
|
|
|
@@ -57,7 +56,7 @@ def maybe(fn):
|
|
|
|
|
def default(val, d):
|
|
|
|
|
if exists(val):
|
|
|
|
|
return val
|
|
|
|
|
return d() if isfunction(d) else d
|
|
|
|
|
return d() if callable(d) else d
|
|
|
|
|
|
|
|
|
|
def cast_tuple(val, length = 1):
|
|
|
|
|
if isinstance(val, list):
|
|
|
|
|
@@ -314,11 +313,6 @@ def extract(a, t, x_shape):
|
|
|
|
|
out = a.gather(-1, t)
|
|
|
|
|
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
|
|
|
|
|
|
|
|
|
def noise_like(shape, device, repeat=False):
|
|
|
|
|
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
|
|
|
|
noise = lambda: torch.randn(shape, device=device)
|
|
|
|
|
return repeat_noise() if repeat else noise()
|
|
|
|
|
|
|
|
|
|
def meanflat(x):
|
|
|
|
|
return x.mean(dim = tuple(range(1, len(x.shape))))
|
|
|
|
|
|
|
|
|
|
@@ -373,7 +367,7 @@ def quadratic_beta_schedule(timesteps):
|
|
|
|
|
scale = 1000 / timesteps
|
|
|
|
|
beta_start = scale * 0.0001
|
|
|
|
|
beta_end = scale * 0.02
|
|
|
|
|
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
|
|
|
|
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sigmoid_beta_schedule(timesteps):
|
|
|
|
|
@@ -385,7 +379,7 @@ def sigmoid_beta_schedule(timesteps):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseGaussianDiffusion(nn.Module):
|
|
|
|
|
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
|
|
|
|
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
if beta_schedule == "cosine":
|
|
|
|
|
@@ -450,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
|
|
|
|
|
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
|
|
|
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
|
|
|
|
|
|
# p2 loss reweighting
|
|
|
|
|
|
|
|
|
|
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
|
|
|
|
|
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
|
|
|
|
|
|
|
|
|
|
def q_posterior(self, x_start, x_t, t):
|
|
|
|
|
posterior_mean = (
|
|
|
|
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
|
|
|
@@ -946,10 +945,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
|
|
|
|
|
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
|
|
|
|
b, *_, device = *x.shape, x.device
|
|
|
|
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
|
|
|
|
noise = noise_like(x.shape, device, repeat_noise)
|
|
|
|
|
noise = torch.randn_like(x)
|
|
|
|
|
# no noise when t == 0
|
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
|
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
|
@@ -1428,6 +1427,7 @@ class Unet(nn.Module):
|
|
|
|
|
# for classifier free guidance
|
|
|
|
|
|
|
|
|
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
|
|
|
|
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
|
|
|
|
|
|
|
|
|
|
self.max_text_len = max_text_len
|
|
|
|
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
|
|
|
|
@@ -1565,19 +1565,28 @@ class Unet(nn.Module):
|
|
|
|
|
time_tokens = self.to_time_tokens(time_hiddens)
|
|
|
|
|
t = self.to_time_cond(time_hiddens)
|
|
|
|
|
|
|
|
|
|
# image embedding to be summed to time embedding
|
|
|
|
|
# discovered by @mhh0318 in the paper
|
|
|
|
|
|
|
|
|
|
if exists(image_embed) and exists(self.to_image_hiddens):
|
|
|
|
|
image_hiddens = self.to_image_hiddens(image_embed)
|
|
|
|
|
t = t + image_hiddens
|
|
|
|
|
|
|
|
|
|
# conditional dropout
|
|
|
|
|
|
|
|
|
|
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
|
|
|
|
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
|
|
|
|
|
|
|
|
|
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
|
|
|
|
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
|
|
|
|
|
|
|
|
|
|
# image embedding to be summed to time embedding
|
|
|
|
|
# discovered by @mhh0318 in the paper
|
|
|
|
|
|
|
|
|
|
if exists(image_embed) and exists(self.to_image_hiddens):
|
|
|
|
|
image_hiddens = self.to_image_hiddens(image_embed)
|
|
|
|
|
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
|
|
|
|
|
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
|
|
|
|
|
|
|
|
|
|
image_hiddens = torch.where(
|
|
|
|
|
image_keep_mask_hidden,
|
|
|
|
|
image_hiddens,
|
|
|
|
|
null_image_hiddens
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
t = t + image_hiddens
|
|
|
|
|
|
|
|
|
|
# mask out image embedding depending on condition dropout
|
|
|
|
|
# for classifier free guidance
|
|
|
|
|
@@ -1585,11 +1594,12 @@ class Unet(nn.Module):
|
|
|
|
|
image_tokens = None
|
|
|
|
|
|
|
|
|
|
if self.cond_on_image_embeds:
|
|
|
|
|
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
|
|
|
|
|
image_tokens = self.image_to_tokens(image_embed)
|
|
|
|
|
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
|
|
|
|
|
|
|
|
|
image_tokens = torch.where(
|
|
|
|
|
image_keep_mask,
|
|
|
|
|
image_keep_mask_embed,
|
|
|
|
|
image_tokens,
|
|
|
|
|
null_image_embed
|
|
|
|
|
)
|
|
|
|
|
@@ -1750,12 +1760,16 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
unconditional = False,
|
|
|
|
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
|
|
|
|
use_dynamic_thres = False, # from the Imagen paper
|
|
|
|
|
dynamic_thres_percentile = 0.9
|
|
|
|
|
dynamic_thres_percentile = 0.9,
|
|
|
|
|
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
|
|
|
|
p2_loss_weight_k = 1
|
|
|
|
|
):
|
|
|
|
|
super().__init__(
|
|
|
|
|
beta_schedule = beta_schedule,
|
|
|
|
|
timesteps = timesteps,
|
|
|
|
|
loss_type = loss_type
|
|
|
|
|
loss_type = loss_type,
|
|
|
|
|
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
|
|
|
|
p2_loss_weight_k = p2_loss_weight_k
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.unconditional = unconditional
|
|
|
|
|
@@ -1956,10 +1970,10 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
|
|
|
|
b, *_, device = *x.shape, x.device
|
|
|
|
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
|
|
|
|
noise = noise_like(x.shape, device, repeat_noise)
|
|
|
|
|
noise = torch.randn_like(x)
|
|
|
|
|
# no noise when t == 0
|
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
|
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
|
@@ -2023,7 +2037,13 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
|
|
|
|
|
target = noise if not predict_x_start else x_start
|
|
|
|
|
|
|
|
|
|
loss = self.loss_fn(pred, target)
|
|
|
|
|
loss = self.loss_fn(pred, target, reduction = 'none')
|
|
|
|
|
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
|
|
|
|
|
|
|
|
|
if self.has_p2_loss_reweighting:
|
|
|
|
|
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
|
|
|
|
|
|
|
|
|
|
loss = loss.mean()
|
|
|
|
|
|
|
|
|
|
if not learned_variance:
|
|
|
|
|
# return simple loss if not using learned variance
|
|
|
|
|
|