Compare commits

..

15 Commits

Author SHA1 Message Date
Phil Wang
1992d25cad project management 2022-05-04 11:18:54 -07:00
Phil Wang
5b619c2fd5 make sure some hyperparameters for unet block is configurable 2022-05-04 11:18:32 -07:00
Phil Wang
9359ad2e91 0.0.95 2022-05-04 10:53:05 -07:00
Phil Wang
9ff228188b offer old resnet blocks, from the original DDPM paper, just in case convnexts are unsuitable for generative work 2022-05-04 10:52:58 -07:00
Kumar R
2d9963d30e Reporting metrics - Cosine similarity. (#55)
* Update train_diffusion_prior.py

* Delete train_diffusion_prior.py

* Cosine similarity logging.

* Update train_diffusion_prior.py

* Report Cosine metrics every N steps.
2022-05-04 08:04:36 -07:00
Phil Wang
58d9b422f3 0.0.94 2022-05-04 07:42:33 -07:00
Ray Bell
44b319cb57 add missing import (#56) 2022-05-04 07:42:20 -07:00
Phil Wang
c30f380689 final reminder 2022-05-03 08:18:53 -07:00
Phil Wang
e4e884bb8b keep all doors open 2022-05-03 08:17:02 -07:00
Phil Wang
803ad9c17d product management again 2022-05-03 08:15:25 -07:00
Phil Wang
a88dd6a9c0 todo 2022-05-03 08:09:02 -07:00
Kumar R
72c16b496e Update train_diffusion_prior.py (#53) 2022-05-02 22:44:57 -07:00
z
81d83dd7f2 defaults align with paper (#52)
Co-authored-by: nousr <>
2022-05-02 13:52:11 -07:00
Phil Wang
fa66f7e1e9 todo 2022-05-02 12:57:15 -07:00
Phil Wang
aa8d135245 allow laion to experiment with normformer in diffusion prior 2022-05-02 11:35:00 -07:00
5 changed files with 160 additions and 25 deletions

View File

@@ -821,7 +821,8 @@ Once built, images will be saved to the same directory the command is invoked
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [x] bring in tools to train vqgan-vae
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo)
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
@@ -832,6 +833,9 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
## Citations
@@ -907,4 +911,4 @@ Once built, images will be saved to the same directory the command is invoked
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,7 @@
import click
import torch
import torchvision.transforms as T
from functools import reduce
from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior

View File

@@ -930,6 +930,72 @@ class SinusoidalPosEmb(nn.Module):
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
groups = 8
):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
nn.GroupNorm(groups, dim_out),
nn.SiLU()
)
def forward(self, x):
return self.block(x)
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
groups = 8
):
super().__init__()
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out)
)
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim_out,
context_dim = cond_dim
)
)
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None):
h = self.block1(x)
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
h = rearrange(time_emb, 'b c -> b c 1 1') + h
if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context = cond) + h
h = self.block2(h)
return h + self.res_conv(x)
class ConvNextBlock(nn.Module):
""" https://arxiv.org/abs/2201.03545 """
@@ -940,8 +1006,7 @@ class ConvNextBlock(nn.Module):
*,
cond_dim = None,
time_cond_dim = None,
mult = 2,
norm = True
mult = 2
):
super().__init__()
need_projection = dim != dim_out
@@ -970,7 +1035,7 @@ class ConvNextBlock(nn.Module):
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim) if norm else nn.Identity(),
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
@@ -1082,7 +1147,11 @@ class LinearAttention(nn.Module):
self.nonlin = nn.GELU()
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1, bias = False),
ChanLayerNorm(dim)
)
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
@@ -1125,7 +1194,11 @@ class Unet(nn.Module):
max_text_len = 256,
cond_on_image_embeds = False,
init_dim = None,
init_conv_kernel_size = 7
init_conv_kernel_size = 7,
block_type = 'resnet',
block_resnet_groups = 8,
block_convnext_mult = 2,
**kwargs
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
@@ -1200,6 +1273,15 @@ class Unet(nn.Module):
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# whether to use resnet or the (improved?) convnext blocks
if block_type == 'resnet':
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
elif block_type == 'convnext':
block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
else:
raise ValueError(f'unimplemented block type {block_type}')
# layers
self.downs = nn.ModuleList([])
@@ -1212,32 +1294,32 @@ class Unet(nn.Module):
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Upsample(dim_in)
]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
ConvNextBlock(dim, dim),
block_klass(dim, dim),
nn.Conv2d(dim, out_dim, 1)
)
@@ -1368,10 +1450,10 @@ class Unet(nn.Module):
hiddens = []
for convnext, sparse_attn, convnext2, downsample in self.downs:
x = convnext(x, c, t)
for block1, sparse_attn, block2, downsample in self.downs:
x = block1(x, c, t)
x = sparse_attn(x)
x = convnext2(x, c, t)
x = block2(x, c, t)
hiddens.append(x)
x = downsample(x)
@@ -1382,11 +1464,11 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t)
for convnext, sparse_attn, convnext2, upsample in self.ups:
for block1, sparse_attn, block2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c, t)
x = block1(x, c, t)
x = sparse_attn(x)
x = convnext2(x, c, t)
x = block2(x, c, t)
x = upsample(x)
return self.final_conv(x)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.93',
version = '0.0.96',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',

View File

@@ -1,18 +1,23 @@
import os
import math
import argparse
import numpy as np
import torch
from torch import nn
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler
import time
from tqdm import tqdm
import wandb
os.environ["WANDB_SILENT"] = "true"
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval()
@@ -41,6 +46,33 @@ def save_model(save_path, state_dict):
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size+val_set_size
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
text_embed = torch.tensor(embt[0]).to(device)
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed = text_embed)
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
return np.mean(predicted_similarity - original_similarity)
def train(image_embed_dim,
image_embed_url,
text_embed_url,
@@ -54,6 +86,7 @@ def train(image_embed_dim,
dp_condition_on_text_encodings,
dp_timesteps,
dp_l2norm_output,
dp_normformer,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
@@ -72,6 +105,7 @@ def train(image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
normformer = dp_normformer,
l2norm_output = dp_l2norm_output).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
@@ -132,9 +166,21 @@ def train(image_embed_dim,
wandb.log({"Training loss": loss.item(),
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
diff_cosine_sim = report_cosine_sims(diffusion_prior,
image_reader,
text_reader,
train_set_size,
val_set_size,
NUM_TEST_EMBEDDINGS,
device)
wandb.log({"Cosine similarity difference": diff_cosine_sim})
scaler.unscale_(optimizer)
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
@@ -163,8 +209,8 @@ def main():
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
# Hyperparameters
parser.add_argument("--learning-rate", type=float, default=0.001)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5)
@@ -183,7 +229,8 @@ def main():
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
parser.add_argument("--dp-normformer", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None)
parser.add_argument("--amp", type=bool, default=False)
@@ -227,6 +274,7 @@ def main():
args.dp_condition_on_text_encodings,
args.dp_timesteps,
args.dp_l2norm_output,
args.dp_normformer,
args.dp_cond_drop_prob,
args.dpn_depth,
args.dpn_dim_head,