Compare commits

...

20 Commits

Author SHA1 Message Date
Phil Wang
70282de23b add ability to turn on normformer settings, given @borisdayma reported good results and some personal anecdata 2022-05-02 11:33:15 -07:00
Phil Wang
83f761847e todo 2022-05-02 10:52:39 -07:00
Phil Wang
11469dc0c6 makes more sense to keep this as True as default, for stability 2022-05-02 10:50:55 -07:00
Romain Beaumont
2d25c89f35 Fix passing of l2norm_output to DiffusionPriorNetwork (#51) 2022-05-02 10:48:16 -07:00
Phil Wang
3fe96c208a add ability to train diffusion prior with l2norm on output image embed 2022-05-02 09:53:20 -07:00
Phil Wang
0fc6c9cdf3 provide option to l2norm the output of the diffusion prior 2022-05-02 09:41:03 -07:00
Phil Wang
7ee0ecc388 mixed precision for training diffusion prior + save optimizer and scaler states 2022-05-02 09:31:04 -07:00
Phil Wang
1924c7cc3d fix issue with mixed precision and gradient clipping 2022-05-02 09:20:19 -07:00
Phil Wang
f7df3caaf3 address not calculating average eval / test loss when training diffusion prior https://github.com/lucidrains/DALLE2-pytorch/issues/49 2022-05-02 08:51:41 -07:00
Phil Wang
fc954ee788 fix calculation of adaptive weight for vit-vqgan, thanks to @CiaoHe 2022-05-02 07:58:14 -07:00
Phil Wang
c1db2753f5 todo 2022-05-01 18:02:30 -07:00
Phil Wang
ad87bfe28f switch to using linear attention for the sparse attention layers within unet, given success in GAN projects 2022-05-01 17:59:03 -07:00
Phil Wang
76c767b1ce update deps, commit to using webdatasets, per @rom1504 consultation 2022-05-01 12:22:15 -07:00
Phil Wang
d991b8c39c just clip the diffusion prior network parameters 2022-05-01 12:01:08 -07:00
Phil Wang
902693e271 todo 2022-05-01 11:57:08 -07:00
Phil Wang
35cd63982d add gradient clipping, make sure weight decay is configurable, make sure learning rate is actually passed into get_optimizer, make sure model is set to training mode at beginning of each epoch 2022-05-01 11:55:38 -07:00
Kumar R
53ce6dfdf6 All changes implemented, current run happening. Link to wandb run in comments. (#43)
* Train DiffusionPrior with pre-computed embeddings

This is in response to https://github.com/lucidrains/DALLE2-pytorch/issues/29 - more metrics will get added.
2022-05-01 11:46:59 -07:00
Phil Wang
ad8d7a368b product management 2022-05-01 11:26:21 -07:00
Phil Wang
b8cf1e5c20 more attention 2022-05-01 11:00:33 -07:00
Phil Wang
94aaa08d97 product management 2022-05-01 09:43:10 -07:00
6 changed files with 353 additions and 18 deletions

View File

@@ -824,9 +824,14 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
## Citations
@@ -858,12 +863,22 @@ Once built, images will be saved to the same directory the command is invoked
```bibtex
@inproceedings{Liu2022ACF,
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
title = {A ConvNet for the 2020s},
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022}
}
```
```bibtex
@article{shen2019efficient,
author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
title = {Efficient Attention: Attention with Linear Complexities},
journal = {CoRR},
year = {2018},
url = {http://arxiv.org/abs/1812.01243},
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
@@ -882,4 +897,14 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@article{Shleifer2021NormFormerIT,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Jason Weston and Myle Ott},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.09456}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -29,6 +29,9 @@ from x_clip import CLIP
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def default(val, d):
if exists(val):
return val
@@ -496,7 +499,12 @@ class SwiGLU(nn.Module):
x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate)
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
def FeedForward(
dim,
mult = 4,
dropout = 0.,
post_activation_norm = False
):
""" post-activation norm https://arxiv.org/abs/2110.09456 """
inner_dim = int(mult * dim)
@@ -519,7 +527,8 @@ class Attention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
causal = False
causal = False,
post_norm = False
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -534,7 +543,11 @@ class Attention(nn.Module):
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
)
def forward(self, x, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device
@@ -596,10 +609,11 @@ class CausalTransformer(nn.Module):
dim_head = 64,
heads = 8,
ff_mult = 4,
norm_out = False,
norm_out = True,
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True
final_proj = True,
normformer = False
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
@@ -607,8 +621,8 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
@@ -635,12 +649,14 @@ class DiffusionPriorNetwork(nn.Module):
self,
dim,
num_timesteps = None,
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
self.l2norm_output = l2norm_output
def forward_with_cond_scale(
self,
@@ -719,7 +735,8 @@ class DiffusionPriorNetwork(nn.Module):
pred_image_embed = tokens[..., -1, :]
return pred_image_embed
output_fn = l2norm if self.l2norm_output else identity
return output_fn(pred_image_embed)
class DiffusionPrior(BaseGaussianDiffusion):
def __init__(
@@ -1050,6 +1067,42 @@ class GridAttention(nn.Module):
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class LinearAttention(nn.Module):
def __init__(
self,
dim,
dim_head = 32,
heads = 8
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.nonlin = nn.GELU()
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
out = self.nonlin(out)
return self.to_out(out)
class Unet(nn.Module):
def __init__(
self,
@@ -1064,10 +1117,9 @@ class Unet(nn.Module):
dim_mults=(1, 2, 4, 8),
channels = 3,
attn_dim_head = 32,
attn_heads = 8,
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
@@ -1161,7 +1213,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
@@ -1178,7 +1230,7 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Upsample(dim_in)
]))

View File

@@ -159,12 +159,13 @@ class DecoderTrainer(nn.Module):
index = unet_number - 1
unet = self.decoder.unets[index]
if exists(self.max_grad_norm):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

View File

@@ -285,6 +285,10 @@ class ResnetEncDec(nn.Module):
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
@property
def last_dec_layer(self):
return self.decoders[-1].weight
def encode(self, x):
for enc in self.encoders:
x = enc(x)
@@ -419,6 +423,10 @@ class ConvNextEncDec(nn.Module):
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
@property
def last_dec_layer(self):
return self.decoders[-1].weight
def encode(self, x):
for enc in self.encoders:
x = enc(x)
@@ -606,6 +614,10 @@ class ViTEncDec(nn.Module):
def get_encoded_fmap_size(self, image_size):
return image_size // self.patch_size
@property
def last_dec_layer(self):
return self.decoder[-3][-1].weight
def encode(self, x):
return self.encoder(x)
@@ -843,7 +855,7 @@ class VQGanVAE(nn.Module):
# calculate adaptive weight
last_dec_layer = self.decoders[-1].weight
last_dec_layer = self.enc_dec.last_dec_layer
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.86',
version = '0.0.93',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -26,12 +26,14 @@ setup(
'clip-anytorch',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'pillow',
'torch>=1.10',
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'webdataset',
'x-clip>=0.5.1',
'youtokentome'
],

243
train_diffusion_prior.py Normal file
View File

@@ -0,0 +1,243 @@
import os
import math
import argparse
import torch
from torch import nn
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer
import time
from tqdm import tqdm
import wandb
os.environ["WANDB_SILENT"] = "true"
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval()
with torch.no_grad():
total_loss = 0.
total_samples = 0.
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
text_reader(batch_size=batch_size, start=start, end=end)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
batches = emb_images_tensor.shape[0]
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
total_loss += loss.item() * batches
total_samples += batches
avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss})
def save_model(save_path, state_dict):
# Saving State Dict
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
def train(image_embed_dim,
image_embed_url,
text_embed_url,
batch_size,
train_percent,
val_percent,
test_percent,
num_epochs,
dp_loss_type,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_l2norm_output,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
dpn_heads,
save_interval,
save_path,
device,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01,
amp=False):
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
l2norm_output = dp_l2norm_output).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
# Get image and text embeddings from the servers
print("==============Downloading embeddings - image and text====================")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count
# Create save_path if it doesn't exist
if not os.path.exists(save_path):
os.makedirs(save_path)
### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs
step = 0
t = time.time()
train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points)
for _ in range(epochs):
diffusion_prior.train()
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
with autocast(enabled=amp):
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
scaler.scale(loss).backward()
# Samples per second
step+=1
samples_per_sec = batch_size*step/(time.time()-t)
# Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval):
t = time.time()
save_model(
save_path,
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
# Log to wandb
wandb.log({"Training loss": loss.item(),
"Steps": step,
"Samples per second": samples_per_sec})
scaler.unscale_(optimizer)
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
### Evaluate model(validation run) ###
start = train_set_size
end=start+val_set_size
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Validation")
### Test run ###
test_set_size = int(test_percent*train_set_size)
start=train_set_size+val_set_size
end=num_data_points
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
def main():
parser = argparse.ArgumentParser()
# Logging
parser.add_argument("--wandb-entity", type=str, default="laion")
parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
parser.add_argument("--wandb-name", type=str, default="laion-dprior")
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
# URLs for embeddings
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
# Hyperparameters
parser.add_argument("--learning-rate", type=float, default=0.001)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5)
# Image embed dimension
parser.add_argument("--image-embed-dim", type=int, default=768)
# Train-test split
parser.add_argument("--train-percent", type=float, default=0.7)
parser.add_argument("--val-percent", type=float, default=0.2)
parser.add_argument("--test-percent", type=float, default=0.1)
# LAION training(pre-computed embeddings)
# DiffusionPriorNetwork(dpn) parameters
parser.add_argument("--dpn-depth", type=int, default=6)
parser.add_argument("--dpn-dim-head", type=int, default=64)
parser.add_argument("--dpn-heads", type=int, default=8)
# DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None)
parser.add_argument("--amp", type=bool, default=False)
# Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
args = parser.parse_args()
print("Setting up wandb logging... Please wait...")
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
config={
"learning_rate": args.learning_rate,
"architecture": args.wandb_arch,
"dataset": args.wandb_dataset,
"epochs": args.num_epochs,
})
print("wandb logging setup done!")
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Training loop
train(args.image_embed_dim,
args.image_embed_url,
args.text_embed_url,
args.batch_size,
args.train_percent,
args.val_percent,
args.test_percent,
args.num_epochs,
args.dp_loss_type,
args.clip,
args.dp_condition_on_text_encodings,
args.dp_timesteps,
args.dp_l2norm_output,
args.dp_cond_drop_prob,
args.dpn_depth,
args.dpn_dim_head,
args.dpn_heads,
args.save_interval,
args.save_path,
device,
args.learning_rate,
args.max_grad_norm,
args.weight_decay,
args.amp)
if __name__ == "__main__":
main()