diff --git a/README.md b/README.md index 85ada0e..277fef0 100644 --- a/README.md +++ b/README.md @@ -647,7 +647,7 @@ Once built, images will be saved to the same directory the command is invoked - [x] use attention-based upsampling https://arxiv.org/abs/2112.11435 - [x] use inheritance just this once for sharing logic between decoder and prior network ddpms - [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion -- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in +- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 927d38b..59fc151 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -89,6 +89,59 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https:// return F.interpolate(t, size = shape, mode = mode, align_corners = False) +# clip related adapters + +class BaseClipAdapter(nn.Module): + def __init__(self, clip): + super().__init__() + self.clip = clip + + @property + def dim_latent(self): + raise NotImplementedError + + @property + def image_size(self): + raise NotImplementedError + + @property + def image_channels(self): + raise NotImplementedError + + def embed_text(self, text): + raise NotImplementedError + + def embed_image(self, image): + raise NotImplementedError + +class XClipAdapter(BaseClipAdapter): + @property + def dim_latent(self): + return self.clip.dim_latent + + @property + def image_size(self): + return self.clip.image_size + + @property + def image_channels(self): + return self.clip.image_channels + + @torch.no_grad() + def embed_text(self, text): + encoder_output = self.clip.text_transformer(text) + text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:] + text_embed = self.clip.to_text_latent(text_cls) + return l2norm(text_embed), text_encodings + + @torch.no_grad() + def embed_image(self, image): + image = resize_image_to(image, self.image_size) + encoder_output = self.clip.visual_transformer(image) + image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:] + image_embed = self.clip.to_visual_latent(image_cls) + return l2norm(image_embed), image_encodings + # classifier free guidance functions def prob_mask_like(shape, prob, device): @@ -595,7 +648,7 @@ class DiffusionPrior(BaseGaussianDiffusion): if exists(clip): assert isinstance(clip, CLIP) freeze_model_and_make_eval_(clip) - self.clip = clip + self.clip = XClipAdapter(clip) else: assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given' self.clip = None @@ -610,29 +663,6 @@ class DiffusionPrior(BaseGaussianDiffusion): self.predict_x_start = predict_x_start # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. - @torch.no_grad() - def get_image_embed(self, image): - assert exists(self.clip) - - image_encoding = self.clip.visual_transformer(image) - image_cls = image_encoding[:, 0] - image_embed = self.clip.to_visual_latent(image_cls) - return l2norm(image_embed) - - @torch.no_grad() - def get_text_cond(self, text): - assert exists(self.clip) - - text_encodings = self.clip.text_transformer(text) - text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:] - text_embed = self.clip.to_text_latent(text_cls) - text_embed = l2norm(text_embed) - - if not self.condition_on_text_encodings: - return dict(text_embed = text_embed) - - return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0) - def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): pred = self.net(x, t, **text_cond) @@ -704,7 +734,13 @@ class DiffusionPrior(BaseGaussianDiffusion): batch_size = text.shape[0] image_embed_dim = self.image_embed_dim - text_cond = self.get_text_cond(text) + text_embed, text_encodings = self.clip.embed_text(text) + + text_cond = dict( + text_embed = text_embed, + text_encodings = text_encodings, + mask = text != 0 + ) image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) text_embeds = text_cond['text_embed'] @@ -736,18 +772,19 @@ class DiffusionPrior(BaseGaussianDiffusion): assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization' if exists(image): - image_embed = self.get_image_embed(image) + image_embed, _ = self.clip.embed_image(image) # calculate text conditionings, based on what is passed in if exists(text): - text_cond = self.get_text_cond(text) - else: - text_cond = dict( - text_embed = text_embed, - text_encodings = text_encodings, - mask = text_mask - ) + text_embed, text_encodings = self.clip.embed_text(text) + text_mask = text != 0 + + text_cond = dict( + text_embed = text_embed, + text_encodings = text_encodings, + mask = text_mask + ) # timestep conditioning from ddpm @@ -1208,7 +1245,9 @@ class Decoder(BaseGaussianDiffusion): loss_type = loss_type ) - assert isinstance(clip, CLIP) + if isinstance(clip, CLIP): + clip = XClipAdapter(clip) + freeze_model_and_make_eval_(clip) self.clip = clip self.clip_image_size = clip.image_size @@ -1290,10 +1329,6 @@ class Decoder(BaseGaussianDiffusion): yield unet.cpu() - @torch.no_grad() - def get_text_encodings(self, text): - text_encodings = self.clip.text_transformer(text) - return text_encodings[:, 1:] @torch.no_grad() def get_image_embed(self, image): @@ -1379,7 +1414,9 @@ class Decoder(BaseGaussianDiffusion): def sample(self, image_embed, text = None, cond_scale = 1.): batch_size = image_embed.shape[0] - text_encodings = self.get_text_encodings(text) if exists(text) else None + text_encodings = None + if exists(text): + _, text_encodings = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' @@ -1442,9 +1479,11 @@ class Decoder(BaseGaussianDiffusion): times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) if not exists(image_embed): - image_embed = self.get_image_embed(image) + image_embed, _ = self.clip.embed_image(image) - text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None + text_encodings = None + if exists(text) and not exists(text_encodings): + _, text_encodings = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' diff --git a/setup.py b/setup.py index 5c1dbd3..b479399 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.56', + version = '0.0.57', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',