From c76a964fd655dd620eb737798d55ac8031a54641 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 5 May 2022 08:11:01 -0700 Subject: [PATCH] allow for CLIP to be optional in Decoder, and allow DecoderTrainer to work off training pre-encoded image embeddings --- dalle2_pytorch/dalle2_pytorch.py | 29 ++++++++++++++++++++--------- setup.py | 2 +- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 6b5c76a..bdb0951 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1457,7 +1457,9 @@ class Decoder(BaseGaussianDiffusion): self, unet, *, - clip, + clip = None, + image_size = None, + channels = 3, vae = tuple(), timesteps = 1000, image_cond_drop_prob = 0.1, @@ -1481,15 +1483,22 @@ class Decoder(BaseGaussianDiffusion): loss_type = loss_type ) - if isinstance(clip, CLIP): - clip = XClipAdapter(clip) + assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' - freeze_model_and_make_eval_(clip) - assert isinstance(clip, BaseClipAdapter) + self.clip = None + if exists(clip): + if isinstance(clip, CLIP): + clip = XClipAdapter(clip) - self.clip = clip - self.clip_image_size = clip.image_size - self.channels = clip.image_channels + freeze_model_and_make_eval_(clip) + assert isinstance(clip, BaseClipAdapter) + + self.clip = clip + self.clip_image_size = clip.image_size + self.channels = clip.image_channels + else: + self.clip_image_size = image_size + self.channels = channels self.condition_on_text_encodings = condition_on_text_encodings @@ -1522,7 +1531,7 @@ class Decoder(BaseGaussianDiffusion): # unet image sizes - image_sizes = default(image_sizes, (clip.image_size,)) + image_sizes = default(image_sizes, (self.clip_image_size,)) image_sizes = tuple(sorted(set(image_sizes))) assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}' @@ -1727,10 +1736,12 @@ class Decoder(BaseGaussianDiffusion): times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) if not exists(image_embed): + assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init' image_embed, _ = self.clip.embed_image(image) text_encodings = text_mask = None if exists(text) and not exists(text_encodings): + assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder' _, text_encodings, text_mask = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' diff --git a/setup.py b/setup.py index 2394dfc..be5520d 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.105', + version = '0.0.106', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',