From 6f941a219a456eed993cf5b1514f3c27ac42cf01 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 20 Apr 2022 10:04:47 -0700 Subject: [PATCH] give time tokens a surface area of 2 tokens as default, make it so researcher can customize which unet actually is conditioned on image embeddings and/or text encodings --- README.md | 4 ++-- dalle2_pytorch/cli.py | 2 +- dalle2_pytorch/dalle2_pytorch.py | 39 +++++++++++++++++++++++--------- setup.py | 2 +- 4 files changed, 32 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 183b942..afae8af 100644 --- a/README.md +++ b/README.md @@ -410,9 +410,9 @@ Offer training wrappers - [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper) - [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions - [x] add efficient attention in unet -- [ ] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning) -- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) +- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning) - [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting) +- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] train on a toy task, offer in colab diff --git a/dalle2_pytorch/cli.py b/dalle2_pytorch/cli.py index ed87e8b..9e89f0d 100644 --- a/dalle2_pytorch/cli.py +++ b/dalle2_pytorch/cli.py @@ -6,4 +6,4 @@ def main(): @click.command() @click.argument('text') def dream(text): - return image + return 'not ready yet' diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 8326f9a..00a3032 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -820,6 +820,7 @@ class Unet(nn.Module): image_embed_dim, cond_dim = None, num_image_tokens = 4, + num_time_tokens = 2, out_dim = None, dim_mults=(1, 2, 4, 8), channels = 3, @@ -830,6 +831,8 @@ class Unet(nn.Module): sparse_attn = False, sparse_attn_window = 8, # window size for sparse attention attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) + cond_on_text_encodings = False, + cond_on_image_embeds = False, ): super().__init__() # save locals to take care of some hyperparameters for cascading DDPM @@ -862,8 +865,8 @@ class Unet(nn.Module): SinusoidalPosEmb(dim), nn.Linear(dim, dim * 4), nn.GELU(), - nn.Linear(dim * 4, cond_dim), - Rearrange('b d -> b 1 d') + nn.Linear(dim * 4, cond_dim * num_time_tokens), + Rearrange('b (r d) -> b r d', r = num_time_tokens) ) self.image_to_cond = nn.Sequential( @@ -873,6 +876,12 @@ class Unet(nn.Module): self.text_to_cond = nn.LazyLinear(cond_dim) + # finer control over whether to condition on image embeddings and text encodings + # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting + + self.cond_on_text_encodings = cond_on_text_encodings + self.cond_on_image_embeds = cond_on_image_embeds + # for classifier free guidance self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) @@ -982,17 +991,22 @@ class Unet(nn.Module): # mask out image embedding depending on condition dropout # for classifier free guidance - image_tokens = self.image_to_cond(image_embed) + image_tokens = None - image_tokens = torch.where( - cond_prob_mask, - image_tokens, - self.null_image_embed - ) + if self.cond_on_image_embeds: + image_tokens = self.image_to_cond(image_embed) + + image_tokens = torch.where( + cond_prob_mask, + image_tokens, + self.null_image_embed + ) # take care of text encodings (optional) - if exists(text_encodings): + text_tokens = None + + if exists(text_encodings) and self.cond_on_text_encodings: text_tokens = self.text_to_cond(text_encodings) text_tokens = torch.where( cond_prob_mask, @@ -1002,12 +1016,15 @@ class Unet(nn.Module): # main conditioning tokens (c) - c = torch.cat((time_tokens, image_tokens), dim = -2) + c = time_tokens + + if exists(image_tokens): + c = torch.cat((c, image_tokens), dim = -2) # text and image conditioning tokens (mid_c) # to save on compute, only do cross attention based conditioning on the inner most layers of the Unet - mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2) + mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2) # go through the layers of the unet, down and up diff --git a/setup.py b/setup.py index c1e4f9a..41e80b0 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.26', + version = '0.0.27', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',