diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c47ecc7..c2a7928 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1,3 +1,4 @@ +import tqdm import torch import torch.nn.functional as F from torch import nn, einsum @@ -52,6 +53,30 @@ def prob_mask_like(shape, prob, device): else: return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob +# gaussian diffusion helper functions + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + +def cosine_beta_schedule(timesteps, s = 0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, steps, steps) + alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.999) + # diffusion prior class RMSNorm(nn.Module): @@ -445,24 +470,159 @@ class Unet(nn.Module): class Decoder(nn.Module): def __init__( self, + net, *, clip, - prior + timesteps = 1000, + cond_prob_drop = 0.2, + loss_type = 'l1' ): super().__init__() assert isinstance(clip, CLIP) - assert isinstance(prior, DiffusionPrior) freeze_model_and_make_eval_(clip) - def forward( - self, - *, - image, - image_embed, - cond_drop_prob = 0.2, # for the classifier free guidance - text_embed = None # in paper, text embedding was optional for conditioning decoder - ): - return image + self.net = net + self.channels = clip.image_channels + self.image_size = clip.image_size + self.cond_prob_drop = cond_prob_drop + + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + self.register_buffer('betas', betas) + self.register_buffer('alphas_cumprod', alphas_cumprod) + self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) + self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) + self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) + self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + self.register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) + self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + def get_image_embed(self, image): + image_encoding = self.clip.visual_transformer(image) + image_cls = image_encoding[:, 0] + image_embed = self.clip.to_visual_latent(image_cls) + return image_embed + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, image_embed, clip_denoised: bool): + x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed)) + + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, clip_denoised = clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, image_embed): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed) + return img + + @torch.no_grad() + def sample(self, image_embed): + batch_size = image_embed.shape[0] + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def p_losses(self, x_start, image_embed, t, noise = None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) + + x_recon = self.net( + x_noisy, + t, + image_embed = image_embed, + cond_prob_drop = self.cond_prob_drop + ) + + if self.loss_type == 'l1': + loss = F.l1_loss(noise, x_recon) + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, image, *args, **kwargs): + b, device, img_size, = image.shape[0], image.device, self.image_size + check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels) + + times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) + image_embed = self.get_image_embed(image) + + loss = self.p_losses(x, times, image_embed = image_embed, *args, **kwargs) + return loss # main class diff --git a/setup.py b/setup.py index 18facc3..cbc5c2c 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ setup( 'torch>=1.10', 'torchvision', 'tqdm', - 'x-clip>=0.4.3', + 'x-clip>=0.4.4', 'youtokentome' ], classifiers=[