diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c2a7928..45622a0 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -221,8 +221,8 @@ class DiffusionPriorNetwork(nn.Module): def forward( self, image_embed, - *, diffusion_timesteps, + *, text_encodings, text_embed, mask = None, @@ -272,21 +272,169 @@ class DiffusionPriorNetwork(nn.Module): class DiffusionPrior(nn.Module): def __init__( self, + net, *, - clip + clip, + timesteps = 1000, + cond_prob_drop = 0.2, + loss_type = 'l1' ): super().__init__() assert isinstance(clip, CLIP) freeze_model_and_make_eval_(clip) - def forward( - self, - *, - text, - image = None - ): + self.net = net + self.image_embed_dim = clip.dim_latent + self.channels = clip.image_channels + self.image_size = clip.image_size + self.cond_prob_drop = cond_prob_drop + + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + self.register_buffer('betas', betas) + self.register_buffer('alphas_cumprod', alphas_cumprod) + self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) + self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) + self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) + self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + self.register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) + self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + def get_image_embed(self, image): + image_encoding = self.clip.visual_transformer(image) + image_cls = image_encoding[:, 0] + image_embed = self.clip.to_visual_latent(image_cls) return image_embed + def get_text_cond(self, text): + text_encodings = self.clip.text_transformer(text) + text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:] + text_embed = self.clip.to_text_latent(text_cls) + return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): + x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond)) + + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, image_embed, text_cond = None, clip_denoised = True, repeat_noise = False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, text_cond): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond) + return img + + @torch.no_grad() + def sample(self, text): + batch_size = text.shape[0] + image_embed_dim = self.image_embed_dim + + text_cond = self.get_text_cond(text) + return self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def p_losses(self, image_embed, t, text_cond, noise = None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise) + + x_recon = self.net( + image_embed_noisy, + t, + cond_prob_drop = self.cond_prob_drop, + **text_cond + ) + + if self.loss_type == 'l1': + loss = F.l1_loss(noise, x_recon) + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, text, image, *args, **kwargs): + b, device, img_size, = image.shape[0], image.device, self.image_size + check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels) + + times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) + image_embed = self.get_image_embed(image) + text_cond = self.get_text_cond(text) + + loss = self.p_losses(x, times, image_embed = image_embed, text_cond = text_cond, *args, **kwargs) + return loss + # decoder def Upsample(dim): @@ -428,9 +576,9 @@ class Unet(nn.Module): def forward( self, x, + time, *, image_embed, - time, text_encodings = None, cond_prob_drop = 0. ): diff --git a/setup.py b/setup.py index cbc5c2c..d0b5dd2 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.2', + version = '0.0.3', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',