|
|
|
|
@@ -84,7 +84,7 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
|
|
|
|
|
if orig_image_size == shape:
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
|
return F.interpolate(t, size = shape, mode = mode)
|
|
|
|
|
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
|
|
|
|
|
|
|
|
|
# classifier free guidance functions
|
|
|
|
|
|
|
|
|
|
@@ -143,6 +143,92 @@ def sigmoid_beta_schedule(timesteps):
|
|
|
|
|
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseGaussianDiffusion(nn.Module):
|
|
|
|
|
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
if beta_schedule == "cosine":
|
|
|
|
|
betas = cosine_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "linear":
|
|
|
|
|
betas = linear_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "quadratic":
|
|
|
|
|
betas = quadratic_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "jsd":
|
|
|
|
|
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
|
|
|
|
elif beta_schedule == "sigmoid":
|
|
|
|
|
betas = sigmoid_beta_schedule(timesteps)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
alphas = 1. - betas
|
|
|
|
|
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
|
|
|
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
|
|
|
|
|
|
|
|
|
timesteps, = betas.shape
|
|
|
|
|
self.num_timesteps = int(timesteps)
|
|
|
|
|
self.loss_type = loss_type
|
|
|
|
|
|
|
|
|
|
self.register_buffer('betas', betas)
|
|
|
|
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
|
|
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
|
|
|
|
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
|
|
|
|
|
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
|
|
|
|
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
|
|
|
|
|
|
|
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
|
|
|
|
|
|
|
|
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
|
|
|
|
|
|
|
|
self.register_buffer('posterior_variance', posterior_variance)
|
|
|
|
|
|
|
|
|
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
|
|
|
|
|
|
|
|
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
|
|
|
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
|
|
|
|
|
|
def q_mean_variance(self, x_start, t):
|
|
|
|
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
|
|
|
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
|
|
|
|
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
|
|
|
|
return mean, variance, log_variance
|
|
|
|
|
|
|
|
|
|
def q_posterior(self, x_start, x_t, t):
|
|
|
|
|
posterior_mean = (
|
|
|
|
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
|
|
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
|
|
|
)
|
|
|
|
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
|
|
|
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
|
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
|
|
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def predict_start_from_noise(self, x_t, t, noise):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
|
|
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def sample(self, *args, **kwargs):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
# diffusion prior
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
|
|
|
@@ -421,25 +507,41 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
image_embed,
|
|
|
|
|
diffusion_timesteps,
|
|
|
|
|
*,
|
|
|
|
|
text_encodings,
|
|
|
|
|
text_embed,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
mask = None,
|
|
|
|
|
cond_drop_prob = 0.2
|
|
|
|
|
):
|
|
|
|
|
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
|
|
|
|
|
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
|
|
|
|
|
|
|
|
|
# in section 2.2, last paragraph
|
|
|
|
|
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
|
|
|
|
|
|
|
|
|
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
|
|
|
|
|
|
|
|
|
# make text encodings optional
|
|
|
|
|
# although the paper seems to suggest it is present <--
|
|
|
|
|
|
|
|
|
|
if not exists(text_encodings):
|
|
|
|
|
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
|
|
|
|
|
|
|
|
|
if not exists(mask):
|
|
|
|
|
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
|
|
|
|
|
|
|
|
|
# classifier free guidance
|
|
|
|
|
|
|
|
|
|
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
|
|
|
|
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
|
|
|
|
|
|
|
|
|
|
mask &= cond_prob_mask
|
|
|
|
|
|
|
|
|
|
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
|
|
|
|
|
|
|
|
|
mask = torch.cat((mask, cond_prob_mask), dim = 1)
|
|
|
|
|
|
|
|
|
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
|
|
|
|
# but let's just do it right
|
|
|
|
|
|
|
|
|
|
if exists(mask):
|
|
|
|
|
not_all_masked_out = mask.any(dim = -1)
|
|
|
|
|
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
|
|
|
|
|
|
|
|
|
|
if exists(mask):
|
|
|
|
|
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
|
|
|
|
|
|
|
|
|
@@ -455,16 +557,6 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
learned_queries
|
|
|
|
|
), dim = -2)
|
|
|
|
|
|
|
|
|
|
# mask if it doesn't exist
|
|
|
|
|
|
|
|
|
|
if not exists(mask):
|
|
|
|
|
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
|
|
|
|
|
|
|
|
|
|
# classifier free guidance
|
|
|
|
|
|
|
|
|
|
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
|
|
|
|
mask &= rearrange(cond_prob_mask, 'b -> b 1')
|
|
|
|
|
|
|
|
|
|
# attend
|
|
|
|
|
|
|
|
|
|
tokens = self.causal_transformer(tokens, mask = mask)
|
|
|
|
|
@@ -475,81 +567,50 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
|
|
|
|
|
return pred_image_embed
|
|
|
|
|
|
|
|
|
|
class DiffusionPrior(nn.Module):
|
|
|
|
|
class DiffusionPrior(BaseGaussianDiffusion):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
net,
|
|
|
|
|
*,
|
|
|
|
|
clip,
|
|
|
|
|
clip = None,
|
|
|
|
|
image_embed_dim = None,
|
|
|
|
|
image_size = None,
|
|
|
|
|
image_channels = 3,
|
|
|
|
|
timesteps = 1000,
|
|
|
|
|
cond_drop_prob = 0.2,
|
|
|
|
|
loss_type = "l1",
|
|
|
|
|
predict_x_start = True,
|
|
|
|
|
beta_schedule = "cosine",
|
|
|
|
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert isinstance(clip, CLIP)
|
|
|
|
|
freeze_model_and_make_eval_(clip)
|
|
|
|
|
self.clip = clip
|
|
|
|
|
super().__init__(
|
|
|
|
|
beta_schedule = beta_schedule,
|
|
|
|
|
timesteps = timesteps,
|
|
|
|
|
loss_type = loss_type
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if exists(clip):
|
|
|
|
|
assert isinstance(clip, CLIP)
|
|
|
|
|
freeze_model_and_make_eval_(clip)
|
|
|
|
|
self.clip = clip
|
|
|
|
|
else:
|
|
|
|
|
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
|
|
|
|
|
self.clip = None
|
|
|
|
|
|
|
|
|
|
self.net = net
|
|
|
|
|
self.image_embed_dim = clip.dim_latent
|
|
|
|
|
self.channels = clip.image_channels
|
|
|
|
|
self.image_size = clip.image_size
|
|
|
|
|
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
|
|
|
|
|
self.channels = default(image_channels, lambda: clip.image_channels)
|
|
|
|
|
|
|
|
|
|
self.cond_drop_prob = cond_drop_prob
|
|
|
|
|
self.condition_on_text_encodings = condition_on_text_encodings
|
|
|
|
|
|
|
|
|
|
self.predict_x_start = predict_x_start
|
|
|
|
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
|
|
|
|
|
|
|
|
|
if beta_schedule == "cosine":
|
|
|
|
|
betas = cosine_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "linear":
|
|
|
|
|
betas = linear_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "quadratic":
|
|
|
|
|
betas = quadratic_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "jsd":
|
|
|
|
|
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
|
|
|
|
elif beta_schedule == "sigmoid":
|
|
|
|
|
betas = sigmoid_beta_schedule(timesteps)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
alphas = 1. - betas
|
|
|
|
|
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
|
|
|
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
|
|
|
|
|
|
|
|
|
timesteps, = betas.shape
|
|
|
|
|
self.num_timesteps = int(timesteps)
|
|
|
|
|
self.loss_type = loss_type
|
|
|
|
|
|
|
|
|
|
self.register_buffer('betas', betas)
|
|
|
|
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
|
|
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
|
|
|
|
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
|
|
|
|
|
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
|
|
|
|
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
|
|
|
|
|
|
|
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
|
|
|
|
|
|
|
|
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
|
|
|
|
|
|
|
|
self.register_buffer('posterior_variance', posterior_variance)
|
|
|
|
|
|
|
|
|
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
|
|
|
|
|
|
|
|
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
|
|
|
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def get_image_embed(self, image):
|
|
|
|
|
assert exists(self.clip)
|
|
|
|
|
|
|
|
|
|
image_encoding = self.clip.visual_transformer(image)
|
|
|
|
|
image_cls = image_encoding[:, 0]
|
|
|
|
|
image_embed = self.clip.to_visual_latent(image_cls)
|
|
|
|
|
@@ -557,33 +618,18 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def get_text_cond(self, text):
|
|
|
|
|
assert exists(self.clip)
|
|
|
|
|
|
|
|
|
|
text_encodings = self.clip.text_transformer(text)
|
|
|
|
|
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
|
|
|
|
text_embed = self.clip.to_text_latent(text_cls)
|
|
|
|
|
text_embed = l2norm(text_embed)
|
|
|
|
|
|
|
|
|
|
if not self.condition_on_text_encodings:
|
|
|
|
|
return dict(text_embed = text_embed)
|
|
|
|
|
|
|
|
|
|
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
|
|
|
|
|
|
|
|
|
def q_mean_variance(self, x_start, t):
|
|
|
|
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
|
|
|
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
|
|
|
|
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
|
|
|
|
return mean, variance, log_variance
|
|
|
|
|
|
|
|
|
|
def predict_start_from_noise(self, x_t, t, noise):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
|
|
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def q_posterior(self, x_start, x_t, t):
|
|
|
|
|
posterior_mean = (
|
|
|
|
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
|
|
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
|
|
|
)
|
|
|
|
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
|
|
|
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
|
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
|
|
|
|
pred = self.net(x, t, **text_cond)
|
|
|
|
|
|
|
|
|
|
@@ -620,14 +666,6 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def p_losses(self, image_embed, t, text_cond, noise = None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(image_embed))
|
|
|
|
|
|
|
|
|
|
@@ -679,13 +717,41 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
|
|
|
|
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
|
|
|
|
|
|
|
|
|
def forward(self, text, image, *args, **kwargs):
|
|
|
|
|
b, device, img_size, = image.shape[0], image.device, self.image_size
|
|
|
|
|
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
text = None,
|
|
|
|
|
image = None,
|
|
|
|
|
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
|
|
|
|
|
image_embed = None,
|
|
|
|
|
text_encodings = None, # as well as CLIP text encodings
|
|
|
|
|
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
|
|
|
|
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
|
|
|
|
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
|
|
|
|
|
|
|
|
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
|
|
|
|
image_embed = self.get_image_embed(image)
|
|
|
|
|
text_cond = self.get_text_cond(text)
|
|
|
|
|
if exists(image):
|
|
|
|
|
image_embed = self.get_image_embed(image)
|
|
|
|
|
|
|
|
|
|
# calculate text conditionings, based on what is passed in
|
|
|
|
|
|
|
|
|
|
if exists(text):
|
|
|
|
|
text_cond = self.get_text_cond(text)
|
|
|
|
|
else:
|
|
|
|
|
text_cond = dict(
|
|
|
|
|
text_embed = text_embed,
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
mask = text_mask
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# timestep conditioning from ddpm
|
|
|
|
|
|
|
|
|
|
batch, device = image_embed.shape[0], image_embed.device
|
|
|
|
|
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
|
|
|
|
|
|
|
|
|
|
# calculate forward loss
|
|
|
|
|
|
|
|
|
|
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
|
|
|
|
return loss
|
|
|
|
|
@@ -693,7 +759,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
# decoder
|
|
|
|
|
|
|
|
|
|
def Upsample(dim):
|
|
|
|
|
return QueryAttnUpsample(dim)
|
|
|
|
|
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
|
|
|
|
|
|
|
|
|
def Downsample(dim):
|
|
|
|
|
return nn.Conv2d(dim, dim, 4, 2, 1)
|
|
|
|
|
@@ -840,6 +906,7 @@ class Unet(nn.Module):
|
|
|
|
|
dim,
|
|
|
|
|
*,
|
|
|
|
|
image_embed_dim,
|
|
|
|
|
text_embed_dim = None,
|
|
|
|
|
cond_dim = None,
|
|
|
|
|
num_image_tokens = 4,
|
|
|
|
|
num_time_tokens = 2,
|
|
|
|
|
@@ -853,6 +920,7 @@ class Unet(nn.Module):
|
|
|
|
|
sparse_attn_window = 8, # window size for sparse attention
|
|
|
|
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
|
|
|
|
cond_on_text_encodings = False,
|
|
|
|
|
max_text_len = 256,
|
|
|
|
|
cond_on_image_embeds = False,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -892,7 +960,7 @@ class Unet(nn.Module):
|
|
|
|
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
|
|
|
|
) if image_embed_dim != cond_dim else nn.Identity()
|
|
|
|
|
|
|
|
|
|
self.text_to_cond = nn.LazyLinear(cond_dim)
|
|
|
|
|
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
|
|
|
|
|
|
|
|
|
|
# finer control over whether to condition on image embeddings and text encodings
|
|
|
|
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
|
|
|
|
@@ -903,7 +971,7 @@ class Unet(nn.Module):
|
|
|
|
|
# for classifier free guidance
|
|
|
|
|
|
|
|
|
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
|
|
|
|
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
|
|
|
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
|
|
|
|
|
|
|
|
|
# attention related params
|
|
|
|
|
|
|
|
|
|
@@ -1031,7 +1099,7 @@ class Unet(nn.Module):
|
|
|
|
|
text_tokens = torch.where(
|
|
|
|
|
cond_prob_mask,
|
|
|
|
|
text_tokens,
|
|
|
|
|
self.null_text_embed
|
|
|
|
|
self.null_text_embed[:, :text_tokens.shape[1]]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# main conditioning tokens (c)
|
|
|
|
|
@@ -1111,13 +1179,13 @@ class LowresConditioner(nn.Module):
|
|
|
|
|
|
|
|
|
|
return cond_fmap
|
|
|
|
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
|
|
|
class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
unet,
|
|
|
|
|
*,
|
|
|
|
|
clip,
|
|
|
|
|
vae = None,
|
|
|
|
|
vae = tuple(),
|
|
|
|
|
timesteps = 1000,
|
|
|
|
|
cond_drop_prob = 0.2,
|
|
|
|
|
loss_type = 'l1',
|
|
|
|
|
@@ -1129,14 +1197,22 @@ class Decoder(nn.Module):
|
|
|
|
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
|
|
|
|
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
|
|
|
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
|
|
|
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
super().__init__(
|
|
|
|
|
beta_schedule = beta_schedule,
|
|
|
|
|
timesteps = timesteps,
|
|
|
|
|
loss_type = loss_type
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert isinstance(clip, CLIP)
|
|
|
|
|
freeze_model_and_make_eval_(clip)
|
|
|
|
|
self.clip = clip
|
|
|
|
|
self.clip_image_size = clip.image_size
|
|
|
|
|
self.channels = clip.image_channels
|
|
|
|
|
|
|
|
|
|
self.condition_on_text_encodings = condition_on_text_encodings
|
|
|
|
|
|
|
|
|
|
# automatically take care of ensuring that first unet is unconditional
|
|
|
|
|
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
|
|
|
|
|
|
|
|
|
@@ -1192,55 +1268,6 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.cond_drop_prob = cond_drop_prob
|
|
|
|
|
|
|
|
|
|
# noise schedule
|
|
|
|
|
|
|
|
|
|
if beta_schedule == "cosine":
|
|
|
|
|
betas = cosine_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "linear":
|
|
|
|
|
betas = linear_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "quadratic":
|
|
|
|
|
betas = quadratic_beta_schedule(timesteps)
|
|
|
|
|
elif beta_schedule == "jsd":
|
|
|
|
|
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
|
|
|
|
elif beta_schedule == "sigmoid":
|
|
|
|
|
betas = sigmoid_beta_schedule(timesteps)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
alphas = 1. - betas
|
|
|
|
|
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
|
|
|
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
|
|
|
|
|
|
|
|
|
timesteps, = betas.shape
|
|
|
|
|
self.num_timesteps = int(timesteps)
|
|
|
|
|
self.loss_type = loss_type
|
|
|
|
|
|
|
|
|
|
self.register_buffer('betas', betas)
|
|
|
|
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
|
|
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
|
|
|
|
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
|
|
|
|
|
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
|
|
|
|
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
|
|
|
|
|
|
|
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
|
|
|
|
|
|
|
|
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
|
|
|
|
|
|
|
|
self.register_buffer('posterior_variance', posterior_variance)
|
|
|
|
|
|
|
|
|
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
|
|
|
|
|
|
|
|
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
|
|
|
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
|
|
|
|
|
|
def get_unet(self, unet_number):
|
|
|
|
|
assert 0 < unet_number <= len(self.unets)
|
|
|
|
|
index = unet_number - 1
|
|
|
|
|
@@ -1273,27 +1300,6 @@ class Decoder(nn.Module):
|
|
|
|
|
image_embed = self.clip.to_visual_latent(image_cls)
|
|
|
|
|
return l2norm(image_embed)
|
|
|
|
|
|
|
|
|
|
def q_mean_variance(self, x_start, t):
|
|
|
|
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
|
|
|
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
|
|
|
|
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
|
|
|
|
return mean, variance, log_variance
|
|
|
|
|
|
|
|
|
|
def predict_start_from_noise(self, x_t, t, noise):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
|
|
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def q_posterior(self, x_start, x_t, t):
|
|
|
|
|
posterior_mean = (
|
|
|
|
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
|
|
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
|
|
|
)
|
|
|
|
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
|
|
|
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
|
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
|
|
|
|
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
@@ -1338,14 +1344,6 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
@@ -1380,6 +1378,8 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
|
|
|
|
|
|
|
|
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
|
|
|
|
|
|
|
|
|
img = None
|
|
|
|
|
|
|
|
|
|
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
|
|
|
|
@@ -1440,6 +1440,8 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
|
|
|
|
|
|
|
|
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
|
|
|
|
|
|
|
|
|
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
|
|
|
|
image = resize_image_to(image, target_image_size)
|
|
|
|
|
|
|
|
|
|
@@ -1467,7 +1469,9 @@ class DALLE2(nn.Module):
|
|
|
|
|
assert isinstance(decoder, Decoder)
|
|
|
|
|
self.prior = prior
|
|
|
|
|
self.decoder = decoder
|
|
|
|
|
|
|
|
|
|
self.prior_num_samples = prior_num_samples
|
|
|
|
|
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
@eval_decorator
|
|
|
|
|
@@ -1484,7 +1488,9 @@ class DALLE2(nn.Module):
|
|
|
|
|
text = tokenizer.tokenize(text).to(device)
|
|
|
|
|
|
|
|
|
|
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
|
|
|
|
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
|
|
|
|
|
|
|
|
|
text_cond = text if self.decoder_need_text_cond else None
|
|
|
|
|
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
|
|
|
|
|
|
|
|
|
if one_text:
|
|
|
|
|
return images[0]
|
|
|
|
|
|