mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 20:04:20 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a65a86d03 | ||
|
|
0be1e0d64c | ||
|
|
98df1ba51e |
10
README.md
10
README.md
@@ -1047,4 +1047,14 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Yu2022CoCaCC,
|
||||||
|
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
|
||||||
|
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2022},
|
||||||
|
volume = {abs/2205.01917}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -23,9 +23,14 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
|||||||
|
|
||||||
from resize_right import resize
|
from resize_right import resize
|
||||||
|
|
||||||
|
# rotary embeddings
|
||||||
|
|
||||||
|
from rotary_embedding_torch import RotaryEmbedding
|
||||||
|
|
||||||
# use x-clip
|
# use x-clip
|
||||||
|
|
||||||
from x_clip import CLIP
|
from x_clip import CLIP
|
||||||
|
from coca_pytorch import CoCa
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
@@ -113,9 +118,10 @@ EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 't
|
|||||||
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
||||||
|
|
||||||
class BaseClipAdapter(nn.Module):
|
class BaseClipAdapter(nn.Module):
|
||||||
def __init__(self, clip):
|
def __init__(self, clip, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
|
self.overrides = kwargs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dim_latent(self):
|
def dim_latent(self):
|
||||||
@@ -173,6 +179,39 @@ class XClipAdapter(BaseClipAdapter):
|
|||||||
image_embed = self.clip.to_visual_latent(image_cls)
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
return EmbeddedImage(l2norm(image_embed), image_encodings)
|
return EmbeddedImage(l2norm(image_embed), image_encodings)
|
||||||
|
|
||||||
|
class CoCaAdapter(BaseClipAdapter):
|
||||||
|
@property
|
||||||
|
def dim_latent(self):
|
||||||
|
return self.clip.dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_size(self):
|
||||||
|
assert 'image_size' in self.overrides
|
||||||
|
return self.overrides['image_size']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_channels(self):
|
||||||
|
assert 'image_channels' in self.overrides
|
||||||
|
return self.overrides['image_channels']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_text_len(self):
|
||||||
|
assert 'max_text_len' in self.overrides
|
||||||
|
return self.overrides['max_text_len']
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def embed_text(self, text):
|
||||||
|
text = text[..., :self.max_text_len]
|
||||||
|
text_mask = text != 0
|
||||||
|
text_embed, text_encodings = self.clip.embed_text(text)
|
||||||
|
return EmbeddedText(text_embed, text_encodings, text_mask)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def embed_image(self, image):
|
||||||
|
image = resize_image_to(image, self.image_size)
|
||||||
|
image_embed, image_encodings = self.clip.embed_image(image)
|
||||||
|
return EmbeddedImage(image_embed, image_encodings)
|
||||||
|
|
||||||
class OpenAIClipAdapter(BaseClipAdapter):
|
class OpenAIClipAdapter(BaseClipAdapter):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -531,7 +570,8 @@ class Attention(nn.Module):
|
|||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
causal = False,
|
causal = False,
|
||||||
post_norm = False
|
post_norm = False,
|
||||||
|
rotary_emb = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
@@ -547,6 +587,8 @@ class Attention(nn.Module):
|
|||||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||||
|
|
||||||
|
self.rotary_emb = rotary_emb
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, dim, bias = False),
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
LayerNorm(dim) if post_norm else nn.Identity()
|
LayerNorm(dim) if post_norm else nn.Identity()
|
||||||
@@ -559,6 +601,12 @@ class Attention(nn.Module):
|
|||||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
||||||
|
|
||||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# rotary embeddings
|
||||||
|
|
||||||
|
if exists(self.rotary_emb):
|
||||||
|
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
|
||||||
|
|
||||||
# add null key / value for classifier free guidance in prior net
|
# add null key / value for classifier free guidance in prior net
|
||||||
|
|
||||||
@@ -566,7 +614,7 @@ class Attention(nn.Module):
|
|||||||
k = torch.cat((nk, k), dim = -2)
|
k = torch.cat((nk, k), dim = -2)
|
||||||
v = torch.cat((nv, v), dim = -2)
|
v = torch.cat((nv, v), dim = -2)
|
||||||
|
|
||||||
q = q * self.scale
|
# calculate query / key similarities
|
||||||
|
|
||||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||||
|
|
||||||
@@ -616,15 +664,18 @@ class CausalTransformer(nn.Module):
|
|||||||
attn_dropout = 0.,
|
attn_dropout = 0.,
|
||||||
ff_dropout = 0.,
|
ff_dropout = 0.,
|
||||||
final_proj = True,
|
final_proj = True,
|
||||||
normformer = False
|
normformer = False,
|
||||||
|
rotary_emb = True
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||||
|
|
||||||
|
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(nn.ModuleList([
|
self.layers.append(nn.ModuleList([
|
||||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
|
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
||||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@@ -755,6 +806,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
sampling_clamp_l2norm = False,
|
sampling_clamp_l2norm = False,
|
||||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||||
|
clip_adapter_overrides = dict()
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -764,7 +816,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
if isinstance(clip, CLIP):
|
if isinstance(clip, CLIP):
|
||||||
clip = XClipAdapter(clip)
|
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||||
|
elif isinstance(clip, CoCa):
|
||||||
|
clip = CoCaAdapter(clip, **clip_adapter_overrides)
|
||||||
|
|
||||||
assert isinstance(clip, BaseClipAdapter)
|
assert isinstance(clip, BaseClipAdapter)
|
||||||
freeze_model_and_make_eval_(clip)
|
freeze_model_and_make_eval_(clip)
|
||||||
@@ -1487,7 +1541,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||||
clip_denoised = True,
|
clip_denoised = True,
|
||||||
clip_x_start = True
|
clip_x_start = True,
|
||||||
|
clip_adapter_overrides = dict()
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -1500,7 +1555,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.clip = None
|
self.clip = None
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
if isinstance(clip, CLIP):
|
if isinstance(clip, CLIP):
|
||||||
clip = XClipAdapter(clip)
|
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||||
|
elif isinstance(clip, CoCa):
|
||||||
|
clip = CoCaAdapter(clip, **clip_adapter_overrides)
|
||||||
|
|
||||||
freeze_model_and_make_eval_(clip)
|
freeze_model_and_make_eval_(clip)
|
||||||
assert isinstance(clip, BaseClipAdapter)
|
assert isinstance(clip, BaseClipAdapter)
|
||||||
|
|||||||
@@ -111,11 +111,6 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
# exponential moving average
|
# exponential moving average
|
||||||
|
|
||||||
self.use_ema = use_ema
|
self.use_ema = use_ema
|
||||||
|
|
||||||
if use_ema:
|
|
||||||
has_lazy_linear = any([type(module) == nn.LazyLinear for module in diffusion_prior.modules()])
|
|
||||||
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
|
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
||||||
|
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.0.108',
|
version = '0.1.0',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -24,6 +24,7 @@ setup(
|
|||||||
install_requires=[
|
install_requires=[
|
||||||
'click',
|
'click',
|
||||||
'clip-anytorch',
|
'clip-anytorch',
|
||||||
|
'coca-pytorch>=0.0.5',
|
||||||
'einops>=0.4',
|
'einops>=0.4',
|
||||||
'einops-exts>=0.0.3',
|
'einops-exts>=0.0.3',
|
||||||
'embedding-reader',
|
'embedding-reader',
|
||||||
|
|||||||
Reference in New Issue
Block a user