From 0be1e0d64c2139989a777ba7c1be78d92e1120a0 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 6 May 2022 08:27:12 -0700 Subject: [PATCH] support CoCa, which seems to be better than CLIP (has an autoregressive text encoder) https://arxiv.org/abs/2205.01917 --- README.md | 10 +++++++ dalle2_pytorch/dalle2_pytorch.py | 49 +++++++++++++++++++++++++++++--- setup.py | 3 +- 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 090a255..1531e9d 100644 --- a/README.md +++ b/README.md @@ -1047,4 +1047,14 @@ Once built, images will be saved to the same directory the command is invoked } ``` +```bibtex +@article{Yu2022CoCaCC, + title = {CoCa: Contrastive Captioners are Image-Text Foundation Models}, + author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2205.01917} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 197c855..8999f28 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -26,6 +26,7 @@ from resize_right import resize # use x-clip from x_clip import CLIP +from coca_pytorch import CoCa # helper functions @@ -113,9 +114,10 @@ EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 't EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings']) class BaseClipAdapter(nn.Module): - def __init__(self, clip): + def __init__(self, clip, **kwargs): super().__init__() self.clip = clip + self.overrides = kwargs @property def dim_latent(self): @@ -173,6 +175,39 @@ class XClipAdapter(BaseClipAdapter): image_embed = self.clip.to_visual_latent(image_cls) return EmbeddedImage(l2norm(image_embed), image_encodings) +class CoCaAdapter(BaseClipAdapter): + @property + def dim_latent(self): + return self.clip.dim + + @property + def image_size(self): + assert 'image_size' in self.overrides + return self.overrides['image_size'] + + @property + def image_channels(self): + assert 'image_channels' in self.overrides + return self.overrides['image_channels'] + + @property + def max_text_len(self): + assert 'max_text_len' in self.overrides + return self.overrides['max_text_len'] + + @torch.no_grad() + def embed_text(self, text): + text = text[..., :self.max_text_len] + text_mask = text != 0 + text_embed, text_encodings = self.clip.embed_text(text) + return EmbeddedText(text_embed, text_encodings, text_mask) + + @torch.no_grad() + def embed_image(self, image): + image = resize_image_to(image, self.image_size) + image_embed, image_encodings = self.clip.embed_image(image) + return EmbeddedImage(image_embed, image_encodings) + class OpenAIClipAdapter(BaseClipAdapter): def __init__( self, @@ -755,6 +790,7 @@ class DiffusionPrior(BaseGaussianDiffusion): condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training sampling_clamp_l2norm = False, image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 + clip_adapter_overrides = dict() ): super().__init__( beta_schedule = beta_schedule, @@ -764,7 +800,9 @@ class DiffusionPrior(BaseGaussianDiffusion): if exists(clip): if isinstance(clip, CLIP): - clip = XClipAdapter(clip) + clip = XClipAdapter(clip, **clip_adapter_overrides) + elif isinstance(clip, CoCa): + clip = CoCaAdapter(clip, **clip_adapter_overrides) assert isinstance(clip, BaseClipAdapter) freeze_model_and_make_eval_(clip) @@ -1487,7 +1525,8 @@ class Decoder(BaseGaussianDiffusion): blur_kernel_size = 3, # cascading ddpm - blur kernel size condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation clip_denoised = True, - clip_x_start = True + clip_x_start = True, + clip_adapter_overrides = dict() ): super().__init__( beta_schedule = beta_schedule, @@ -1500,7 +1539,9 @@ class Decoder(BaseGaussianDiffusion): self.clip = None if exists(clip): if isinstance(clip, CLIP): - clip = XClipAdapter(clip) + clip = XClipAdapter(clip, **clip_adapter_overrides) + elif isinstance(clip, CoCa): + clip = CoCaAdapter(clip, **clip_adapter_overrides) freeze_model_and_make_eval_(clip) assert isinstance(clip, BaseClipAdapter) diff --git a/setup.py b/setup.py index a22aec5..cfe6fcc 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.108', + version = '0.0.109', license='MIT', description = 'DALL-E 2', author = 'Phil Wang', @@ -24,6 +24,7 @@ setup( install_requires=[ 'click', 'clip-anytorch', + 'coca-pytorch>=0.0.5', 'einops>=0.4', 'einops-exts>=0.0.3', 'embedding-reader',