mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
take care of DDPM decoder (DDPM for producing image embedding will have a separate objective, predicting directly the embedding rather than the noise [epsilon in paper])
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import tqdm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
@@ -52,6 +53,30 @@ def prob_mask_like(shape, prob, device):
|
|||||||
else:
|
else:
|
||||||
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
|
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
|
||||||
|
|
||||||
|
# gaussian diffusion helper functions
|
||||||
|
|
||||||
|
def extract(a, t, x_shape):
|
||||||
|
b, *_ = t.shape
|
||||||
|
out = a.gather(-1, t)
|
||||||
|
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||||
|
|
||||||
|
def noise_like(shape, device, repeat=False):
|
||||||
|
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||||
|
noise = lambda: torch.randn(shape, device=device)
|
||||||
|
return repeat_noise() if repeat else noise()
|
||||||
|
|
||||||
|
def cosine_beta_schedule(timesteps, s = 0.008):
|
||||||
|
"""
|
||||||
|
cosine schedule
|
||||||
|
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||||
|
"""
|
||||||
|
steps = timesteps + 1
|
||||||
|
x = torch.linspace(0, steps, steps)
|
||||||
|
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||||
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||||
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||||
|
return torch.clip(betas, 0, 0.999)
|
||||||
|
|
||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@@ -445,24 +470,159 @@ class Unet(nn.Module):
|
|||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
net,
|
||||||
*,
|
*,
|
||||||
clip,
|
clip,
|
||||||
prior
|
timesteps = 1000,
|
||||||
|
cond_prob_drop = 0.2,
|
||||||
|
loss_type = 'l1'
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
assert isinstance(prior, DiffusionPrior)
|
|
||||||
freeze_model_and_make_eval_(clip)
|
freeze_model_and_make_eval_(clip)
|
||||||
|
|
||||||
def forward(
|
self.net = net
|
||||||
self,
|
self.channels = clip.image_channels
|
||||||
*,
|
self.image_size = clip.image_size
|
||||||
image,
|
self.cond_prob_drop = cond_prob_drop
|
||||||
image_embed,
|
|
||||||
cond_drop_prob = 0.2, # for the classifier free guidance
|
betas = cosine_beta_schedule(timesteps)
|
||||||
text_embed = None # in paper, text embedding was optional for conditioning decoder
|
|
||||||
):
|
alphas = 1. - betas
|
||||||
return image
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||||
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
||||||
|
|
||||||
|
timesteps, = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.loss_type = loss_type
|
||||||
|
|
||||||
|
self.register_buffer('betas', betas)
|
||||||
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||||
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||||
|
|
||||||
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
|
|
||||||
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||||
|
|
||||||
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
|
|
||||||
|
self.register_buffer('posterior_variance', posterior_variance)
|
||||||
|
|
||||||
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||||
|
|
||||||
|
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||||
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
|
def get_image_embed(self, image):
|
||||||
|
image_encoding = self.clip.visual_transformer(image)
|
||||||
|
image_cls = image_encoding[:, 0]
|
||||||
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
|
return image_embed
|
||||||
|
|
||||||
|
def q_mean_variance(self, x_start, t):
|
||||||
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||||
|
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
||||||
|
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||||
|
return mean, variance, log_variance
|
||||||
|
|
||||||
|
def predict_start_from_noise(self, x_t, t, noise):
|
||||||
|
return (
|
||||||
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
||||||
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||||
|
)
|
||||||
|
|
||||||
|
def q_posterior(self, x_start, x_t, t):
|
||||||
|
posterior_mean = (
|
||||||
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||||
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||||
|
)
|
||||||
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||||
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
|
def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
|
||||||
|
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed))
|
||||||
|
|
||||||
|
if clip_denoised:
|
||||||
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, clip_denoised = clip_denoised)
|
||||||
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
|
# no noise when t == 0
|
||||||
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_loop(self, shape, image_embed):
|
||||||
|
device = self.betas.device
|
||||||
|
|
||||||
|
b = shape[0]
|
||||||
|
img = torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||||
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed)
|
||||||
|
return img
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, image_embed):
|
||||||
|
batch_size = image_embed.shape[0]
|
||||||
|
image_size = self.image_size
|
||||||
|
channels = self.channels
|
||||||
|
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed)
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
|
return (
|
||||||
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||||
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
|
)
|
||||||
|
|
||||||
|
def p_losses(self, x_start, image_embed, t, noise = None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
|
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||||
|
|
||||||
|
x_recon = self.net(
|
||||||
|
x_noisy,
|
||||||
|
t,
|
||||||
|
image_embed = image_embed,
|
||||||
|
cond_prob_drop = self.cond_prob_drop
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.loss_type == 'l1':
|
||||||
|
loss = F.l1_loss(noise, x_recon)
|
||||||
|
elif self.loss_type == 'l2':
|
||||||
|
loss = F.mse_loss(noise, x_recon)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def forward(self, image, *args, **kwargs):
|
||||||
|
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||||
|
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||||
|
|
||||||
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||||
|
image_embed = self.get_image_embed(image)
|
||||||
|
|
||||||
|
loss = self.p_losses(x, times, image_embed = image_embed, *args, **kwargs)
|
||||||
|
return loss
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user