|
|
|
|
@@ -1,8 +1,6 @@
|
|
|
|
|
import math
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
from inspect import isfunction
|
|
|
|
|
from functools import partial
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
@@ -13,7 +11,7 @@ from einops.layers.torch import Rearrange
|
|
|
|
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
|
|
|
|
from einops_exts.torch import EinopsToAndFrom
|
|
|
|
|
|
|
|
|
|
from kornia.filters import gaussian_blur2d
|
|
|
|
|
from kornia.filters import filter2d
|
|
|
|
|
|
|
|
|
|
from dalle2_pytorch.tokenizer import tokenizer
|
|
|
|
|
|
|
|
|
|
@@ -31,9 +29,6 @@ def default(val, d):
|
|
|
|
|
return val
|
|
|
|
|
return d() if isfunction(d) else d
|
|
|
|
|
|
|
|
|
|
def cast_tuple(val, length = 1):
|
|
|
|
|
return val if isinstance(val, tuple) else ((val,) * length)
|
|
|
|
|
|
|
|
|
|
def eval_decorator(fn):
|
|
|
|
|
def inner(model, *args, **kwargs):
|
|
|
|
|
was_training = model.training
|
|
|
|
|
@@ -69,15 +64,6 @@ def freeze_model_and_make_eval_(model):
|
|
|
|
|
def l2norm(t):
|
|
|
|
|
return F.normalize(t, dim = -1)
|
|
|
|
|
|
|
|
|
|
def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://github.com/assafshocher/ResizeRight
|
|
|
|
|
shape = cast_tuple(image_size, 2)
|
|
|
|
|
orig_image_size = t.shape[-2:]
|
|
|
|
|
|
|
|
|
|
if orig_image_size == shape:
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
|
return F.interpolate(t, size = shape, mode = mode)
|
|
|
|
|
|
|
|
|
|
# classifier free guidance functions
|
|
|
|
|
|
|
|
|
|
def prob_mask_like(shape, prob, device):
|
|
|
|
|
@@ -106,35 +92,12 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
|
|
|
|
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
|
|
|
|
"""
|
|
|
|
|
steps = timesteps + 1
|
|
|
|
|
x = torch.linspace(0, timesteps, steps)
|
|
|
|
|
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
|
|
|
|
x = torch.linspace(0, steps, steps)
|
|
|
|
|
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
|
|
|
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
|
|
|
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
|
|
|
|
return torch.clip(betas, 0, 0.999)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_beta_schedule(timesteps):
|
|
|
|
|
scale = 1000 / timesteps
|
|
|
|
|
beta_start = scale * 0.0001
|
|
|
|
|
beta_end = scale * 0.02
|
|
|
|
|
return torch.linspace(beta_start, beta_end, timesteps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def quadratic_beta_schedule(timesteps):
|
|
|
|
|
scale = 1000 / timesteps
|
|
|
|
|
beta_start = scale * 0.0001
|
|
|
|
|
beta_end = scale * 0.02
|
|
|
|
|
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sigmoid_beta_schedule(timesteps):
|
|
|
|
|
scale = 1000 / timesteps
|
|
|
|
|
beta_start = scale * 0.0001
|
|
|
|
|
beta_end = scale * 0.02
|
|
|
|
|
betas = torch.linspace(-6, 6, timesteps)
|
|
|
|
|
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# diffusion prior
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
|
|
@@ -464,11 +427,10 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
net,
|
|
|
|
|
*,
|
|
|
|
|
clip,
|
|
|
|
|
timesteps=1000,
|
|
|
|
|
cond_drop_prob=0.2,
|
|
|
|
|
loss_type="l1",
|
|
|
|
|
predict_x0=True,
|
|
|
|
|
beta_schedule="cosine",
|
|
|
|
|
timesteps = 1000,
|
|
|
|
|
cond_drop_prob = 0.2,
|
|
|
|
|
loss_type = 'l1',
|
|
|
|
|
predict_x0 = True
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert isinstance(clip, CLIP)
|
|
|
|
|
@@ -484,18 +446,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
self.predict_x0 = predict_x0
|
|
|
|
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
|
|
|
|
|
|
|
|
|
if beta_schedule == "cosine":
|
|
|
|
|
betas = cosine_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "linear":
|
|
|
|
|
betas = linear_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "quadratic":
|
|
|
|
|
betas = quadratic_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "jsd":
|
|
|
|
|
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
|
|
|
|
elif beta_schedule == "sigmoid":
|
|
|
|
|
betas = sigmoid_beta_schedule(timesteps)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
betas = cosine_beta_schedule(timesteps)
|
|
|
|
|
|
|
|
|
|
alphas = 1. - betas
|
|
|
|
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
|
|
|
|
@@ -599,6 +550,31 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def sample(self, text, num_samples_per_batch = 2):
|
|
|
|
|
# in the paper, what they did was
|
|
|
|
|
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
|
|
|
|
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
|
|
|
|
|
|
|
|
|
batch_size = text.shape[0]
|
|
|
|
|
image_embed_dim = self.image_embed_dim
|
|
|
|
|
|
|
|
|
|
text_cond = self.get_text_cond(text)
|
|
|
|
|
|
|
|
|
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
|
|
|
|
text_embeds = text_cond['text_embed']
|
|
|
|
|
|
|
|
|
|
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
|
|
|
|
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
|
|
|
|
|
|
|
|
|
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
|
|
|
|
|
top_sim_indices = text_image_sims.topk(k = 1).indices
|
|
|
|
|
|
|
|
|
|
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
|
|
|
|
|
|
|
|
|
|
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
|
|
|
|
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
|
|
|
|
|
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
@@ -625,39 +601,11 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
loss = F.l1_loss(to_predict, x_recon)
|
|
|
|
|
elif self.loss_type == 'l2':
|
|
|
|
|
loss = F.mse_loss(to_predict, x_recon)
|
|
|
|
|
elif self.loss_type == "huber":
|
|
|
|
|
loss = F.smooth_l1_loss(to_predict, x_recon)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
@eval_decorator
|
|
|
|
|
def sample(self, text, num_samples_per_batch = 2):
|
|
|
|
|
# in the paper, what they did was
|
|
|
|
|
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
|
|
|
|
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
|
|
|
|
|
|
|
|
|
batch_size = text.shape[0]
|
|
|
|
|
image_embed_dim = self.image_embed_dim
|
|
|
|
|
|
|
|
|
|
text_cond = self.get_text_cond(text)
|
|
|
|
|
|
|
|
|
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
|
|
|
|
text_embeds = text_cond['text_embed']
|
|
|
|
|
|
|
|
|
|
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
|
|
|
|
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
|
|
|
|
|
|
|
|
|
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
|
|
|
|
|
top_sim_indices = text_image_sims.topk(k = 1).indices
|
|
|
|
|
|
|
|
|
|
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
|
|
|
|
|
|
|
|
|
|
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
|
|
|
|
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
|
|
|
|
|
|
|
|
|
def forward(self, text, image, *args, **kwargs):
|
|
|
|
|
b, device, img_size, = image.shape[0], image.device, self.image_size
|
|
|
|
|
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
|
|
|
|
@@ -677,6 +625,17 @@ def Upsample(dim):
|
|
|
|
|
def Downsample(dim):
|
|
|
|
|
return nn.Conv2d(dim, dim, 4, 2, 1)
|
|
|
|
|
|
|
|
|
|
class Blur(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
filt = torch.Tensor([1, 2, 1])
|
|
|
|
|
self.register_buffer('filt', filt)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
filt = self.filt
|
|
|
|
|
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
|
|
|
|
return filter2d(x, filt, normalized = True)
|
|
|
|
|
|
|
|
|
|
class SinusoidalPosEmb(nn.Module):
|
|
|
|
|
def __init__(self, dim):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -799,20 +758,6 @@ class CrossAttention(nn.Module):
|
|
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
|
|
|
return self.to_out(out)
|
|
|
|
|
|
|
|
|
|
class GridAttention(nn.Module):
|
|
|
|
|
def __init__(self, *args, window_size = 8, **kwargs):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
self.attn = Attention(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
h, w = x.shape[-2:]
|
|
|
|
|
wsz = self.window_size
|
|
|
|
|
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
|
|
|
|
|
out = self.attn(x)
|
|
|
|
|
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
class Unet(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
@@ -821,41 +766,14 @@ class Unet(nn.Module):
|
|
|
|
|
image_embed_dim,
|
|
|
|
|
cond_dim = None,
|
|
|
|
|
num_image_tokens = 4,
|
|
|
|
|
num_time_tokens = 2,
|
|
|
|
|
out_dim = None,
|
|
|
|
|
dim_mults=(1, 2, 4, 8),
|
|
|
|
|
channels = 3,
|
|
|
|
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
|
|
|
|
lowres_cond_upsample_mode = 'bilinear',
|
|
|
|
|
blur_sigma = 0.1,
|
|
|
|
|
blur_kernel_size = 3,
|
|
|
|
|
sparse_attn = False,
|
|
|
|
|
sparse_attn_window = 8, # window size for sparse attention
|
|
|
|
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
|
|
|
|
cond_on_text_encodings = False,
|
|
|
|
|
cond_on_image_embeds = False,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
# save locals to take care of some hyperparameters for cascading DDPM
|
|
|
|
|
|
|
|
|
|
self._locals = locals()
|
|
|
|
|
del self._locals['self']
|
|
|
|
|
del self._locals['__class__']
|
|
|
|
|
|
|
|
|
|
# for eventual cascading diffusion
|
|
|
|
|
|
|
|
|
|
self.lowres_cond = lowres_cond
|
|
|
|
|
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
|
|
|
|
self.lowres_blur_kernel_size = blur_kernel_size
|
|
|
|
|
self.lowres_blur_sigma = blur_sigma
|
|
|
|
|
|
|
|
|
|
# determine dimensions
|
|
|
|
|
|
|
|
|
|
self.channels = channels
|
|
|
|
|
|
|
|
|
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
|
|
|
|
|
|
|
|
|
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
|
|
|
|
|
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
|
|
|
|
in_out = list(zip(dims[:-1], dims[1:]))
|
|
|
|
|
|
|
|
|
|
# time, image embeddings, and optional text encoding
|
|
|
|
|
@@ -866,8 +784,8 @@ class Unet(nn.Module):
|
|
|
|
|
SinusoidalPosEmb(dim),
|
|
|
|
|
nn.Linear(dim, dim * 4),
|
|
|
|
|
nn.GELU(),
|
|
|
|
|
nn.Linear(dim * 4, cond_dim * num_time_tokens),
|
|
|
|
|
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
|
|
|
|
nn.Linear(dim * 4, cond_dim),
|
|
|
|
|
Rearrange('b d -> b 1 d')
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.image_to_cond = nn.Sequential(
|
|
|
|
|
@@ -877,12 +795,6 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.text_to_cond = nn.LazyLinear(cond_dim)
|
|
|
|
|
|
|
|
|
|
# finer control over whether to condition on image embeddings and text encodings
|
|
|
|
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
|
|
|
|
|
|
|
|
|
self.cond_on_text_encodings = cond_on_text_encodings
|
|
|
|
|
self.cond_on_image_embeds = cond_on_image_embeds
|
|
|
|
|
|
|
|
|
|
# for classifier free guidance
|
|
|
|
|
|
|
|
|
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
|
|
|
|
@@ -895,32 +807,27 @@ class Unet(nn.Module):
|
|
|
|
|
num_resolutions = len(in_out)
|
|
|
|
|
|
|
|
|
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
|
|
|
|
is_first = ind == 0
|
|
|
|
|
is_last = ind >= (num_resolutions - 1)
|
|
|
|
|
layer_cond_dim = cond_dim if not is_first else None
|
|
|
|
|
|
|
|
|
|
self.downs.append(nn.ModuleList([
|
|
|
|
|
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
|
|
|
|
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
|
|
|
|
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
|
|
|
|
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
|
|
|
|
|
Downsample(dim_out) if not is_last else nn.Identity()
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
mid_dim = dims[-1]
|
|
|
|
|
|
|
|
|
|
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
|
|
|
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
|
|
|
|
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
|
|
|
|
|
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
|
|
|
|
|
|
|
|
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
|
|
|
|
is_last = ind >= (num_resolutions - 2)
|
|
|
|
|
layer_cond_dim = cond_dim if not is_last else None
|
|
|
|
|
is_last = ind >= (num_resolutions - 1)
|
|
|
|
|
|
|
|
|
|
self.ups.append(nn.ModuleList([
|
|
|
|
|
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
|
|
|
|
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
|
|
|
|
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
|
|
|
|
Upsample(dim_in)
|
|
|
|
|
ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim),
|
|
|
|
|
ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim),
|
|
|
|
|
Upsample(dim_in) if not is_last else nn.Identity()
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
out_dim = default(out_dim, channels)
|
|
|
|
|
@@ -929,15 +836,6 @@ class Unet(nn.Module):
|
|
|
|
|
nn.Conv2d(dim, out_dim, 1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# if the current settings for the unet are not correct
|
|
|
|
|
# for cascading DDPM, then reinit the unet with the right settings
|
|
|
|
|
def force_lowres_cond(self, lowres_cond):
|
|
|
|
|
if lowres_cond == self.lowres_cond:
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
|
|
|
|
|
return self.__class__(**updated_kwargs)
|
|
|
|
|
|
|
|
|
|
def forward_with_cond_scale(
|
|
|
|
|
self,
|
|
|
|
|
*args,
|
|
|
|
|
@@ -958,56 +856,29 @@ class Unet(nn.Module):
|
|
|
|
|
time,
|
|
|
|
|
*,
|
|
|
|
|
image_embed,
|
|
|
|
|
lowres_cond_img = None,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
cond_drop_prob = 0.,
|
|
|
|
|
blur_sigma = None,
|
|
|
|
|
blur_kernel_size = None
|
|
|
|
|
cond_drop_prob = 0.
|
|
|
|
|
):
|
|
|
|
|
batch_size, device = x.shape[0], x.device
|
|
|
|
|
|
|
|
|
|
# add low resolution conditioning, if present
|
|
|
|
|
|
|
|
|
|
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
|
|
|
|
|
|
|
|
|
if exists(lowres_cond_img):
|
|
|
|
|
if self.training:
|
|
|
|
|
# when training, blur the low resolution conditional image
|
|
|
|
|
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
|
|
|
|
|
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
|
|
|
|
|
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
|
|
|
|
|
|
|
|
|
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
|
|
|
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
|
|
|
|
|
|
|
|
|
# time conditioning
|
|
|
|
|
|
|
|
|
|
time_tokens = self.time_mlp(time)
|
|
|
|
|
|
|
|
|
|
# conditional dropout
|
|
|
|
|
|
|
|
|
|
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
|
|
|
|
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
|
|
|
|
|
|
|
|
|
# mask out image embedding depending on condition dropout
|
|
|
|
|
# for classifier free guidance
|
|
|
|
|
|
|
|
|
|
image_tokens = None
|
|
|
|
|
image_tokens = self.image_to_cond(image_embed)
|
|
|
|
|
|
|
|
|
|
if self.cond_on_image_embeds:
|
|
|
|
|
image_tokens = self.image_to_cond(image_embed)
|
|
|
|
|
|
|
|
|
|
image_tokens = torch.where(
|
|
|
|
|
cond_prob_mask,
|
|
|
|
|
image_tokens,
|
|
|
|
|
self.null_image_embed
|
|
|
|
|
)
|
|
|
|
|
image_tokens = torch.where(
|
|
|
|
|
cond_prob_mask,
|
|
|
|
|
image_tokens,
|
|
|
|
|
self.null_image_embed
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# take care of text encodings (optional)
|
|
|
|
|
|
|
|
|
|
text_tokens = None
|
|
|
|
|
|
|
|
|
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
|
|
|
|
if exists(text_encodings):
|
|
|
|
|
text_tokens = self.text_to_cond(text_encodings)
|
|
|
|
|
text_tokens = torch.where(
|
|
|
|
|
cond_prob_mask,
|
|
|
|
|
@@ -1017,38 +888,30 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# main conditioning tokens (c)
|
|
|
|
|
|
|
|
|
|
c = time_tokens
|
|
|
|
|
|
|
|
|
|
if exists(image_tokens):
|
|
|
|
|
c = torch.cat((c, image_tokens), dim = -2)
|
|
|
|
|
c = torch.cat((time_tokens, image_tokens), dim = -2)
|
|
|
|
|
|
|
|
|
|
# text and image conditioning tokens (mid_c)
|
|
|
|
|
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
|
|
|
|
|
|
|
|
|
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
|
|
|
|
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
|
|
|
|
|
|
|
|
|
|
# go through the layers of the unet, down and up
|
|
|
|
|
|
|
|
|
|
hiddens = []
|
|
|
|
|
|
|
|
|
|
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
|
|
|
|
for convnext, convnext2, downsample in self.downs:
|
|
|
|
|
x = convnext(x, c)
|
|
|
|
|
x = sparse_attn(x)
|
|
|
|
|
x = convnext2(x, c)
|
|
|
|
|
hiddens.append(x)
|
|
|
|
|
x = downsample(x)
|
|
|
|
|
|
|
|
|
|
x = self.mid_block1(x, mid_c)
|
|
|
|
|
|
|
|
|
|
if exists(self.mid_attn):
|
|
|
|
|
x = self.mid_attn(x)
|
|
|
|
|
|
|
|
|
|
x = self.mid_attn(x)
|
|
|
|
|
x = self.mid_block2(x, mid_c)
|
|
|
|
|
|
|
|
|
|
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
|
|
|
|
for convnext, convnext2, upsample in self.ups:
|
|
|
|
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
|
|
|
|
x = convnext(x, c)
|
|
|
|
|
x = sparse_attn(x)
|
|
|
|
|
x = convnext2(x, c)
|
|
|
|
|
x = upsample(x)
|
|
|
|
|
|
|
|
|
|
@@ -1057,56 +920,24 @@ class Unet(nn.Module):
|
|
|
|
|
class Decoder(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
unet,
|
|
|
|
|
net,
|
|
|
|
|
*,
|
|
|
|
|
clip,
|
|
|
|
|
timesteps = 1000,
|
|
|
|
|
cond_drop_prob = 0.2,
|
|
|
|
|
loss_type = 'l1',
|
|
|
|
|
beta_schedule = 'cosine',
|
|
|
|
|
image_sizes = None # for cascading ddpm, image size at each stage
|
|
|
|
|
loss_type = 'l1'
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert isinstance(clip, CLIP)
|
|
|
|
|
freeze_model_and_make_eval_(clip)
|
|
|
|
|
self.clip = clip
|
|
|
|
|
self.clip_image_size = clip.image_size
|
|
|
|
|
|
|
|
|
|
self.net = net
|
|
|
|
|
self.channels = clip.image_channels
|
|
|
|
|
|
|
|
|
|
# automatically take care of ensuring that first unet is unconditional
|
|
|
|
|
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
|
|
|
|
|
|
|
|
|
self.unets = nn.ModuleList([])
|
|
|
|
|
for ind, one_unet in enumerate(cast_tuple(unet)):
|
|
|
|
|
is_first = ind == 0
|
|
|
|
|
one_unet = one_unet.force_lowres_cond(not is_first)
|
|
|
|
|
self.unets.append(one_unet)
|
|
|
|
|
|
|
|
|
|
# unet image sizes
|
|
|
|
|
|
|
|
|
|
image_sizes = default(image_sizes, (clip.image_size,))
|
|
|
|
|
image_sizes = tuple(sorted(set(image_sizes)))
|
|
|
|
|
|
|
|
|
|
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
|
|
|
|
self.image_sizes = image_sizes
|
|
|
|
|
|
|
|
|
|
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
|
|
|
|
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
|
|
|
|
|
|
|
|
|
self.image_size = clip.image_size
|
|
|
|
|
self.cond_drop_prob = cond_drop_prob
|
|
|
|
|
|
|
|
|
|
if beta_schedule == "cosine":
|
|
|
|
|
betas = cosine_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "linear":
|
|
|
|
|
betas = linear_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "quadratic":
|
|
|
|
|
betas = quadratic_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "jsd":
|
|
|
|
|
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
|
|
|
|
elif beta_schedule == "sigmoid":
|
|
|
|
|
betas = sigmoid_beta_schedule(timesteps)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
betas = cosine_beta_schedule(timesteps)
|
|
|
|
|
|
|
|
|
|
alphas = 1. - betas
|
|
|
|
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
|
|
|
|
@@ -1142,26 +973,11 @@ class Decoder(nn.Module):
|
|
|
|
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def one_unet_in_gpu(self, unet_number):
|
|
|
|
|
assert 0 < unet_number <= len(self.unets)
|
|
|
|
|
index = unet_number - 1
|
|
|
|
|
self.cuda()
|
|
|
|
|
self.unets.cpu()
|
|
|
|
|
|
|
|
|
|
unet = self.unets[index]
|
|
|
|
|
unet.cuda()
|
|
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
self.unets.cpu()
|
|
|
|
|
|
|
|
|
|
def get_text_encodings(self, text):
|
|
|
|
|
text_encodings = self.clip.text_transformer(text)
|
|
|
|
|
return text_encodings[:, 1:]
|
|
|
|
|
|
|
|
|
|
def get_image_embed(self, image):
|
|
|
|
|
image = resize_image_to(image, self.clip_image_size)
|
|
|
|
|
image_encoding = self.clip.visual_transformer(image)
|
|
|
|
|
image_cls = image_encoding[:, 0]
|
|
|
|
|
image_embed = self.clip.to_visual_latent(image_cls)
|
|
|
|
|
@@ -1188,9 +1004,8 @@ class Decoder(nn.Module):
|
|
|
|
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
|
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
|
|
|
|
|
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
|
|
|
|
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
|
|
|
|
|
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
|
|
|
|
|
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
|
|
|
|
|
|
|
|
|
|
if clip_denoised:
|
|
|
|
|
x_recon.clamp_(-1., 1.)
|
|
|
|
|
@@ -1199,25 +1014,33 @@ class Decoder(nn.Module):
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
|
|
|
|
|
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
|
|
|
|
b, *_, device = *x.shape, x.device
|
|
|
|
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
|
|
|
|
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
|
|
|
|
noise = noise_like(x.shape, device, repeat_noise)
|
|
|
|
|
# no noise when t == 0
|
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
|
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
|
|
|
|
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
|
|
|
|
|
device = self.betas.device
|
|
|
|
|
|
|
|
|
|
b = shape[0]
|
|
|
|
|
img = torch.randn(shape, device = device)
|
|
|
|
|
img = torch.randn(shape, device=device)
|
|
|
|
|
|
|
|
|
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
|
|
|
|
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
|
|
|
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
|
|
|
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def sample(self, image_embed, text = None, cond_scale = 1.):
|
|
|
|
|
batch_size = image_embed.shape[0]
|
|
|
|
|
image_size = self.image_size
|
|
|
|
|
channels = self.channels
|
|
|
|
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
|
|
|
|
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
|
|
|
|
|
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
@@ -1226,17 +1049,16 @@ class Decoder(nn.Module):
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
|
|
|
|
|
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
|
|
|
|
|
|
|
|
|
x_recon = unet(
|
|
|
|
|
x_recon = self.net(
|
|
|
|
|
x_noisy,
|
|
|
|
|
t,
|
|
|
|
|
image_embed = image_embed,
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
cond_drop_prob = self.cond_drop_prob
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -1244,54 +1066,22 @@ class Decoder(nn.Module):
|
|
|
|
|
loss = F.l1_loss(noise, x_recon)
|
|
|
|
|
elif self.loss_type == 'l2':
|
|
|
|
|
loss = F.mse_loss(noise, x_recon)
|
|
|
|
|
elif self.loss_type == "huber":
|
|
|
|
|
loss = F.smooth_l1_loss(noise, x_recon)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
@eval_decorator
|
|
|
|
|
def sample(self, image_embed, text = None, cond_scale = 1.):
|
|
|
|
|
batch_size = image_embed.shape[0]
|
|
|
|
|
channels = self.channels
|
|
|
|
|
|
|
|
|
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
|
|
|
|
|
|
|
|
|
img = None
|
|
|
|
|
|
|
|
|
|
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
|
|
|
|
|
with self.one_unet_in_gpu(ind + 1):
|
|
|
|
|
shape = (batch_size, channels, image_size, image_size)
|
|
|
|
|
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
|
|
|
|
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
|
|
|
|
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
|
|
|
|
unet_number = default(unet_number, 1)
|
|
|
|
|
assert 1 <= unet_number <= len(self.unets)
|
|
|
|
|
|
|
|
|
|
index = unet_number - 1
|
|
|
|
|
unet = self.unets[index]
|
|
|
|
|
target_image_size = self.image_sizes[index]
|
|
|
|
|
|
|
|
|
|
b, c, h, w, device, = *image.shape, image.device
|
|
|
|
|
|
|
|
|
|
check_shape(image, 'b c h w', c = self.channels)
|
|
|
|
|
assert h >= target_image_size and w >= target_image_size
|
|
|
|
|
def forward(self, image, text = None):
|
|
|
|
|
b, device, img_size, = image.shape[0], image.device, self.image_size
|
|
|
|
|
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
|
|
|
|
|
|
|
|
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
|
|
|
|
|
|
|
|
|
if not exists(image_embed):
|
|
|
|
|
image_embed = self.get_image_embed(image)
|
|
|
|
|
image_embed = self.get_image_embed(image)
|
|
|
|
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
|
|
|
|
|
|
|
|
|
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
|
|
|
|
|
|
|
|
|
lowres_cond_img = image if index > 0 else None
|
|
|
|
|
ddpm_image = resize_image_to(image, target_image_size)
|
|
|
|
|
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
|
|
|
|
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
# main class
|
|
|
|
|
|
|
|
|
|
|