|
|
|
|
@@ -2,6 +2,7 @@ import math
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
from inspect import isfunction
|
|
|
|
|
from functools import partial
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
@@ -820,6 +821,7 @@ class Unet(nn.Module):
|
|
|
|
|
image_embed_dim,
|
|
|
|
|
cond_dim = None,
|
|
|
|
|
num_image_tokens = 4,
|
|
|
|
|
num_time_tokens = 2,
|
|
|
|
|
out_dim = None,
|
|
|
|
|
dim_mults=(1, 2, 4, 8),
|
|
|
|
|
channels = 3,
|
|
|
|
|
@@ -830,6 +832,8 @@ class Unet(nn.Module):
|
|
|
|
|
sparse_attn = False,
|
|
|
|
|
sparse_attn_window = 8, # window size for sparse attention
|
|
|
|
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
|
|
|
|
cond_on_text_encodings = False,
|
|
|
|
|
cond_on_image_embeds = False,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
# save locals to take care of some hyperparameters for cascading DDPM
|
|
|
|
|
@@ -862,8 +866,8 @@ class Unet(nn.Module):
|
|
|
|
|
SinusoidalPosEmb(dim),
|
|
|
|
|
nn.Linear(dim, dim * 4),
|
|
|
|
|
nn.GELU(),
|
|
|
|
|
nn.Linear(dim * 4, cond_dim),
|
|
|
|
|
Rearrange('b d -> b 1 d')
|
|
|
|
|
nn.Linear(dim * 4, cond_dim * num_time_tokens),
|
|
|
|
|
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.image_to_cond = nn.Sequential(
|
|
|
|
|
@@ -873,6 +877,12 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.text_to_cond = nn.LazyLinear(cond_dim)
|
|
|
|
|
|
|
|
|
|
# finer control over whether to condition on image embeddings and text encodings
|
|
|
|
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
|
|
|
|
|
|
|
|
|
self.cond_on_text_encodings = cond_on_text_encodings
|
|
|
|
|
self.cond_on_image_embeds = cond_on_image_embeds
|
|
|
|
|
|
|
|
|
|
# for classifier free guidance
|
|
|
|
|
|
|
|
|
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
|
|
|
|
@@ -982,17 +992,22 @@ class Unet(nn.Module):
|
|
|
|
|
# mask out image embedding depending on condition dropout
|
|
|
|
|
# for classifier free guidance
|
|
|
|
|
|
|
|
|
|
image_tokens = self.image_to_cond(image_embed)
|
|
|
|
|
image_tokens = None
|
|
|
|
|
|
|
|
|
|
image_tokens = torch.where(
|
|
|
|
|
cond_prob_mask,
|
|
|
|
|
image_tokens,
|
|
|
|
|
self.null_image_embed
|
|
|
|
|
)
|
|
|
|
|
if self.cond_on_image_embeds:
|
|
|
|
|
image_tokens = self.image_to_cond(image_embed)
|
|
|
|
|
|
|
|
|
|
image_tokens = torch.where(
|
|
|
|
|
cond_prob_mask,
|
|
|
|
|
image_tokens,
|
|
|
|
|
self.null_image_embed
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# take care of text encodings (optional)
|
|
|
|
|
|
|
|
|
|
if exists(text_encodings):
|
|
|
|
|
text_tokens = None
|
|
|
|
|
|
|
|
|
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
|
|
|
|
text_tokens = self.text_to_cond(text_encodings)
|
|
|
|
|
text_tokens = torch.where(
|
|
|
|
|
cond_prob_mask,
|
|
|
|
|
@@ -1002,12 +1017,15 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# main conditioning tokens (c)
|
|
|
|
|
|
|
|
|
|
c = torch.cat((time_tokens, image_tokens), dim = -2)
|
|
|
|
|
c = time_tokens
|
|
|
|
|
|
|
|
|
|
if exists(image_tokens):
|
|
|
|
|
c = torch.cat((c, image_tokens), dim = -2)
|
|
|
|
|
|
|
|
|
|
# text and image conditioning tokens (mid_c)
|
|
|
|
|
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
|
|
|
|
|
|
|
|
|
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
|
|
|
|
|
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
|
|
|
|
|
|
|
|
|
# go through the layers of the unet, down and up
|
|
|
|
|
|
|
|
|
|
@@ -1124,6 +1142,20 @@ class Decoder(nn.Module):
|
|
|
|
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def one_unet_in_gpu(self, unet_number):
|
|
|
|
|
assert 0 < unet_number <= len(self.unets)
|
|
|
|
|
index = unet_number - 1
|
|
|
|
|
self.cuda()
|
|
|
|
|
self.unets.cpu()
|
|
|
|
|
|
|
|
|
|
unet = self.unets[index]
|
|
|
|
|
unet.cuda()
|
|
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
self.unets.cpu()
|
|
|
|
|
|
|
|
|
|
def get_text_encodings(self, text):
|
|
|
|
|
text_encodings = self.clip.text_transformer(text)
|
|
|
|
|
return text_encodings[:, 1:]
|
|
|
|
|
@@ -1228,9 +1260,11 @@ class Decoder(nn.Module):
|
|
|
|
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
|
|
|
|
|
|
|
|
|
img = None
|
|
|
|
|
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
|
|
|
|
|
shape = (batch_size, channels, image_size, image_size)
|
|
|
|
|
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
|
|
|
|
|
|
|
|
|
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
|
|
|
|
|
with self.one_unet_in_gpu(ind + 1):
|
|
|
|
|
shape = (batch_size, channels, image_size, image_size)
|
|
|
|
|
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
|
|
|
|
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
|