From 9878be760bf69fdc6747dce9fc38d9588e474ec5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 26 Apr 2022 09:47:09 -0700 Subject: [PATCH] have researcher explicitly state upfront whether to condition with text encodings in cascading ddpm decoder, have DALLE-2 class take care of passing in text if feature turned on --- README.md | 3 ++- dalle2_pytorch/dalle2_pytorch.py | 18 +++++++++++++++--- setup.py | 2 +- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 3979503..4500708 100644 --- a/README.md +++ b/README.md @@ -348,7 +348,8 @@ decoder = Decoder( image_sizes = (128, 256), clip = clip, timesteps = 100, - cond_drop_prob = 0.2 + cond_drop_prob = 0.2, + condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling ).cuda() for unet_number in (1, 2): diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index e1cedcc..7ac2794 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -894,6 +894,7 @@ class Unet(nn.Module): sparse_attn_window = 8, # window size for sparse attention attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) cond_on_text_encodings = False, + max_text_len = 256, cond_on_image_embeds = False, ): super().__init__() @@ -944,7 +945,7 @@ class Unet(nn.Module): # for classifier free guidance self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) - self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim)) + self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) # attention related params @@ -1072,7 +1073,7 @@ class Unet(nn.Module): text_tokens = torch.where( cond_prob_mask, text_tokens, - self.null_text_embed + self.null_text_embed[:, :text_tokens.shape[1]] ) # main conditioning tokens (c) @@ -1170,6 +1171,7 @@ class Decoder(nn.Module): lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur blur_sigma = 0.1, # cascading ddpm - blur sigma blur_kernel_size = 3, # cascading ddpm - blur kernel size + condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation ): super().__init__() assert isinstance(clip, CLIP) @@ -1178,6 +1180,8 @@ class Decoder(nn.Module): self.clip_image_size = clip.image_size self.channels = clip.image_channels + self.condition_on_text_encodings = condition_on_text_encodings + # automatically take care of ensuring that first unet is unconditional # while the rest of the unets are conditioned on the low resolution image produced by previous unet @@ -1421,6 +1425,8 @@ class Decoder(nn.Module): text_encodings = self.get_text_encodings(text) if exists(text) else None + assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' + img = None for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): @@ -1481,6 +1487,8 @@ class Decoder(nn.Module): text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None + assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' + lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None image = resize_image_to(image, target_image_size) @@ -1508,7 +1516,9 @@ class DALLE2(nn.Module): assert isinstance(decoder, Decoder) self.prior = prior self.decoder = decoder + self.prior_num_samples = prior_num_samples + self.decoder_need_text_cond = self.decoder.condition_on_text_encodings @torch.no_grad() @eval_decorator @@ -1525,7 +1535,9 @@ class DALLE2(nn.Module): text = tokenizer.tokenize(text).to(device) image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples) - images = self.decoder.sample(image_embed, cond_scale = cond_scale) + + text_cond = text if self.decoder_need_text_cond else None + images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) if one_text: return images[0] diff --git a/setup.py b/setup.py index dbbea31..b7ab3bf 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.48', + version = '0.0.49', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',