From 1f1557c614700cf38e6666a984083e45865ae8e8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Jul 2022 10:56:11 -0700 Subject: [PATCH] make it so even if text mask is omitted, it will be derived based on whether text encodings are all 0s or not, simplify dataloading --- dalle2_pytorch/dalle2_pytorch.py | 10 ++++++++++ dalle2_pytorch/version.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index bb1da34..24cd6e9 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -220,6 +220,7 @@ class XClipAdapter(BaseClipAdapter): encoder_output = self.clip.text_transformer(text) text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:] text_embed = self.clip.to_text_latent(text_cls) + text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) return EmbeddedText(l2norm(text_embed), text_encodings, text_mask) @torch.no_grad() @@ -255,6 +256,7 @@ class CoCaAdapter(BaseClipAdapter): text = text[..., :self.max_text_len] text_mask = text != 0 text_embed, text_encodings = self.clip.embed_text(text) + text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) return EmbeddedText(text_embed, text_encodings, text_mask) @torch.no_grad() @@ -314,6 +316,7 @@ class OpenAIClipAdapter(BaseClipAdapter): text_embed = self.clip.encode_text(text) text_encodings = self.text_encodings + text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) del self.text_encodings return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask) @@ -1197,6 +1200,7 @@ class DiffusionPrior(nn.Module): if self.condition_on_text_encodings: assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified' + text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1)) text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} # timestep conditioning from ddpm @@ -2410,6 +2414,9 @@ class Decoder(nn.Module): assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' + if self.condition_on_text_encodings: + text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1)) + img = None is_cuda = next(self.parameters()).is_cuda @@ -2493,6 +2500,9 @@ class Decoder(nn.Module): assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' + if self.condition_on_text_encodings: + text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1)) + lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None image = resize_image_to(image, target_image_size) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index bd5657d..2f15b8c 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.19.6' +__version__ = '0.20.0'