From e76e89f9eb8ed2eea881c533db934b62cb7d0567 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Jul 2022 15:40:31 -0700 Subject: [PATCH] remove text masking altogether in favor of deriving from text encodings (padded text encodings must be pad value of 0.) --- README.md | 2 +- dalle2_pytorch/dalle2_pytorch.py | 56 +++++++++++++------------------- dalle2_pytorch/version.py | 2 +- train_diffusion_prior.py | 9 ++--- 4 files changed, 28 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 42379ae..015a318 100644 --- a/README.md +++ b/README.md @@ -421,7 +421,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a ## Training on Preprocessed CLIP Embeddings -It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask` +It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` Working example below diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index f84794e..8571cb2 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -165,7 +165,7 @@ def unnormalize_zero_to_one(normed_img): # clip related adapters -EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask']) +EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings']) EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings']) class BaseClipAdapter(nn.Module): @@ -226,7 +226,7 @@ class XClipAdapter(BaseClipAdapter): text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:] text_embed = self.clip.to_text_latent(text_cls) text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) - return EmbeddedText(l2norm(text_embed), text_encodings, text_mask) + return EmbeddedText(l2norm(text_embed), text_encodings) @torch.no_grad() def embed_image(self, image): @@ -262,7 +262,7 @@ class CoCaAdapter(BaseClipAdapter): text_mask = text != 0 text_embed, text_encodings = self.clip.embed_text(text) text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) - return EmbeddedText(text_embed, text_encodings, text_mask) + return EmbeddedText(text_embed, text_encodings) @torch.no_grad() def embed_image(self, image): @@ -323,7 +323,7 @@ class OpenAIClipAdapter(BaseClipAdapter): text_encodings = self.text_encodings text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) del self.text_encodings - return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask) + return EmbeddedText(l2norm(text_embed.float()), text_encodings.float()) @torch.no_grad() def embed_image(self, image): @@ -871,8 +871,7 @@ class DiffusionPriorNetwork(nn.Module): if not exists(text_encodings): text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype) - if not exists(mask) or mask.numel() == 0: - mask = torch.any(text_encodings != 0., dim = -1) + mask = torch.any(text_encodings != 0., dim = -1) # classifier free guidance @@ -889,9 +888,8 @@ class DiffusionPriorNetwork(nn.Module): # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # but let's just do it right - if exists(mask): - attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds - mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query + attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds + mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query time_embed = self.to_time_embeds(diffusion_timesteps) @@ -1152,12 +1150,12 @@ class DiffusionPrior(nn.Module): batch_size = text.shape[0] image_embed_dim = self.image_embed_dim - text_embed, text_encodings, text_mask = self.clip.embed_text(text) + text_embed, text_encodings = self.clip.embed_text(text) text_cond = dict(text_embed = text_embed) if self.condition_on_text_encodings: - text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} + text_cond = {**text_cond, 'text_encodings': text_encodings} image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps) @@ -1185,7 +1183,6 @@ class DiffusionPrior(nn.Module): text_embed = None, # allow for training on preprocessed CLIP text and image embeddings image_embed = None, text_encodings = None, # as well as CLIP text encodings - text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity *args, **kwargs ): @@ -1199,13 +1196,13 @@ class DiffusionPrior(nn.Module): # calculate text conditionings, based on what is passed in if exists(text): - text_embed, text_encodings, text_mask = self.clip.embed_text(text) + text_embed, text_encodings = self.clip.embed_text(text) text_cond = dict(text_embed = text_embed) if self.condition_on_text_encodings: assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified' - text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} + text_cond = {**text_cond, 'text_encodings': text_encodings} # timestep conditioning from ddpm @@ -1744,7 +1741,6 @@ class Unet(nn.Module): image_embed, lowres_cond_img = None, text_encodings = None, - text_mask = None, image_cond_drop_prob = 0., text_cond_drop_prob = 0., blur_sigma = None, @@ -1818,8 +1814,7 @@ class Unet(nn.Module): if exists(text_encodings) and self.cond_on_text_encodings: assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.' - if not exists(text_mask) or text_mask.numel() == 0: - text_mask = torch.any(text_encodings != 0., dim = -1) + text_mask = torch.any(text_encodings != 0., dim = -1) text_tokens = self.text_to_cond(text_encodings) @@ -2218,10 +2213,10 @@ class Decoder(nn.Module): x = x.clamp(-s, s) / s return x - def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): + def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' - pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)) + pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)) if learned_variance: pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1) @@ -2253,16 +2248,16 @@ class Decoder(nn.Module): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True): + def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance) + model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance) noise = torch.randn_like(x) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): + def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False): device = self.device b = shape[0] @@ -2278,7 +2273,6 @@ class Decoder(nn.Module): torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, - text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, @@ -2291,7 +2285,7 @@ class Decoder(nn.Module): return unnormalize_img @torch.no_grad() - def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): + def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False): batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1] @@ -2307,7 +2301,7 @@ class Decoder(nn.Module): time_cond = torch.full((batch,), time, device = device, dtype = torch.long) - pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) + pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) if learned_variance: pred, _ = pred.chunk(2, dim = 1) @@ -2346,7 +2340,7 @@ class Decoder(nn.Module): return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs) - def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): + def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): noise = default(noise, lambda: torch.randn_like(x_start)) # normalize to [-1, 1] @@ -2364,7 +2358,6 @@ class Decoder(nn.Module): times, image_embed = image_embed, text_encodings = text_encodings, - text_mask = text_mask, lowres_cond_img = lowres_cond_img, image_cond_drop_prob = self.image_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob, @@ -2424,7 +2417,6 @@ class Decoder(nn.Module): self, image_embed = None, text = None, - text_mask = None, text_encodings = None, batch_size = 1, cond_scale = 1., @@ -2438,7 +2430,7 @@ class Decoder(nn.Module): if exists(text) and not exists(text_encodings) and not self.unconditional: assert exists(self.clip) - _, text_encodings, text_mask = self.clip.embed_text(text) + _, text_encodings = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' @@ -2468,7 +2460,6 @@ class Decoder(nn.Module): shape, image_embed = image_embed, text_encodings = text_encodings, - text_mask = text_mask, cond_scale = cond_scale, predict_x_start = predict_x_start, learned_variance = learned_variance, @@ -2492,7 +2483,6 @@ class Decoder(nn.Module): text = None, image_embed = None, text_encodings = None, - text_mask = None, unet_number = None, return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes ): @@ -2521,7 +2511,7 @@ class Decoder(nn.Module): if exists(text) and not exists(text_encodings) and not self.unconditional: assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder' - _, text_encodings, text_mask = self.clip.embed_text(text) + _, text_encodings = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' @@ -2544,7 +2534,7 @@ class Decoder(nn.Module): image = vae.encode(image) lowres_cond_img = maybe(vae.encode)(lowres_cond_img) - losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler) + losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler) if not return_lowres_cond_image: return losses diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index ebcbb29..81edede 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.21.3' +__version__ = '0.22.0' diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 02f2ce2..454dc79 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -126,9 +126,9 @@ def report_cosine_sims( # we are text conditioned, we produce an embedding from the tokenized text if text_conditioned: - text_embedding, text_encodings, text_mask = trainer.embed_text(text_data) + text_embedding, text_encodings = trainer.embed_text(text_data) text_cond = dict( - text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask + text_embed=text_embedding, text_encodings=text_encodings ) else: text_embedding = text_data @@ -146,15 +146,12 @@ def report_cosine_sims( if text_conditioned: text_encodings_shuffled = text_encodings[rolled_idx] - text_mask_shuffled = text_mask[rolled_idx] else: text_encodings_shuffled = None - text_mask_shuffled = None text_cond_shuffled = dict( text_embed=text_embed_shuffled, - text_encodings=text_encodings_shuffled, - mask=text_mask_shuffled, + text_encodings=text_encodings_shuffled ) # prepare the text embedding