mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
1033 lines
35 KiB
Python
1033 lines
35 KiB
Python
import math
|
|
from tqdm import tqdm
|
|
from inspect import isfunction
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn, einsum
|
|
|
|
from einops import rearrange, repeat
|
|
from einops.layers.torch import Rearrange
|
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
|
from einops_exts.torch import EinopsToAndFrom
|
|
|
|
from kornia.filters import filter2d
|
|
|
|
from dalle2_pytorch.tokenizer import tokenizer
|
|
|
|
# use x-clip
|
|
|
|
from x_clip import CLIP
|
|
|
|
# helper functions
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
def default(val, d):
|
|
if exists(val):
|
|
return val
|
|
return d() if isfunction(d) else d
|
|
|
|
def eval_decorator(fn):
|
|
def inner(model, *args, **kwargs):
|
|
was_training = model.training
|
|
model.eval()
|
|
out = fn(model, *args, **kwargs)
|
|
model.train(was_training)
|
|
return out
|
|
return inner
|
|
|
|
def is_list_str(x):
|
|
if not isinstance(x, (list, tuple)):
|
|
return False
|
|
return all([type(el) == str for el in x])
|
|
|
|
# for controlling freezing of CLIP
|
|
|
|
def set_module_requires_grad_(module, requires_grad):
|
|
for param in module.parameters():
|
|
param.requires_grad = requires_grad
|
|
|
|
def freeze_all_layers_(module):
|
|
set_module_requires_grad_(module, False)
|
|
|
|
def unfreeze_all_layers_(module):
|
|
set_module_requires_grad_(module, True)
|
|
|
|
def freeze_model_and_make_eval_(model):
|
|
model.eval()
|
|
freeze_all_layers_(model)
|
|
|
|
# tensor helpers
|
|
|
|
def l2norm(t):
|
|
return F.normalize(t, dim = -1)
|
|
|
|
# classifier free guidance functions
|
|
|
|
def prob_mask_like(shape, prob, device):
|
|
if prob == 1:
|
|
return torch.ones(shape, device = device, dtype = torch.bool)
|
|
elif prob == 0:
|
|
return torch.zeros(shape, device = device, dtype = torch.bool)
|
|
else:
|
|
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
|
|
|
|
# gaussian diffusion helper functions
|
|
|
|
def extract(a, t, x_shape):
|
|
b, *_ = t.shape
|
|
out = a.gather(-1, t)
|
|
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
|
|
|
def noise_like(shape, device, repeat=False):
|
|
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
|
noise = lambda: torch.randn(shape, device=device)
|
|
return repeat_noise() if repeat else noise()
|
|
|
|
def cosine_beta_schedule(timesteps, s = 0.008):
|
|
"""
|
|
cosine schedule
|
|
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
|
"""
|
|
steps = timesteps + 1
|
|
x = torch.linspace(0, steps, steps)
|
|
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
|
return torch.clip(betas, 0, 0.999)
|
|
|
|
# diffusion prior
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dim, eps = 1e-5):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.scale = dim ** 0.5
|
|
self.gamma = nn.Parameter(torch.ones(dim))
|
|
|
|
def forward(self, x):
|
|
squared_sum = (x ** 2).sum(dim = -1, keepdim = True)
|
|
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
|
return x * inv_norm * self.gamma * self.scale
|
|
|
|
class ChanRMSNorm(RMSNorm):
|
|
def forward(self, x):
|
|
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
|
|
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
|
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
|
|
|
|
class Residual(nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x, **kwargs):
|
|
return self.fn(x, **kwargs) + x
|
|
|
|
# mlp
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim_in,
|
|
dim_out,
|
|
*,
|
|
expansion_factor = 2.,
|
|
depth = 2,
|
|
norm = False,
|
|
):
|
|
super().__init__()
|
|
hidden_dim = int(expansion_factor * dim_out)
|
|
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
|
|
|
|
layers = [nn.Sequential(
|
|
nn.Linear(dim_in, hidden_dim),
|
|
nn.SiLU(),
|
|
norm_fn()
|
|
)]
|
|
|
|
for _ in range(depth - 1):
|
|
layers.append(nn.Sequential(
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
nn.SiLU(),
|
|
norm_fn()
|
|
))
|
|
|
|
layers.append(nn.Linear(hidden_dim, dim_out))
|
|
self.net = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.net(x.float())
|
|
|
|
# feedforward
|
|
|
|
class SwiGLU(nn.Module):
|
|
""" used successfully in https://arxiv.org/abs/2204.0231 """
|
|
def forward(self, x):
|
|
x, gate = x.chunk(2, dim = -1)
|
|
return x * F.silu(gate)
|
|
|
|
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
|
""" post-activation norm https://arxiv.org/abs/2110.09456 """
|
|
|
|
inner_dim = int(mult * dim)
|
|
return nn.Sequential(
|
|
RMSNorm(dim),
|
|
nn.Linear(dim, inner_dim * 2, bias = False),
|
|
SwiGLU(),
|
|
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(inner_dim, dim, bias = False)
|
|
)
|
|
|
|
# attention
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
*,
|
|
dim_head = 64,
|
|
heads = 8,
|
|
dropout = 0.,
|
|
causal = False
|
|
):
|
|
super().__init__()
|
|
self.scale = dim_head ** -0.5
|
|
self.heads = heads
|
|
inner_dim = dim_head * heads
|
|
|
|
self.causal = causal
|
|
self.norm = RMSNorm(dim)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
|
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
|
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
|
|
|
def forward(self, x, mask = None):
|
|
b, n, device = *x.shape[:2], x.device
|
|
|
|
x = self.norm(x)
|
|
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
|
|
|
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
|
|
|
# add null key / value for classifier free guidance in prior net
|
|
|
|
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
|
|
k = torch.cat((nk, k), dim = -2)
|
|
v = torch.cat((nv, v), dim = -2)
|
|
|
|
q = q * self.scale
|
|
|
|
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
|
max_neg_value = -torch.finfo(sim.dtype).max
|
|
|
|
if exists(mask):
|
|
mask = F.pad(mask, (1, 0), value = True)
|
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
|
sim = sim.masked_fill(~mask, max_neg_value)
|
|
|
|
if self.causal:
|
|
i, j = sim.shape[-2:]
|
|
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
|
sim = sim.masked_fill(causal_mask, max_neg_value)
|
|
|
|
sim = sim - sim.amax(dim = -1, keepdim = True)
|
|
attn = sim.softmax(dim = -1)
|
|
attn = self.dropout(attn)
|
|
|
|
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
return self.to_out(out)
|
|
|
|
class CausalTransformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
depth,
|
|
dim_head = 64,
|
|
heads = 8,
|
|
ff_mult = 4,
|
|
norm_out = False,
|
|
attn_dropout = 0.,
|
|
ff_dropout = 0.
|
|
):
|
|
super().__init__()
|
|
# todo - bring in rotary embeddings or alibi
|
|
|
|
self.layers = nn.ModuleList([])
|
|
for _ in range(depth):
|
|
self.layers.append(nn.ModuleList([
|
|
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
|
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
|
]))
|
|
|
|
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
|
):
|
|
for attn, ff in self.layers:
|
|
x = attn(x, mask = mask) + x
|
|
x = ff(x) + x
|
|
|
|
return self.norm(x)
|
|
|
|
class DiffusionPriorNetwork(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
num_timesteps = None,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
|
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
|
|
|
def forward_with_cond_scale(
|
|
self,
|
|
*args,
|
|
cond_scale = 1.,
|
|
**kwargs
|
|
):
|
|
logits = self.forward(*args, **kwargs)
|
|
|
|
if cond_scale == 1:
|
|
return logits
|
|
|
|
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
|
return null_logits + (logits - null_logits) * cond_scale
|
|
|
|
def forward(
|
|
self,
|
|
image_embed,
|
|
diffusion_timesteps,
|
|
*,
|
|
text_encodings,
|
|
text_embed,
|
|
mask = None,
|
|
cond_drop_prob = 0.2
|
|
):
|
|
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
|
|
|
|
# in section 2.2, last paragraph
|
|
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
|
|
|
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
|
|
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
|
# but let's just do it right
|
|
|
|
if exists(mask):
|
|
not_all_masked_out = mask.any(dim = -1)
|
|
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
|
|
|
|
if exists(mask):
|
|
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
|
|
|
time_embed = self.time_embeddings(diffusion_timesteps)
|
|
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
|
|
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
|
|
|
tokens = torch.cat((
|
|
text_encodings,
|
|
text_embed,
|
|
time_embed,
|
|
learned_queries
|
|
), dim = -2)
|
|
|
|
# mask if it doesn't exist
|
|
|
|
if not exists(mask):
|
|
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
|
|
|
|
# classifier free guidance
|
|
|
|
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
|
mask &= rearrange(cond_prob_mask, 'b -> b 1')
|
|
|
|
# attend
|
|
|
|
tokens = self.causal_transformer(tokens, mask = mask)
|
|
|
|
# get learned query, which should predict the image embedding (per DDPM timestep)
|
|
|
|
pred_image_embed = tokens[..., -1, :]
|
|
|
|
return pred_image_embed
|
|
|
|
class DiffusionPrior(nn.Module):
|
|
def __init__(
|
|
self,
|
|
net,
|
|
*,
|
|
clip,
|
|
timesteps = 1000,
|
|
cond_drop_prob = 0.2,
|
|
loss_type = 'l1',
|
|
predict_x0 = True
|
|
):
|
|
super().__init__()
|
|
assert isinstance(clip, CLIP)
|
|
freeze_model_and_make_eval_(clip)
|
|
self.clip = clip
|
|
|
|
self.net = net
|
|
self.image_embed_dim = clip.dim_latent
|
|
self.channels = clip.image_channels
|
|
self.image_size = clip.image_size
|
|
self.cond_drop_prob = cond_drop_prob
|
|
|
|
self.predict_x0 = predict_x0
|
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
|
|
|
betas = cosine_beta_schedule(timesteps)
|
|
|
|
alphas = 1. - betas
|
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
|
|
|
timesteps, = betas.shape
|
|
self.num_timesteps = int(timesteps)
|
|
self.loss_type = loss_type
|
|
|
|
self.register_buffer('betas', betas)
|
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
|
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
|
|
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
|
|
self.register_buffer('posterior_variance', posterior_variance)
|
|
|
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
|
|
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
def get_image_embed(self, image):
|
|
image_encoding = self.clip.visual_transformer(image)
|
|
image_cls = image_encoding[:, 0]
|
|
image_embed = self.clip.to_visual_latent(image_cls)
|
|
return l2norm(image_embed)
|
|
|
|
def get_text_cond(self, text):
|
|
text_encodings = self.clip.text_transformer(text)
|
|
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
|
text_embed = self.clip.to_text_latent(text_cls)
|
|
text_embed = l2norm(text_embed)
|
|
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
|
|
|
def q_mean_variance(self, x_start, t):
|
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
|
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
|
return mean, variance, log_variance
|
|
|
|
def predict_start_from_noise(self, x_t, t, noise):
|
|
return (
|
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
|
)
|
|
|
|
def q_posterior(self, x_start, x_t, t):
|
|
posterior_mean = (
|
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
)
|
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
|
if self.predict_x0:
|
|
x_recon = self.net(x, t, **text_cond)
|
|
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
|
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
|
else:
|
|
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
|
|
|
|
if clip_denoised:
|
|
x_recon.clamp_(-1., 1.)
|
|
|
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
@torch.no_grad()
|
|
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False):
|
|
b, *_, device = *x.shape, x.device
|
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
|
|
noise = noise_like(x.shape, device, repeat_noise)
|
|
# no noise when t == 0
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
@torch.no_grad()
|
|
def p_sample_loop(self, shape, text_cond):
|
|
device = self.betas.device
|
|
|
|
b = shape[0]
|
|
img = torch.randn(shape, device=device)
|
|
|
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
|
return img
|
|
|
|
@torch.no_grad()
|
|
def sample(self, text, num_samples_per_batch = 2):
|
|
# in the paper, what they did was
|
|
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
|
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
|
|
|
batch_size = text.shape[0]
|
|
image_embed_dim = self.image_embed_dim
|
|
|
|
text_cond = self.get_text_cond(text)
|
|
|
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
|
text_embeds = text_cond['text_embed']
|
|
|
|
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
|
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
|
|
|
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
|
|
top_sim_indices = text_image_sims.topk(k = 1).indices
|
|
|
|
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
|
|
|
|
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
|
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
return (
|
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
)
|
|
|
|
def p_losses(self, image_embed, t, text_cond, noise = None):
|
|
noise = default(noise, lambda: torch.randn_like(image_embed))
|
|
|
|
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
|
|
|
|
x_recon = self.net(
|
|
image_embed_noisy,
|
|
t,
|
|
cond_drop_prob = self.cond_drop_prob,
|
|
**text_cond
|
|
)
|
|
|
|
to_predict = noise if not self.predict_x0 else image_embed
|
|
|
|
if self.loss_type == 'l1':
|
|
loss = F.l1_loss(to_predict, x_recon)
|
|
elif self.loss_type == 'l2':
|
|
loss = F.mse_loss(to_predict, x_recon)
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
return loss
|
|
|
|
def forward(self, text, image, *args, **kwargs):
|
|
b, device, img_size, = image.shape[0], image.device, self.image_size
|
|
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
|
|
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
|
image_embed = self.get_image_embed(image)
|
|
text_cond = self.get_text_cond(text)
|
|
|
|
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
|
return loss
|
|
|
|
# decoder
|
|
|
|
def Upsample(dim):
|
|
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
|
|
|
def Downsample(dim):
|
|
return nn.Conv2d(dim, dim, 4, 2, 1)
|
|
|
|
class Blur(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
filt = torch.Tensor([1, 2, 1])
|
|
self.register_buffer('filt', filt)
|
|
|
|
def forward(self, x):
|
|
filt = self.filt
|
|
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
|
return filter2d(x, filt, normalized = True)
|
|
|
|
class SinusoidalPosEmb(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.dim = dim
|
|
|
|
def forward(self, x):
|
|
half_dim = self.dim // 2
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
|
|
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
|
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
|
|
|
class ConvNextBlock(nn.Module):
|
|
""" https://arxiv.org/abs/2201.03545 """
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_out,
|
|
*,
|
|
cond_dim = None,
|
|
mult = 2,
|
|
norm = True
|
|
):
|
|
super().__init__()
|
|
need_projection = dim != dim_out
|
|
|
|
self.cross_attn = None
|
|
|
|
if exists(cond_dim):
|
|
self.cross_attn = EinopsToAndFrom(
|
|
'b c h w',
|
|
'b (h w) c',
|
|
CrossAttention(
|
|
dim = dim,
|
|
context_dim = cond_dim
|
|
)
|
|
)
|
|
|
|
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
|
|
|
inner_dim = int(dim_out * mult)
|
|
self.net = nn.Sequential(
|
|
ChanRMSNorm(dim) if norm else nn.Identity(),
|
|
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
|
nn.GELU(),
|
|
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
|
)
|
|
|
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
|
|
|
def forward(self, x, cond = None):
|
|
h = self.ds_conv(x)
|
|
|
|
if exists(self.cross_attn):
|
|
assert exists(cond)
|
|
h = self.cross_attn(h, context = cond) + h
|
|
|
|
h = self.net(h)
|
|
|
|
return h + self.res_conv(x)
|
|
|
|
class CrossAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
*,
|
|
context_dim = None,
|
|
dim_head = 64,
|
|
heads = 8,
|
|
dropout = 0.,
|
|
):
|
|
super().__init__()
|
|
self.scale = dim_head ** -0.5
|
|
self.heads = heads
|
|
inner_dim = dim_head * heads
|
|
|
|
context_dim = default(context_dim, dim)
|
|
|
|
self.norm = RMSNorm(dim)
|
|
self.norm_context = RMSNorm(context_dim)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
|
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
|
|
|
def forward(self, x, context, mask = None):
|
|
b, n, device = *x.shape[:2], x.device
|
|
|
|
x = self.norm(x)
|
|
context = self.norm_context(context)
|
|
|
|
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
|
|
|
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
|
|
|
|
# add null key / value for classifier free guidance in prior net
|
|
|
|
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
|
|
|
|
k = torch.cat((nk, k), dim = -2)
|
|
v = torch.cat((nv, v), dim = -2)
|
|
|
|
q = q * self.scale
|
|
|
|
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
|
max_neg_value = -torch.finfo(sim.dtype).max
|
|
|
|
if exists(mask):
|
|
mask = F.pad(mask, (1, 0), value = True)
|
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
|
sim = sim.masked_fill(~mask, max_neg_value)
|
|
|
|
sim = sim - sim.amax(dim = -1, keepdim = True)
|
|
attn = sim.softmax(dim = -1)
|
|
|
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
return self.to_out(out)
|
|
|
|
class Unet(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
*,
|
|
image_embed_dim,
|
|
cond_dim = None,
|
|
num_image_tokens = 4,
|
|
out_dim = None,
|
|
dim_mults=(1, 2, 4, 8),
|
|
channels = 3,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
|
|
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
|
in_out = list(zip(dims[:-1], dims[1:]))
|
|
|
|
# time and image embeddings
|
|
|
|
cond_dim = default(cond_dim, dim)
|
|
|
|
self.time_mlp = nn.Sequential(
|
|
SinusoidalPosEmb(dim),
|
|
nn.Linear(dim, dim * 4),
|
|
nn.GELU(),
|
|
nn.Linear(dim * 4, cond_dim),
|
|
Rearrange('b d -> b 1 d')
|
|
)
|
|
|
|
self.image_to_cond = nn.Sequential(
|
|
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
|
) if image_embed_dim != cond_dim else nn.Identity()
|
|
|
|
# for classifier free guidance
|
|
|
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
|
|
|
# layers
|
|
|
|
self.downs = nn.ModuleList([])
|
|
self.ups = nn.ModuleList([])
|
|
num_resolutions = len(in_out)
|
|
|
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
|
is_last = ind >= (num_resolutions - 1)
|
|
|
|
self.downs.append(nn.ModuleList([
|
|
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
|
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
|
|
Downsample(dim_out) if not is_last else nn.Identity()
|
|
]))
|
|
|
|
mid_dim = dims[-1]
|
|
|
|
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
|
|
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
|
|
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
|
is_last = ind >= (num_resolutions - 1)
|
|
|
|
self.ups.append(nn.ModuleList([
|
|
ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim),
|
|
ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim),
|
|
Upsample(dim_in) if not is_last else nn.Identity()
|
|
]))
|
|
|
|
out_dim = default(out_dim, channels)
|
|
self.final_conv = nn.Sequential(
|
|
ConvNextBlock(dim, dim),
|
|
nn.Conv2d(dim, out_dim, 1)
|
|
)
|
|
|
|
def forward_with_cond_scale(
|
|
self,
|
|
*args,
|
|
cond_scale = 1.,
|
|
**kwargs
|
|
):
|
|
logits = self.forward(*args, **kwargs)
|
|
|
|
if cond_scale == 1:
|
|
return logits
|
|
|
|
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
|
return null_logits + (logits - null_logits) * cond_scale
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
time,
|
|
*,
|
|
image_embed,
|
|
text_encodings = None,
|
|
cond_drop_prob = 0.
|
|
):
|
|
batch_size, device = x.shape[0], x.device
|
|
time_tokens = self.time_mlp(time)
|
|
|
|
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
|
|
|
# mask out image embedding depending on condition dropout
|
|
# for classifier free guidance
|
|
|
|
image_tokens = self.image_to_cond(image_embed)
|
|
|
|
image_tokens = torch.where(
|
|
rearrange(cond_prob_mask, 'b -> b 1 1'),
|
|
image_tokens,
|
|
self.null_image_embed
|
|
)
|
|
|
|
c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition
|
|
|
|
hiddens = []
|
|
|
|
for convnext, convnext2, downsample in self.downs:
|
|
x = convnext(x, c)
|
|
x = convnext2(x, c)
|
|
hiddens.append(x)
|
|
x = downsample(x)
|
|
|
|
x = self.mid_block1(x, c)
|
|
x = self.mid_attn(x)
|
|
x = self.mid_block2(x, c)
|
|
|
|
for convnext, convnext2, upsample in self.ups:
|
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
|
x = convnext(x, c)
|
|
x = convnext2(x, c)
|
|
x = upsample(x)
|
|
|
|
return self.final_conv(x)
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
net,
|
|
*,
|
|
clip,
|
|
timesteps = 1000,
|
|
cond_drop_prob = 0.2,
|
|
loss_type = 'l1'
|
|
):
|
|
super().__init__()
|
|
assert isinstance(clip, CLIP)
|
|
freeze_model_and_make_eval_(clip)
|
|
self.clip = clip
|
|
|
|
self.net = net
|
|
self.channels = clip.image_channels
|
|
self.image_size = clip.image_size
|
|
self.cond_drop_prob = cond_drop_prob
|
|
|
|
betas = cosine_beta_schedule(timesteps)
|
|
|
|
alphas = 1. - betas
|
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
|
|
|
timesteps, = betas.shape
|
|
self.num_timesteps = int(timesteps)
|
|
self.loss_type = loss_type
|
|
|
|
self.register_buffer('betas', betas)
|
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
|
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
|
|
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
|
|
self.register_buffer('posterior_variance', posterior_variance)
|
|
|
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
|
|
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
def get_image_embed(self, image):
|
|
image_encoding = self.clip.visual_transformer(image)
|
|
image_cls = image_encoding[:, 0]
|
|
image_embed = self.clip.to_visual_latent(image_cls)
|
|
return l2norm(image_embed)
|
|
|
|
def q_mean_variance(self, x_start, t):
|
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
|
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
|
return mean, variance, log_variance
|
|
|
|
def predict_start_from_noise(self, x_t, t, noise):
|
|
return (
|
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
|
)
|
|
|
|
def q_posterior(self, x_start, x_t, t):
|
|
posterior_mean = (
|
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
)
|
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
|
|
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
|
|
|
|
if clip_denoised:
|
|
x_recon.clamp_(-1., 1.)
|
|
|
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
@torch.no_grad()
|
|
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
|
b, *_, device = *x.shape, x.device
|
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
|
noise = noise_like(x.shape, device, repeat_noise)
|
|
# no noise when t == 0
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
@torch.no_grad()
|
|
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
|
|
device = self.betas.device
|
|
|
|
b = shape[0]
|
|
img = torch.randn(shape, device=device)
|
|
|
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
|
|
return img
|
|
|
|
@torch.no_grad()
|
|
def sample(self, image_embed, cond_scale = 1.):
|
|
batch_size = image_embed.shape[0]
|
|
image_size = self.image_size
|
|
channels = self.channels
|
|
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
return (
|
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
)
|
|
|
|
def p_losses(self, x_start, t, *, image_embed, noise = None):
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
|
|
|
x_recon = self.net(
|
|
x_noisy,
|
|
t,
|
|
image_embed = image_embed,
|
|
cond_drop_prob = self.cond_drop_prob
|
|
)
|
|
|
|
if self.loss_type == 'l1':
|
|
loss = F.l1_loss(noise, x_recon)
|
|
elif self.loss_type == 'l2':
|
|
loss = F.mse_loss(noise, x_recon)
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
return loss
|
|
|
|
def forward(self, image):
|
|
b, device, img_size, = image.shape[0], image.device, self.image_size
|
|
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
|
|
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
|
image_embed = self.get_image_embed(image)
|
|
|
|
loss = self.p_losses(image, times, image_embed = image_embed)
|
|
return loss
|
|
|
|
# main class
|
|
|
|
class DALLE2(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
prior,
|
|
decoder,
|
|
prior_num_samples = 2
|
|
):
|
|
super().__init__()
|
|
assert isinstance(prior, DiffusionPrior)
|
|
assert isinstance(decoder, Decoder)
|
|
self.prior = prior.eval()
|
|
self.decoder = decoder.eval()
|
|
self.prior_num_samples = prior_num_samples
|
|
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
text,
|
|
cond_scale = 1.
|
|
):
|
|
device = next(self.parameters()).device
|
|
|
|
if isinstance(text, str) or is_list_str(text):
|
|
text = [text] if not isinstance(text, (list, tuple)) else text
|
|
text = tokenizer.tokenize(text).to(device)
|
|
|
|
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
|
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
|
return images
|