From 1c1e508369da34eb35741558d33203f42fea006e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 30 Apr 2022 09:13:34 -0700 Subject: [PATCH] fix all issues with text encodings conditioning in the decoder, using null padding tokens technique from dalle v1 --- dalle2_pytorch/dalle2_pytorch.py | 44 +++++++++++++++++++++++--------- setup.py | 2 +- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 48e8ef8..bf0d7fa 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1101,6 +1101,8 @@ class Unet(nn.Module): # for classifier free guidance self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) + + self.max_text_len = max_text_len self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) # attention related params @@ -1185,6 +1187,7 @@ class Unet(nn.Module): image_embed, lowres_cond_img = None, text_encodings = None, + text_mask = None, image_cond_drop_prob = 0., text_cond_drop_prob = 0., blur_sigma = None, @@ -1229,11 +1232,25 @@ class Unet(nn.Module): text_tokens = None if exists(text_encodings) and self.cond_on_text_encodings: + text_tokens = text_tokens[:, :self.max_text_len] + + text_tokens_len = text_tokens.shape[1] + remainder = self.max_text_len - text_tokens_len + + if remainder > 0: + text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) + + if exists(text_mask): + if remainder > 0: + text_mask = F.pad(text_mask, (0, remainder), value = False) + + text_keep_mask &= text_mask + text_tokens = self.text_to_cond(text_encodings) text_tokens = torch.where( text_keep_mask, text_tokens, - self.null_text_embed[:, :text_tokens.shape[1]] + self.null_text_embed ) # main conditioning tokens (c) @@ -1434,8 +1451,8 @@ class Decoder(BaseGaussianDiffusion): image_embed, _ = self.clip.embed_image(image) return image_embed - def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): - pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) + def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): + pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) if predict_x_start: x_recon = pred @@ -1449,16 +1466,16 @@ class Decoder(BaseGaussianDiffusion): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False): + def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start) + model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1): + def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): device = self.betas.device b = shape[0] @@ -1471,6 +1488,7 @@ class Decoder(BaseGaussianDiffusion): torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, + text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start @@ -1478,7 +1496,7 @@ class Decoder(BaseGaussianDiffusion): return img - def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None): + def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise) @@ -1488,6 +1506,7 @@ class Decoder(BaseGaussianDiffusion): times, image_embed = image_embed, text_encodings = text_encodings, + text_mask = text_mask, lowres_cond_img = lowres_cond_img, image_cond_drop_prob = self.image_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob, @@ -1503,9 +1522,9 @@ class Decoder(BaseGaussianDiffusion): def sample(self, image_embed, text = None, cond_scale = 1.): batch_size = image_embed.shape[0] - text_encodings = None + text_encodings = text_mask = None if exists(text): - _, text_encodings, _ = self.clip.embed_text(text) + _, text_encodings, text_mask = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' @@ -1534,6 +1553,7 @@ class Decoder(BaseGaussianDiffusion): shape, image_embed = image_embed, text_encodings = text_encodings, + text_mask = text_mask, cond_scale = cond_scale, predict_x_start = predict_x_start, lowres_cond_img = lowres_cond_img @@ -1571,9 +1591,9 @@ class Decoder(BaseGaussianDiffusion): if not exists(image_embed): image_embed, _ = self.clip.embed_image(image) - text_encodings = None + text_encodings = text_mask = None if exists(text) and not exists(text_encodings): - _, text_encodings, _ = self.clip.embed_text(text) + _, text_encodings, text_mask = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' @@ -1588,7 +1608,7 @@ class Decoder(BaseGaussianDiffusion): if exists(lowres_cond_img): lowres_cond_img = vae.encode(lowres_cond_img) - return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start) + return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start) # main class diff --git a/setup.py b/setup.py index 51c7ac3..6d46224 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.74', + version = '0.0.75', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',