Compare commits

..

10 Commits

Author SHA1 Message Date
Phil Wang
58d9b422f3 0.0.94 2022-05-04 07:42:33 -07:00
Ray Bell
44b319cb57 add missing import (#56) 2022-05-04 07:42:20 -07:00
Phil Wang
c30f380689 final reminder 2022-05-03 08:18:53 -07:00
Phil Wang
e4e884bb8b keep all doors open 2022-05-03 08:17:02 -07:00
Phil Wang
803ad9c17d product management again 2022-05-03 08:15:25 -07:00
Phil Wang
a88dd6a9c0 todo 2022-05-03 08:09:02 -07:00
Kumar R
72c16b496e Update train_diffusion_prior.py (#53) 2022-05-02 22:44:57 -07:00
z
81d83dd7f2 defaults align with paper (#52)
Co-authored-by: nousr <>
2022-05-02 13:52:11 -07:00
Phil Wang
fa66f7e1e9 todo 2022-05-02 12:57:15 -07:00
Phil Wang
aa8d135245 allow laion to experiment with normformer in diffusion prior 2022-05-02 11:35:00 -07:00
4 changed files with 18 additions and 6 deletions

View File

@@ -821,7 +821,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [x] bring in tools to train vqgan-vae
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo)
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
@@ -832,6 +832,10 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
## Citations

View File

@@ -1,6 +1,7 @@
import click
import torch
import torchvision.transforms as T
from functools import reduce
from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.93',
version = '0.0.94',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',

View File

@@ -7,6 +7,9 @@ from torch import nn
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler
import time
from tqdm import tqdm
@@ -54,6 +57,7 @@ def train(image_embed_dim,
dp_condition_on_text_encodings,
dp_timesteps,
dp_l2norm_output,
dp_normformer,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
@@ -72,6 +76,7 @@ def train(image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
normformer = dp_normformer,
l2norm_output = dp_l2norm_output).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
@@ -134,7 +139,7 @@ def train(image_embed_dim,
"Samples per second": samples_per_sec})
scaler.unscale_(optimizer)
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
@@ -163,8 +168,8 @@ def main():
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
# Hyperparameters
parser.add_argument("--learning-rate", type=float, default=0.001)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5)
@@ -183,7 +188,8 @@ def main():
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
parser.add_argument("--dp-normformer", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None)
parser.add_argument("--amp", type=bool, default=False)
@@ -227,6 +233,7 @@ def main():
args.dp_condition_on_text_encodings,
args.dp_timesteps,
args.dp_l2norm_output,
args.dp_normformer,
args.dp_cond_drop_prob,
args.dpn_depth,
args.dpn_dim_head,