|
|
|
|
@@ -1,7 +1,7 @@
|
|
|
|
|
import math
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
from inspect import isfunction
|
|
|
|
|
from functools import partial
|
|
|
|
|
from functools import partial, wraps
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from collections import namedtuple
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
@@ -45,6 +45,14 @@ def exists(val):
|
|
|
|
|
def identity(t, *args, **kwargs):
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
|
def maybe(fn):
|
|
|
|
|
@wraps(fn)
|
|
|
|
|
def inner(x):
|
|
|
|
|
if not exists(x):
|
|
|
|
|
return x
|
|
|
|
|
return fn(x)
|
|
|
|
|
return inner
|
|
|
|
|
|
|
|
|
|
def default(val, d):
|
|
|
|
|
if exists(val):
|
|
|
|
|
return val
|
|
|
|
|
@@ -114,10 +122,10 @@ def resize_image_to(image, target_image_size):
|
|
|
|
|
# ddpms expect images to be in the range of -1 to 1
|
|
|
|
|
# but CLIP may otherwise
|
|
|
|
|
|
|
|
|
|
def normalize_img(img):
|
|
|
|
|
def normalize_neg_one_to_one(img):
|
|
|
|
|
return img * 2 - 1
|
|
|
|
|
|
|
|
|
|
def unnormalize_img(normed_img):
|
|
|
|
|
def unnormalize_zero_to_one(normed_img):
|
|
|
|
|
return (normed_img + 1) * 0.5
|
|
|
|
|
|
|
|
|
|
# clip related adapters
|
|
|
|
|
@@ -606,7 +614,6 @@ class Attention(nn.Module):
|
|
|
|
|
heads = 8,
|
|
|
|
|
dropout = 0.,
|
|
|
|
|
causal = False,
|
|
|
|
|
post_norm = False,
|
|
|
|
|
rotary_emb = None
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -616,7 +623,6 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.causal = causal
|
|
|
|
|
self.norm = LayerNorm(dim)
|
|
|
|
|
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
|
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
|
|
|
|
@@ -627,7 +633,7 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.to_out = nn.Sequential(
|
|
|
|
|
nn.Linear(inner_dim, dim, bias = False),
|
|
|
|
|
LayerNorm(dim) if post_norm else nn.Identity()
|
|
|
|
|
LayerNorm(dim)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, mask = None, attn_bias = None):
|
|
|
|
|
@@ -684,8 +690,7 @@ class Attention(nn.Module):
|
|
|
|
|
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
|
|
|
|
|
|
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
|
|
|
out = self.to_out(out)
|
|
|
|
|
return self.post_norm(out)
|
|
|
|
|
return self.to_out(out)
|
|
|
|
|
|
|
|
|
|
class CausalTransformer(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
@@ -711,7 +716,7 @@ class CausalTransformer(nn.Module):
|
|
|
|
|
self.layers = nn.ModuleList([])
|
|
|
|
|
for _ in range(depth):
|
|
|
|
|
self.layers.append(nn.ModuleList([
|
|
|
|
|
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
|
|
|
|
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
|
|
|
|
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
@@ -1037,7 +1042,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|
|
|
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
|
|
|
|
|
|
|
|
|
if exists(image):
|
|
|
|
|
image_embed, _ = self.clip.embed_image(unnormalize_img(image))
|
|
|
|
|
image_embed, _ = self.clip.embed_image(image)
|
|
|
|
|
|
|
|
|
|
# calculate text conditionings, based on what is passed in
|
|
|
|
|
|
|
|
|
|
@@ -1173,7 +1178,11 @@ class CrossAttention(nn.Module):
|
|
|
|
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
|
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
|
|
|
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
|
|
|
|
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
|
|
|
|
|
|
|
|
|
self.to_out = nn.Sequential(
|
|
|
|
|
nn.Linear(inner_dim, dim, bias = False),
|
|
|
|
|
LayerNorm(dim)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, context, mask = None):
|
|
|
|
|
b, n, device = *x.shape[:2], x.device
|
|
|
|
|
@@ -1821,7 +1830,7 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
# eq 15 - https://arxiv.org/abs/2102.09672
|
|
|
|
|
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
|
|
|
|
|
max_log = extract(torch.log(self.betas), t, x.shape)
|
|
|
|
|
var_interp_frac = unnormalize_img(var_interp_frac_unnormalized)
|
|
|
|
|
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
|
|
|
|
|
|
|
|
|
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
|
|
|
|
posterior_variance = posterior_log_variance.exp()
|
|
|
|
|
@@ -1844,6 +1853,8 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
b = shape[0]
|
|
|
|
|
img = torch.randn(shape, device = device)
|
|
|
|
|
|
|
|
|
|
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
|
|
|
|
img = self.p_sample(
|
|
|
|
|
unet,
|
|
|
|
|
@@ -1859,11 +1870,19 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
clip_denoised = clip_denoised
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return img
|
|
|
|
|
unnormalize_img = unnormalize_zero_to_one(img)
|
|
|
|
|
return unnormalize_img
|
|
|
|
|
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
# normalize to [-1, 1]
|
|
|
|
|
|
|
|
|
|
x_start = normalize_neg_one_to_one(x_start)
|
|
|
|
|
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
# get x_t
|
|
|
|
|
|
|
|
|
|
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
|
|
|
|
|
|
|
|
|
model_output = unet(
|
|
|
|
|
@@ -2011,7 +2030,7 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
|
|
|
|
|
if not exists(image_embed):
|
|
|
|
|
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
|
|
|
|
image_embed, _ = self.clip.embed_image(unnormalize_img(image))
|
|
|
|
|
image_embed, _ = self.clip.embed_image(image)
|
|
|
|
|
|
|
|
|
|
text_encodings = text_mask = None
|
|
|
|
|
if exists(text) and not exists(text_encodings):
|
|
|
|
|
|