mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
28e944f328 | ||
|
|
14e63a3f67 | ||
|
|
09e9eaa5a6 | ||
|
|
e6d752cf4a | ||
|
|
ad20a14a4d | ||
|
|
0be1e0d64c | ||
|
|
98df1ba51e |
13
README.md
13
README.md
@@ -967,7 +967,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
|
||||
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
- [ ] train on a toy task, offer in colab
|
||||
@@ -980,6 +980,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -1047,4 +1048,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2022CoCaCC,
|
||||
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
|
||||
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2205.01917}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -23,9 +23,14 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
|
||||
from resize_right import resize
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
|
||||
# use x-clip
|
||||
|
||||
from x_clip import CLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
# helper functions
|
||||
|
||||
@@ -113,9 +118,10 @@ EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 't
|
||||
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
||||
|
||||
class BaseClipAdapter(nn.Module):
|
||||
def __init__(self, clip):
|
||||
def __init__(self, clip, **kwargs):
|
||||
super().__init__()
|
||||
self.clip = clip
|
||||
self.overrides = kwargs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
@@ -173,6 +179,39 @@ class XClipAdapter(BaseClipAdapter):
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return EmbeddedImage(l2norm(image_embed), image_encodings)
|
||||
|
||||
class CoCaAdapter(BaseClipAdapter):
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return self.clip.dim
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
assert 'image_size' in self.overrides
|
||||
return self.overrides['image_size']
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
assert 'image_channels' in self.overrides
|
||||
return self.overrides['image_channels']
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
assert 'max_text_len' in self.overrides
|
||||
return self.overrides['max_text_len']
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
return EmbeddedText(text_embed, text_encodings, text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image_embed, image_encodings = self.clip.embed_image(image)
|
||||
return EmbeddedImage(image_embed, image_encodings)
|
||||
|
||||
class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -225,7 +264,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
del self.text_encodings
|
||||
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
|
||||
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
@@ -233,7 +272,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(image_embed.float(), None)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
@@ -531,7 +570,8 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
post_norm = False
|
||||
post_norm = False,
|
||||
rotary_emb = None
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -547,6 +587,8 @@ class Attention(nn.Module):
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim) if post_norm else nn.Identity()
|
||||
@@ -559,6 +601,12 @@ class Attention(nn.Module):
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
||||
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||
q = q * self.scale
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
if exists(self.rotary_emb):
|
||||
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
|
||||
|
||||
# add null key / value for classifier free guidance in prior net
|
||||
|
||||
@@ -566,7 +614,7 @@ class Attention(nn.Module):
|
||||
k = torch.cat((nk, k), dim = -2)
|
||||
v = torch.cat((nv, v), dim = -2)
|
||||
|
||||
q = q * self.scale
|
||||
# calculate query / key similarities
|
||||
|
||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||
|
||||
@@ -616,15 +664,18 @@ class CausalTransformer(nn.Module):
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
final_proj = True,
|
||||
normformer = False
|
||||
normformer = False,
|
||||
rotary_emb = True
|
||||
):
|
||||
super().__init__()
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
|
||||
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
@@ -754,7 +805,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False,
|
||||
training_clamp_l2norm = False,
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -764,7 +817,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
if exists(clip):
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
clip = CoCaAdapter(clip, **clip_adapter_overrides)
|
||||
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
@@ -788,6 +843,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||
self.training_clamp_l2norm = training_clamp_l2norm
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
pred = self.net(x, t, **text_cond)
|
||||
@@ -840,6 +896,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
**text_cond
|
||||
)
|
||||
|
||||
if self.predict_x_start and self.training_clamp_l2norm:
|
||||
pred = l2norm(pred) * self.image_embed_scale
|
||||
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
@@ -1487,7 +1546,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -1500,7 +1560,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
clip = CoCaAdapter(clip, **clip_adapter_overrides)
|
||||
|
||||
freeze_model_and_make_eval_(clip)
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
|
||||
@@ -111,11 +111,6 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
# exponential moving average
|
||||
|
||||
self.use_ema = use_ema
|
||||
|
||||
if use_ema:
|
||||
has_lazy_linear = any([type(module) == nn.LazyLinear for module in diffusion_prior.modules()])
|
||||
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
||||
|
||||
|
||||
4
setup.py
4
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.108',
|
||||
version = '0.1.4',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -24,12 +24,14 @@ setup(
|
||||
install_requires=[
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'coca-pytorch>=0.0.5',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'pillow',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
|
||||
Reference in New Issue
Block a user