mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
support CoCa, which seems to be better than CLIP (has an autoregressive text encoder) https://arxiv.org/abs/2205.01917
This commit is contained in:
10
README.md
10
README.md
@@ -1047,4 +1047,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2022CoCaCC,
|
||||
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
|
||||
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2205.01917}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -26,6 +26,7 @@ from resize_right import resize
|
||||
# use x-clip
|
||||
|
||||
from x_clip import CLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
# helper functions
|
||||
|
||||
@@ -113,9 +114,10 @@ EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 't
|
||||
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
||||
|
||||
class BaseClipAdapter(nn.Module):
|
||||
def __init__(self, clip):
|
||||
def __init__(self, clip, **kwargs):
|
||||
super().__init__()
|
||||
self.clip = clip
|
||||
self.overrides = kwargs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
@@ -173,6 +175,39 @@ class XClipAdapter(BaseClipAdapter):
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return EmbeddedImage(l2norm(image_embed), image_encodings)
|
||||
|
||||
class CoCaAdapter(BaseClipAdapter):
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return self.clip.dim
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
assert 'image_size' in self.overrides
|
||||
return self.overrides['image_size']
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
assert 'image_channels' in self.overrides
|
||||
return self.overrides['image_channels']
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
assert 'max_text_len' in self.overrides
|
||||
return self.overrides['max_text_len']
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
return EmbeddedText(text_embed, text_encodings, text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image_embed, image_encodings = self.clip.embed_image(image)
|
||||
return EmbeddedImage(image_embed, image_encodings)
|
||||
|
||||
class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -755,6 +790,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False,
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -764,7 +800,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
if exists(clip):
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
clip = CoCaAdapter(clip, **clip_adapter_overrides)
|
||||
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
@@ -1487,7 +1525,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -1500,7 +1539,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
clip = CoCaAdapter(clip, **clip_adapter_overrides)
|
||||
|
||||
freeze_model_and_make_eval_(clip)
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
|
||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.108',
|
||||
version = '0.0.109',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -24,6 +24,7 @@ setup(
|
||||
install_requires=[
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'coca-pytorch>=0.0.5',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
|
||||
Reference in New Issue
Block a user