From c30544b73aa0e4224a8d81388037ef57fc3cc7eb Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 26 Apr 2022 10:23:34 -0700 Subject: [PATCH] no CLIP altogether for training DiffusionPrior --- README.md | 49 ++++++++++++++++++++++++++++++++ dalle2_pytorch/dalle2_pytorch.py | 25 +++++++++++----- setup.py | 2 +- 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 5a90f41..806ee40 100644 --- a/README.md +++ b/README.md @@ -446,6 +446,55 @@ loss.backward() # now the diffusion prior can generate image embeddings from the text embeddings ``` +You can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization + +```python +import torch +from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior + +# setup prior network, which contains an autoregressive transformer + +prior_network = DiffusionPriorNetwork( + dim = 512, + depth = 6, + dim_head = 64, + heads = 8 +).cuda() + +# diffusion prior network, which contains the CLIP and network (with transformer) above + +diffusion_prior = DiffusionPrior( + net = prior_network, + image_embed_dim = 512, # this needs to be set + timesteps = 100, + cond_drop_prob = 0.2, + condition_on_text_encodings = False # this probably should be true, but just to get Laion started +).cuda() + +# mock data + +text = torch.randint(0, 49408, (4, 256)).cuda() +images = torch.randn(4, 3, 256, 256).cuda() + +# precompute the text and image embeddings +# here using the diffusion prior class, but could be done with CLIP alone + +clip_image_embeds = torch.randn(4, 512).cuda() +clip_text_embeds = torch.randn(4, 512).cuda() + +# feed text and images into diffusion prior network + +loss = diffusion_prior( + text_embed = clip_text_embeds, + image_embed = clip_image_embeds +) + +loss.backward() + +# do the above for many many many steps +# now the diffusion prior can generate image embeddings from the text embeddings +``` + ## Experimental ### DALL-E2 with Latent Diffusion diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 7ac2794..8dd338b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -486,7 +486,10 @@ class DiffusionPrior(nn.Module): self, net, *, - clip, + clip = None, + image_embed_dim = None, + image_size = None, + image_channels = 3, timesteps = 1000, cond_drop_prob = 0.2, loss_type = "l1", @@ -495,14 +498,18 @@ class DiffusionPrior(nn.Module): condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training ): super().__init__() - assert isinstance(clip, CLIP) - freeze_model_and_make_eval_(clip) - self.clip = clip + + if exists(clip): + assert isinstance(clip, CLIP) + freeze_model_and_make_eval_(clip) + self.clip = clip + else: + assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given' + self.clip = None self.net = net - self.image_embed_dim = clip.dim_latent - self.channels = clip.image_channels - self.image_size = clip.image_size + self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent) + self.channels = default(image_channels, lambda: clip.image_channels) self.cond_drop_prob = cond_drop_prob self.condition_on_text_encodings = condition_on_text_encodings @@ -559,6 +566,8 @@ class DiffusionPrior(nn.Module): @torch.no_grad() def get_image_embed(self, image): + assert exists(self.clip) + image_encoding = self.clip.visual_transformer(image) image_cls = image_encoding[:, 0] image_embed = self.clip.to_visual_latent(image_cls) @@ -566,6 +575,8 @@ class DiffusionPrior(nn.Module): @torch.no_grad() def get_text_cond(self, text): + assert exists(self.clip) + text_encodings = self.clip.text_transformer(text) text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:] text_embed = self.clip.to_text_latent(text_cls) diff --git a/setup.py b/setup.py index b7ab3bf..10da6b3 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.49', + version = '0.0.50', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',