From 9f55c24db6805761e1907ee379a1f7036d07018b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 14 Apr 2022 11:46:45 -0700 Subject: [PATCH] allow for decoder conditioning with the text encodings from CLIP, if it is passed in. use lazy linear to avoid researchers having to worry about text encoding dimensions, but remove later if it does not work well --- README.md | 2 +- dalle2_pytorch/dalle2_pytorch.py | 65 +++++++++++++++++++++++--------- setup.py | 2 +- 3 files changed, 50 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index adae23f..9022190 100644 --- a/README.md +++ b/README.md @@ -276,7 +276,7 @@ decoder = Decoder( cond_drop_prob = 0.2 ).cuda() -loss = decoder(images) +loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much loss.backward() # do above for many steps diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index eb03cd7..f717d1e 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -722,7 +722,7 @@ class Unet(nn.Module): dims = [channels, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - # time and image embeddings + # time, image embeddings, and optional text encoding cond_dim = default(cond_dim, dim) @@ -739,9 +739,12 @@ class Unet(nn.Module): Rearrange('b (n d) -> b n d', n = num_image_tokens) ) if image_embed_dim != cond_dim else nn.Identity() + self.text_to_cond = nn.LazyLinear(cond_dim) + # for classifier free guidance self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) + self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim)) # layers @@ -806,6 +809,7 @@ class Unet(nn.Module): time_tokens = self.time_mlp(time) cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device) + cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1') # mask out image embedding depending on condition dropout # for classifier free guidance @@ -813,12 +817,31 @@ class Unet(nn.Module): image_tokens = self.image_to_cond(image_embed) image_tokens = torch.where( - rearrange(cond_prob_mask, 'b -> b 1 1'), + cond_prob_mask, image_tokens, self.null_image_embed ) - c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition + # take care of text encodings (optional) + + if exists(text_encodings): + text_tokens = self.text_to_cond(text_encodings) + text_tokens = torch.where( + cond_prob_mask, + text_tokens, + self.null_text_embed + ) + + # main conditioning tokens (c) + + c = torch.cat((time_tokens, image_tokens), dim = -2) + + # text and image conditioning tokens (mid_c) + # to save on compute, only do cross attention based conditioning on the inner most layers of the Unet + + mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2) + + # go through the layers of the unet, down and up hiddens = [] @@ -828,9 +851,9 @@ class Unet(nn.Module): hiddens.append(x) x = downsample(x) - x = self.mid_block1(x, c) + x = self.mid_block1(x, mid_c) x = self.mid_attn(x) - x = self.mid_block2(x, c) + x = self.mid_block2(x, mid_c) for convnext, convnext2, upsample in self.ups: x = torch.cat((x, hiddens.pop()), dim=1) @@ -896,6 +919,10 @@ class Decoder(nn.Module): self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + def get_text_encodings(self, text): + text_encodings = self.clip.text_transformer(text) + return text_encodings[:, 1:] + def get_image_embed(self, image): image_encoding = self.clip.visual_transformer(image) image_cls = image_encoding[:, 0] @@ -923,8 +950,8 @@ class Decoder(nn.Module): posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.): - x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale)) + def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.): + x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)) if clip_denoised: x_recon.clamp_(-1., 1.) @@ -933,31 +960,32 @@ class Decoder(nn.Module): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False): + def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised) + model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop(self, shape, image_embed, cond_scale = 1): + def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1): device = self.betas.device b = shape[0] img = torch.randn(shape, device=device) for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): - img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale) + img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale) return img @torch.no_grad() - def sample(self, image_embed, cond_scale = 1.): + def sample(self, image_embed, text = None, cond_scale = 1.): batch_size = image_embed.shape[0] image_size = self.image_size channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale) + text_encodings = self.get_text_encodings(text) if exists(text) else None + return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) @@ -967,7 +995,7 @@ class Decoder(nn.Module): extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) - def p_losses(self, x_start, t, *, image_embed, noise = None): + def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) @@ -976,6 +1004,7 @@ class Decoder(nn.Module): x_noisy, t, image_embed = image_embed, + text_encodings = text_encodings, cond_drop_prob = self.cond_drop_prob ) @@ -988,14 +1017,16 @@ class Decoder(nn.Module): return loss - def forward(self, image): + def forward(self, image, text = None): b, device, img_size, = image.shape[0], image.device, self.image_size check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels) times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) - image_embed = self.get_image_embed(image) - loss = self.p_losses(image, times, image_embed = image_embed) + image_embed = self.get_image_embed(image) + text_encodings = self.get_text_encodings(text) if exists(text) else None + + loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings) return loss # main class diff --git a/setup.py b/setup.py index 0f15d48..9abd021 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.12', + version = '0.0.14', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',