Compare commits

..

11 Commits

Author SHA1 Message Date
z
cd5f2c1de4 simulate unrelated captions as a training metric (#66)
* add unrelated embedding metric

* change to torch.roll

Co-authored-by: nousr <z@localhost.com>
Co-authored-by: nousr <>
2022-05-07 05:34:59 -07:00
Phil Wang
85ed77d512 fix a potentially huge bug thanks to @CiaoHe https://github.com/lucidrains/DALLE2-pytorch/issues/71 2022-05-07 05:05:54 -07:00
Piero Rolando
fd53fa17db Fix a typo in README (#70)
Change "pyhon" for "python" (correct)
2022-05-06 16:53:36 -07:00
Phil Wang
3676ef4d49 make sure vqgan-vae trainer supports mixed precision 2022-05-06 10:44:16 -07:00
Phil Wang
28e944f328 make sure openai clip adapter outputs l2normed embeddings 2022-05-06 10:12:03 -07:00
Phil Wang
14e63a3f67 also offer l2norm clamping in diffusion prior during training, if one were using predict x0 objective 2022-05-06 10:05:14 -07:00
Phil Wang
09e9eaa5a6 project management 2022-05-06 09:00:22 -07:00
Phil Wang
e6d752cf4a reprioritize 2022-05-06 08:55:26 -07:00
Phil Wang
ad20a14a4d bring in rotary embeddings for diffusion prior causal transformer (the most powerful relative positional encoding, used in PaLM) - 0.1.0 because of breaking change 2022-05-06 08:45:30 -07:00
Phil Wang
0be1e0d64c support CoCa, which seems to be better than CLIP (has an autoregressive text encoder) https://arxiv.org/abs/2205.01917 2022-05-06 08:27:12 -07:00
Phil Wang
98df1ba51e add diffusion prior trainer, which automatically takes care of the exponential moving average (training and sampling), as well as mixed precision, gradient clipping 2022-05-06 08:11:09 -07:00
6 changed files with 161 additions and 47 deletions

View File

@@ -902,7 +902,7 @@ Please note that the script internally passes text_embed and image_embed to the
### Usage
```bash
$ pyhon train_diffusion_prior.py
$ python train_diffusion_prior.py
```
The most significant parameters for the script are as follows:
@@ -967,7 +967,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [ ] train on a toy task, offer in colab
@@ -980,6 +980,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
## Citations
@@ -1047,4 +1048,14 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@article{Yu2022CoCaCC,
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
journal = {ArXiv},
year = {2022},
volume = {abs/2205.01917}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -23,9 +23,14 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
from resize_right import resize
# rotary embeddings
from rotary_embedding_torch import RotaryEmbedding
# use x-clip
from x_clip import CLIP
from coca_pytorch import CoCa
# helper functions
@@ -113,9 +118,10 @@ EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 't
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module):
def __init__(self, clip):
def __init__(self, clip, **kwargs):
super().__init__()
self.clip = clip
self.overrides = kwargs
@property
def dim_latent(self):
@@ -173,6 +179,39 @@ class XClipAdapter(BaseClipAdapter):
image_embed = self.clip.to_visual_latent(image_cls)
return EmbeddedImage(l2norm(image_embed), image_encodings)
class CoCaAdapter(BaseClipAdapter):
@property
def dim_latent(self):
return self.clip.dim
@property
def image_size(self):
assert 'image_size' in self.overrides
return self.overrides['image_size']
@property
def image_channels(self):
assert 'image_channels' in self.overrides
return self.overrides['image_channels']
@property
def max_text_len(self):
assert 'max_text_len' in self.overrides
return self.overrides['max_text_len']
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
text_embed, text_encodings = self.clip.embed_text(text)
return EmbeddedText(text_embed, text_encodings, text_mask)
@torch.no_grad()
def embed_image(self, image):
image = resize_image_to(image, self.image_size)
image_embed, image_encodings = self.clip.embed_image(image)
return EmbeddedImage(image_embed, image_encodings)
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(
self,
@@ -225,7 +264,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
@torch.no_grad()
def embed_image(self, image):
@@ -233,7 +272,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image)
return EmbeddedImage(image_embed.float(), None)
return EmbeddedImage(l2norm(image_embed.float()), None)
# classifier free guidance functions
@@ -531,7 +570,8 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
post_norm = False
post_norm = False,
rotary_emb = None
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -547,6 +587,8 @@ class Attention(nn.Module):
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.rotary_emb = rotary_emb
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
@@ -559,6 +601,12 @@ class Attention(nn.Module):
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
q = q * self.scale
# rotary embeddings
if exists(self.rotary_emb):
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
# add null key / value for classifier free guidance in prior net
@@ -566,7 +614,7 @@ class Attention(nn.Module):
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
# calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k)
@@ -616,15 +664,18 @@ class CausalTransformer(nn.Module):
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True,
normformer = False
normformer = False,
rotary_emb = True
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
@@ -714,7 +765,7 @@ class DiffusionPriorNetwork(nn.Module):
# but let's just do it right
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
@@ -725,6 +776,7 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings,
text_embed,
time_embed,
image_embed,
learned_queries
), dim = -2)
@@ -754,7 +806,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False,
training_clamp_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__(
beta_schedule = beta_schedule,
@@ -764,7 +818,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
if exists(clip):
if isinstance(clip, CLIP):
clip = XClipAdapter(clip)
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
clip = CoCaAdapter(clip, **clip_adapter_overrides)
assert isinstance(clip, BaseClipAdapter)
freeze_model_and_make_eval_(clip)
@@ -788,6 +844,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
@@ -840,6 +897,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
**text_cond
)
if self.predict_x_start and self.training_clamp_l2norm:
pred = l2norm(pred) * self.image_embed_scale
target = noise if not self.predict_x_start else image_embed
loss = self.loss_fn(pred, target)
@@ -1487,7 +1547,8 @@ class Decoder(BaseGaussianDiffusion):
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True,
clip_x_start = True
clip_x_start = True,
clip_adapter_overrides = dict()
):
super().__init__(
beta_schedule = beta_schedule,
@@ -1500,7 +1561,9 @@ class Decoder(BaseGaussianDiffusion):
self.clip = None
if exists(clip):
if isinstance(clip, CLIP):
clip = XClipAdapter(clip)
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
clip = CoCaAdapter(clip, **clip_adapter_overrides)
freeze_model_and_make_eval_(clip)
assert isinstance(clip, BaseClipAdapter)

View File

@@ -111,11 +111,6 @@ class DiffusionPriorTrainer(nn.Module):
# exponential moving average
self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in diffusion_prior.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)

View File

@@ -3,14 +3,15 @@ import copy
from random import choice
from pathlib import Path
from shutil import rmtree
from PIL import Image
import torch
from torch import nn
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image
from einops import rearrange
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
ema_update_after_step = 2000,
ema_update_every = 10,
apply_grad_penalty_every = 4,
amp = False
):
super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.discr_scaler = GradScaler(enabled = amp)
# create dataset
self.ds = ImageDataset(folder, image_size = image_size)
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl)
img = img.to(device)
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)
with autocast(enabled = self.amp):
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)
self.scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.optim.step()
self.scaler.step(self.optim)
self.scaler.update()
self.optim.zero_grad()
# update discriminator
if exists(self.vae.discr):
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl)
img = img.to(device)
loss = self.vae(img, return_discr_loss = True)
with autocast(enabled = self.amp):
loss = self.vae(img, return_discr_loss = True)
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.discr_optim.step()
self.discr_scaler.step(self.discr_optim)
self.discr_scaler.update()
self.discr_optim.zero_grad()
# log

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.108',
version = '0.1.6',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -24,12 +24,14 @@ setup(
install_requires=[
'click',
'clip-anytorch',
'coca-pytorch>=0.0.5',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'pillow',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',
'torchvision',
'tqdm',

View File

@@ -46,28 +46,60 @@ def save_model(save_path, state_dict):
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_size, val_set_size, NUM_TEST_EMBEDDINGS, device):
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size+val_set_size
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend), image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
# make a copy of the text embeddings for shuffling
text_embed = torch.tensor(embt[0]).to(device)
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed = text_embed)
text_embed_shuffled = text_embed.clone()
# roll the text embeddings to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
# prepare the text embedding
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed=text_embed)
# prepare image embeddings
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
# predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True)
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
# predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
# calculate similarities
original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
wandb.log(
{"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)": np.mean(
predicted_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(
unrelated_similarity)})
return np.mean(predicted_similarity - original_similarity)