take care of DDPM decoder (DDPM for producing image embedding will have a separate objective, predicting directly the embedding rather than the noise [epsilon in paper])

This commit is contained in:
Phil Wang
2022-04-12 17:48:41 -07:00
parent 862e5ba50e
commit 33d69d3859
2 changed files with 172 additions and 12 deletions

View File

@@ -1,3 +1,4 @@
import tqdm
import torch
import torch.nn.functional as F
from torch import nn, einsum
@@ -52,6 +53,30 @@ def prob_mask_like(shape, prob, device):
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
# gaussian diffusion helper functions
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, steps, steps)
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
# diffusion prior
class RMSNorm(nn.Module):
@@ -445,24 +470,159 @@ class Unet(nn.Module):
class Decoder(nn.Module):
def __init__(
self,
net,
*,
clip,
prior
timesteps = 1000,
cond_prob_drop = 0.2,
loss_type = 'l1'
):
super().__init__()
assert isinstance(clip, CLIP)
assert isinstance(prior, DiffusionPrior)
freeze_model_and_make_eval_(clip)
def forward(
self,
*,
image,
image_embed,
cond_drop_prob = 0.2, # for the classifier free guidance
text_embed = None # in paper, text embedding was optional for conditioning decoder
):
return image
self.net = net
self.channels = clip.image_channels
self.image_size = clip.image_size
self.cond_prob_drop = cond_prob_drop
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return image_embed
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed))
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, image_embed):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed)
return img
@torch.no_grad()
def sample(self, image_embed):
batch_size = image_embed.shape[0]
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, image_embed, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
x_recon = self.net(
x_noisy,
t,
image_embed = image_embed,
cond_prob_drop = self.cond_prob_drop
)
if self.loss_type == 'l1':
loss = F.l1_loss(noise, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon)
else:
raise NotImplementedError()
return loss
def forward(self, image, *args, **kwargs):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
loss = self.p_losses(x, times, image_embed = image_embed, *args, **kwargs)
return loss
# main class

View File

@@ -29,7 +29,7 @@ setup(
'torch>=1.10',
'torchvision',
'tqdm',
'x-clip>=0.4.3',
'x-clip>=0.4.4',
'youtokentome'
],
classifiers=[