Compare commits

...

9 Commits

Author SHA1 Message Date
Phil Wang
11d4e11f10 allow for training unconditional ddpm or cascading ddpms 2022-05-15 16:54:56 -07:00
Phil Wang
99778e12de trainer classes now takes care of auto-casting numpy to torch tensors, and setting correct device based on model parameter devices 2022-05-15 15:25:45 -07:00
Phil Wang
0f0011caf0 todo 2022-05-15 14:28:35 -07:00
Phil Wang
7b7a62044a use eval vs training mode to determine whether to call backprop on trainer forward 2022-05-15 14:20:59 -07:00
Phil Wang
156fe5ed9f final cleanup for the day 2022-05-15 12:38:41 -07:00
Phil Wang
5ec34bebe1 cleanup readme 2022-05-15 12:29:26 -07:00
Phil Wang
8eaacf1ac1 remove indirection 2022-05-15 12:05:45 -07:00
Phil Wang
e66c7b0249 incorrect naming 2022-05-15 11:23:52 -07:00
Phil Wang
f7cd4a0992 product management 2022-05-15 11:21:12 -07:00
5 changed files with 226 additions and 212 deletions

104
README.md
View File

@@ -706,7 +706,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
```
## Training wrapper (wip)
## Training wrapper
### Decoder Training
@@ -851,6 +851,57 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
```
## Bonus
### Unconditional Training
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
ex.
```python
import torch
from dalle2_pytorch import Unet, Decoder
# unet for the cascading ddpm
unet1 = Unet(
dim = 128,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unets
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
timesteps = 1000,
unconditional = True
).cuda()
# mock images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder
for i in (1, 2):
loss = decoder(images, unet_number = i)
loss.backward()
# do the above for many many many many steps
# then it will learn to generate images
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
```
## Dataloaders
### Decoder Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
@@ -895,14 +946,14 @@ dataset = ImageEmbeddingDataset(
)
```
## Scripts
### Scripts (wip)
### Using the `train_diffusion_prior.py` script
#### `train_diffusion_prior.py`
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
### Usage
#### Usage
```bash
$ python train_diffusion_prior.py
@@ -910,58 +961,49 @@ $ python train_diffusion_prior.py
The most significant parameters for the script are as follows:
--image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
--text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
--image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
--learning-rate, default=1.1e-4
- `learning-rate`, default = `1.1e-4`
--weight-decay, default=6.02e-2
- `weight-decay`, default = `6.02e-2`
--max-grad-norm, default=0.5
- `max-grad-norm`, default = `0.5`
--batch-size, default=10 ** 4
- `batch-size`, default = `10 ** 4`
--num-epochs, default=5
- `num-epochs`, default = `5`
--clip, default=None # Signals the prior to use pre-computed embeddings
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
### Sample wandb run log
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/1blxu24j
### Loading and saving the Diffusion Prior model
#### Loading and Saving the DiffusionPrior model
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
```python
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
```
##### Loading
load_diffusion_model(dprior_path, device)
dprior_path : path to saved model(.pth)
device : the cuda device you're running on
##### Saving
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
save_path : path to save at
model : object of Diffusion_Prior
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
scaler : a GradScaler object.
e.g: scaler = GradScaler(enabled=amp)
config : config object created in train_diffusion_prior.py - see file for example.
image_embed_dim - the dimension of the image_embedding
e.g: 768
## CLI (wip)
@@ -1021,6 +1063,8 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
## Citations

View File

@@ -1305,7 +1305,7 @@ class Unet(nn.Module):
self,
dim,
*,
image_embed_dim,
image_embed_dim = None,
text_embed_dim = None,
cond_dim = None,
num_image_tokens = 4,
@@ -1377,7 +1377,7 @@ class Unet(nn.Module):
self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim)
@@ -1701,7 +1701,7 @@ class Decoder(BaseGaussianDiffusion):
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None
if exists(clip):
@@ -1988,8 +1988,7 @@ class Decoder(BaseGaussianDiffusion):
image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
img = self.p_sample_loop(
unet,
@@ -2037,12 +2036,12 @@ class Decoder(BaseGaussianDiffusion):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed):
if not exists(image_embed) and not self.unconditional:
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None
if exists(text) and not exists(text_encodings):
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text)
@@ -2063,9 +2062,7 @@ class Decoder(BaseGaussianDiffusion):
vae.eval()
with torch.no_grad():
image = vae.encode(image)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)

View File

@@ -1,7 +1,7 @@
import time
import copy
from math import ceil
from functools import partial
from functools import partial, wraps
from collections.abc import Iterable
import torch
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
import numpy as np
# helper functions
def exists(val):
@@ -45,6 +47,29 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# decorators
def cast_torch_tensor(fn):
@wraps(fn)
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
split_kwargs_index = len(all_args) - len(kwargs_keys)
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
out = fn(model, *args, **kwargs)
return out
return inner
# gradient accumulation functions
def split_iterable(it, split_size):
@@ -80,13 +105,13 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
batch_size = len(first_tensor)
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
num_chunks = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
@@ -254,10 +279,12 @@ class DiffusionPriorTrainer(nn.Module):
self.step += 1
@torch.inference_mode()
@cast_torch_tensor
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode()
@cast_torch_tensor
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@@ -265,6 +292,7 @@ class DiffusionPriorTrainer(nn.Module):
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
@cast_torch_tensor
def forward(
self,
*args,
@@ -279,7 +307,9 @@ class DiffusionPriorTrainer(nn.Module):
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scaler.scale(loss).backward()
if self.training:
self.scaler.scale(loss).backward()
return total_loss
@@ -375,6 +405,7 @@ class DecoderTrainer(nn.Module):
self.step += 1
@torch.no_grad()
@cast_torch_tensor
def sample(self, *args, **kwargs):
if self.use_ema:
trainable_unets = self.decoder.unets
@@ -391,6 +422,7 @@ class DecoderTrainer(nn.Module):
return output
@cast_torch_tensor
def forward(
self,
*args,
@@ -406,6 +438,8 @@ class DecoderTrainer(nn.Module):
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scale(loss, unet_number = unet_number).backward()
if self.training:
self.scale(loss, unet_number = unet_number).backward()
return total_loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.31',
version = '0.2.35',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -30,6 +30,7 @@ setup(
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'numpy',
'pillow',
'resize-right>=0.0.2',
'rotary-embedding-torch',

View File

@@ -111,37 +111,110 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
def train(image_embed_dim,
image_embed_url,
text_embed_url,
batch_size,
train_percent,
val_percent,
test_percent,
num_epochs,
dp_loss_type,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
dpn_heads,
save_interval,
save_path,
device,
RESUME,
DPRIOR_PATH,
config,
wandb_entity,
wandb_project,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01,
dropout=0.05,
amp=False):
@click.command()
@click.option("--wandb-entity", default="laion")
@click.option("--wandb-project", default="diffusion-prior")
@click.option("--wandb-dataset", default="LAION-5B")
@click.option("--wandb-arch", default="DiffusionPrior")
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
@click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5)
@click.option("--batch-size", default=10**4)
@click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.7)
@click.option("--val-percent", default=0.2)
@click.option("--test-percent", default=0.1)
@click.option("--dpn-depth", default=6)
@click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=8)
@click.option("--dp-condition-on-text-encodings", default=False)
@click.option("--dp-timesteps", default=100)
@click.option("--dp-normformer", default=False)
@click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2")
@click.option("--clip", default=None)
@click.option("--amp", default=False)
@click.option("--save-interval", default=30)
@click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None)
def train(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = args.pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Training loop
# diffusion prior network
prior_network = DiffusionPriorNetwork(
@@ -269,140 +342,5 @@ def train(image_embed_dim,
end = num_data_points
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
@click.command()
@click.option("--wandb-entity", default="laion")
@click.option("--wandb-project", default="diffusion-prior")
@click.option("--wandb-dataset", default="LAION-5B")
@click.option("--wandb-arch", default="DiffusionPrior")
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
@click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5)
@click.option("--batch-size", default=10**4)
@click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.7)
@click.option("--val-percent", default=0.2)
@click.option("--test-percent", default=0.1)
@click.option("--dpn-depth", default=6)
@click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=8)
@click.option("--dp-condition-on-text-encodings", default=False)
@click.option("--dp-timesteps", default=100)
@click.option("--dp-normformer", default=False)
@click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2")
@click.option("--clip", default=None)
@click.option("--amp", default=False)
@click.option("--save-interval", default=30)
@click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None)
def main(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = args.pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Training loop
train(image_embed_dim,
image_embed_url,
text_embed_url,
batch_size,
train_percent,
val_percent,
test_percent,
num_epochs,
dp_loss_type,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
dpn_heads,
save_interval,
save_path,
device,
RESUME,
DPRIOR_PATH,
config,
wandb_entity,
wandb_project,
learning_rate,
max_grad_norm,
weight_decay,
dropout,
amp)
if __name__ == "__main__":
main()
train()