mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11d4e11f10 | ||
|
|
99778e12de | ||
|
|
0f0011caf0 | ||
|
|
7b7a62044a | ||
|
|
156fe5ed9f | ||
|
|
5ec34bebe1 | ||
|
|
8eaacf1ac1 | ||
|
|
e66c7b0249 | ||
|
|
f7cd4a0992 | ||
|
|
68e7d2f241 | ||
|
|
74f222596a | ||
|
|
aa6772dcff | ||
|
|
71d0c4edae | ||
|
|
f7eee09d8b | ||
|
|
89de5af63e |
112
README.md
112
README.md
@@ -706,7 +706,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## Training wrapper (wip)
|
||||
## Training wrapper
|
||||
|
||||
### Decoder Training
|
||||
|
||||
@@ -814,8 +814,8 @@ clip = CLIP(
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
text = torch.randint(0, 49408, (32, 256)).cuda()
|
||||
images = torch.randn(32, 3, 256, 256).cuda()
|
||||
|
||||
# prior networks (with transformer)
|
||||
|
||||
@@ -842,7 +842,7 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
|
||||
ema_update_every = 10,
|
||||
)
|
||||
|
||||
loss = diffusion_prior_trainer(text, images)
|
||||
loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
|
||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||
|
||||
# after much of the above three lines in a loop
|
||||
@@ -851,6 +851,57 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
|
||||
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
||||
```
|
||||
|
||||
## Bonus
|
||||
|
||||
### Unconditional Training
|
||||
|
||||
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
|
||||
|
||||
ex.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder
|
||||
|
||||
# unet for the cascading ddpm
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
# decoder, which contains the unets
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
|
||||
timesteps = 1000,
|
||||
unconditional = True
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 512, 512).cuda()
|
||||
|
||||
# feed images into decoder
|
||||
|
||||
for i in (1, 2):
|
||||
loss = decoder(images, unet_number = i)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many many many many steps
|
||||
# then it will learn to generate images
|
||||
|
||||
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
|
||||
```
|
||||
|
||||
## Dataloaders
|
||||
|
||||
### Decoder Dataloaders
|
||||
|
||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||
@@ -895,14 +946,14 @@ dataset = ImageEmbeddingDataset(
|
||||
)
|
||||
```
|
||||
|
||||
## Scripts
|
||||
### Scripts (wip)
|
||||
|
||||
### Using the `train_diffusion_prior.py` script
|
||||
#### `train_diffusion_prior.py`
|
||||
|
||||
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
|
||||
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
|
||||
|
||||
### Usage
|
||||
#### Usage
|
||||
|
||||
```bash
|
||||
$ python train_diffusion_prior.py
|
||||
@@ -910,58 +961,49 @@ $ python train_diffusion_prior.py
|
||||
|
||||
The most significant parameters for the script are as follows:
|
||||
|
||||
--image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
|
||||
|
||||
--text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
|
||||
|
||||
--image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
||||
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
||||
|
||||
--learning-rate, default=1.1e-4
|
||||
- `learning-rate`, default = `1.1e-4`
|
||||
|
||||
--weight-decay, default=6.02e-2
|
||||
- `weight-decay`, default = `6.02e-2`
|
||||
|
||||
--max-grad-norm, default=0.5
|
||||
- `max-grad-norm`, default = `0.5`
|
||||
|
||||
--batch-size, default=10 ** 4
|
||||
- `batch-size`, default = `10 ** 4`
|
||||
|
||||
--num-epochs, default=5
|
||||
- `num-epochs`, default = `5`
|
||||
|
||||
--clip, default=None # Signals the prior to use pre-computed embeddings
|
||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||
|
||||
### Sample wandb run log
|
||||
|
||||
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/1blxu24j
|
||||
|
||||
### Loading and saving the Diffusion Prior model
|
||||
#### Loading and Saving the DiffusionPrior model
|
||||
|
||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
||||
|
||||
## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||
```python
|
||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||
```
|
||||
|
||||
##### Loading
|
||||
|
||||
load_diffusion_model(dprior_path, device)
|
||||
|
||||
dprior_path : path to saved model(.pth)
|
||||
|
||||
device : the cuda device you're running on
|
||||
|
||||
##### Saving
|
||||
|
||||
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
||||
|
||||
save_path : path to save at
|
||||
|
||||
model : object of Diffusion_Prior
|
||||
|
||||
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
||||
|
||||
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
|
||||
scaler : a GradScaler object.
|
||||
|
||||
e.g: scaler = GradScaler(enabled=amp)
|
||||
|
||||
config : config object created in train_diffusion_prior.py - see file for example.
|
||||
|
||||
image_embed_dim - the dimension of the image_embedding
|
||||
|
||||
e.g: 768
|
||||
|
||||
## CLI (wip)
|
||||
@@ -1007,6 +1049,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [x] cross embed layers for downsampling, as an option
|
||||
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
@@ -1014,13 +1057,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
||||
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -1305,7 +1305,7 @@ class Unet(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
image_embed_dim,
|
||||
image_embed_dim = None,
|
||||
text_embed_dim = None,
|
||||
cond_dim = None,
|
||||
num_image_tokens = 4,
|
||||
@@ -1377,7 +1377,7 @@ class Unet(nn.Module):
|
||||
self.image_to_cond = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||
@@ -1701,7 +1701,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.unconditional = unconditional
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
@@ -1988,8 +1988,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
image_size = vae.get_encoded_fmap_size(image_size)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
img = self.p_sample_loop(
|
||||
unet,
|
||||
@@ -2037,12 +2036,12 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
if not exists(image_embed):
|
||||
if not exists(image_embed) and not self.unconditional:
|
||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
@@ -2063,9 +2062,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
image = vae.encode(image)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||
|
||||
|
||||
49
dalle2_pytorch/trackers.py
Normal file
49
dalle2_pytorch/trackers.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# base class
|
||||
|
||||
class BaseTracker(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
# basic stdout class
|
||||
|
||||
class ConsoleTracker(BaseTracker):
|
||||
def init(self, **config):
|
||||
print(config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
print(log)
|
||||
|
||||
# basic wandb class
|
||||
|
||||
class WandbTracker(BaseTracker):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
import wandb
|
||||
except ImportError as e:
|
||||
print('`pip install wandb` to use the wandb experiment tracker')
|
||||
raise e
|
||||
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
self.wandb = wandb
|
||||
|
||||
def init(self, **config):
|
||||
self.wandb.init(**config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
self.wandb.log(log, **kwargs)
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
import copy
|
||||
from math import ceil
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
import numpy as np
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
@@ -45,6 +47,29 @@ def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# decorators
|
||||
|
||||
def cast_torch_tensor(fn):
|
||||
@wraps(fn)
|
||||
def inner(model, *args, **kwargs):
|
||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||
cast_device = kwargs.pop('_cast_device', True)
|
||||
|
||||
kwargs_keys = kwargs.keys()
|
||||
all_args = (*args, *kwargs.values())
|
||||
split_kwargs_index = len(all_args) - len(kwargs_keys)
|
||||
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
|
||||
|
||||
if cast_device:
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
|
||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||
|
||||
out = fn(model, *args, **kwargs)
|
||||
return out
|
||||
return inner
|
||||
|
||||
# gradient accumulation functions
|
||||
|
||||
def split_iterable(it, split_size):
|
||||
@@ -66,18 +91,27 @@ def split(t, split_size = None):
|
||||
|
||||
return TypeError
|
||||
|
||||
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
|
||||
batch_size = len(x)
|
||||
def find_first(cond, arr):
|
||||
for el in arr:
|
||||
if cond(el):
|
||||
return el
|
||||
return None
|
||||
|
||||
def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||
all_args = (*args, *kwargs.values())
|
||||
len_all_args = len(all_args)
|
||||
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
|
||||
assert exists(first_tensor)
|
||||
|
||||
batch_size = len(first_tensor)
|
||||
split_size = default(split_size, batch_size)
|
||||
chunk_size = ceil(batch_size / split_size)
|
||||
num_chunks = ceil(batch_size / split_size)
|
||||
|
||||
dict_len = len(kwargs)
|
||||
dict_keys = kwargs.keys()
|
||||
all_args = (x, *args, *kwargs.values())
|
||||
len_all_args = len(all_args)
|
||||
split_kwargs_index = len_all_args - dict_len
|
||||
|
||||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
|
||||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
|
||||
chunk_sizes = tuple(map(len, split_all_args[0]))
|
||||
|
||||
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
||||
@@ -117,7 +151,7 @@ def load_diffusion_model(dprior_path, device):
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior
|
||||
return diffusion_prior, loaded_obj
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
@@ -228,6 +262,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def update(self):
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
@@ -240,11 +276,15 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior.update()
|
||||
|
||||
self.step += 1
|
||||
|
||||
@torch.inference_mode()
|
||||
@cast_torch_tensor
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
@cast_torch_tensor
|
||||
def sample(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
|
||||
@@ -252,22 +292,24 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*args,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
if self.training:
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
return total_loss
|
||||
|
||||
@@ -328,6 +370,8 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
@@ -358,7 +402,10 @@ class DecoderTrainer(nn.Module):
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
self.step += 1
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
def sample(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
trainable_unets = self.decoder.unets
|
||||
@@ -375,22 +422,24 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
*args,
|
||||
unet_number,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
self.scale(loss, unet_number = unet_number).backward()
|
||||
|
||||
if self.training:
|
||||
self.scale(loss, unet_number = unet_number).backward()
|
||||
|
||||
return total_loss
|
||||
|
||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.2.29',
|
||||
version = '0.2.35',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -30,6 +30,7 @@ setup(
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'numpy',
|
||||
'pillow',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
|
||||
@@ -1,24 +1,42 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import click
|
||||
import math
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from embedding_reader import EmbeddingReader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from torch.cuda.amp import autocast,GradScaler
|
||||
|
||||
import time
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||
from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
# constants
|
||||
|
||||
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
|
||||
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
|
||||
|
||||
tracker = WandbTracker()
|
||||
|
||||
# helpers functions
|
||||
|
||||
def exists(val):
|
||||
val is not None
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.last_time = time.time()
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
# functions
|
||||
|
||||
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
||||
model.eval()
|
||||
@@ -40,7 +58,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
|
||||
total_samples += batches
|
||||
|
||||
avg_loss = (total_loss / total_samples)
|
||||
wandb.log({f'{phase} {loss_type}': avg_loss})
|
||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
|
||||
diffusion_prior.eval()
|
||||
@@ -87,85 +105,177 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N
|
||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
predicted_img_similarity = cos(
|
||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
||||
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
||||
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
||||
|
||||
def train(image_embed_dim,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
batch_size,
|
||||
train_percent,
|
||||
val_percent,
|
||||
test_percent,
|
||||
num_epochs,
|
||||
dp_loss_type,
|
||||
clip,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
dpn_heads,
|
||||
save_interval,
|
||||
save_path,
|
||||
device,
|
||||
RESUME,
|
||||
DPRIOR_PATH,
|
||||
config,
|
||||
wandb_entity,
|
||||
wandb_project,
|
||||
learning_rate=0.001,
|
||||
max_grad_norm=0.5,
|
||||
weight_decay=0.01,
|
||||
dropout=0.05,
|
||||
amp=False):
|
||||
@click.command()
|
||||
@click.option("--wandb-entity", default="laion")
|
||||
@click.option("--wandb-project", default="diffusion-prior")
|
||||
@click.option("--wandb-dataset", default="LAION-5B")
|
||||
@click.option("--wandb-arch", default="DiffusionPrior")
|
||||
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
@click.option("--learning-rate", default=1.1e-4)
|
||||
@click.option("--weight-decay", default=6.02e-2)
|
||||
@click.option("--dropout", default=5e-2)
|
||||
@click.option("--max-grad-norm", default=0.5)
|
||||
@click.option("--batch-size", default=10**4)
|
||||
@click.option("--num-epochs", default=5)
|
||||
@click.option("--image-embed-dim", default=768)
|
||||
@click.option("--train-percent", default=0.7)
|
||||
@click.option("--val-percent", default=0.2)
|
||||
@click.option("--test-percent", default=0.1)
|
||||
@click.option("--dpn-depth", default=6)
|
||||
@click.option("--dpn-dim-head", default=64)
|
||||
@click.option("--dpn-heads", default=8)
|
||||
@click.option("--dp-condition-on-text-encodings", default=False)
|
||||
@click.option("--dp-timesteps", default=100)
|
||||
@click.option("--dp-normformer", default=False)
|
||||
@click.option("--dp-cond-drop-prob", default=0.1)
|
||||
@click.option("--dp-loss-type", default="l2")
|
||||
@click.option("--clip", default=None)
|
||||
@click.option("--amp", default=False)
|
||||
@click.option("--save-interval", default=30)
|
||||
@click.option("--save-path", default="./diffusion_prior_checkpoints")
|
||||
@click.option("--pretrained-model-path", default=None)
|
||||
def train(
|
||||
wandb_entity,
|
||||
wandb_project,
|
||||
wandb_dataset,
|
||||
wandb_arch,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
learning_rate,
|
||||
weight_decay,
|
||||
dropout,
|
||||
max_grad_norm,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
image_embed_dim,
|
||||
train_percent,
|
||||
val_percent,
|
||||
test_percent,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
dpn_heads,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dp_loss_type,
|
||||
clip,
|
||||
amp,
|
||||
save_interval,
|
||||
save_path,
|
||||
pretrained_model_path
|
||||
):
|
||||
config = {
|
||||
"learning_rate": learning_rate,
|
||||
"architecture": wandb_arch,
|
||||
"dataset": wandb_dataset,
|
||||
"weight_decay": weight_decay,
|
||||
"max_gradient_clipping_norm": max_grad_norm,
|
||||
"batch_size": batch_size,
|
||||
"epochs": num_epochs,
|
||||
"diffusion_prior_network": {
|
||||
"depth": dpn_depth,
|
||||
"dim_head": dpn_dim_head,
|
||||
"heads": dpn_heads,
|
||||
"normformer": dp_normformer
|
||||
},
|
||||
"diffusion_prior": {
|
||||
"condition_on_text_encodings": dp_condition_on_text_encodings,
|
||||
"timesteps": dp_timesteps,
|
||||
"cond_drop_prob": dp_cond_drop_prob,
|
||||
"loss_type": dp_loss_type,
|
||||
"clip": clip
|
||||
}
|
||||
}
|
||||
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
|
||||
DPRIOR_PATH = args.pretrained_model_path
|
||||
RESUME = exists(DPRIOR_PATH)
|
||||
|
||||
if not RESUME:
|
||||
tracker.init(
|
||||
entity = wandb_entity,
|
||||
project = wandb_project,
|
||||
config = config
|
||||
)
|
||||
|
||||
# Obtain the utilized device.
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Training loop
|
||||
# diffusion prior network
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer).to(device)
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer
|
||||
)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings
|
||||
)
|
||||
|
||||
# Load pre-trained model from DPRIOR_PATH
|
||||
|
||||
if RESUME:
|
||||
diffusion_prior=load_diffusion_model(DPRIOR_PATH,device)
|
||||
wandb.init( entity=wandb_entity, project=wandb_project, config=config)
|
||||
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
||||
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior = diffusion_prior,
|
||||
lr = learning_rate,
|
||||
wd = weight_decay,
|
||||
max_grad_norm = max_grad_norm,
|
||||
amp = amp,
|
||||
).to(device)
|
||||
|
||||
# load optimizer and scaler
|
||||
|
||||
if RESUME:
|
||||
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
trainer.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
|
||||
# Create save_path if it doesn't exist
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
Path(save_path).mkdir(exist_ok = True, parents = True)
|
||||
|
||||
# Get image and text embeddings from the servers
|
||||
|
||||
print_ribbon("Downloading embeddings - image and text")
|
||||
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
||||
num_data_points = text_reader.count
|
||||
|
||||
### Training code ###
|
||||
scaler = GradScaler(enabled=amp)
|
||||
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
epochs = num_epochs
|
||||
|
||||
step = 0
|
||||
t = time.time()
|
||||
timer = Timer()
|
||||
epochs = num_epochs
|
||||
|
||||
train_set_size = int(train_percent*num_data_points)
|
||||
val_set_size = int(val_percent*num_data_points)
|
||||
@@ -176,32 +286,31 @@ def train(image_embed_dim,
|
||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
|
||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||
|
||||
diffusion_prior.train()
|
||||
trainer.train()
|
||||
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
|
||||
with autocast(enabled=amp):
|
||||
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
||||
scaler.scale(loss).backward()
|
||||
loss = trainer(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
||||
|
||||
# Samples per second
|
||||
step+=1
|
||||
samples_per_sec = batch_size*step/(time.time()-t)
|
||||
|
||||
samples_per_sec = batch_size * step / timer.elapsed()
|
||||
|
||||
# Save checkpoint every save_interval minutes
|
||||
if(int(time.time()-t) >= 60*save_interval):
|
||||
t = time.time()
|
||||
if(int(timer.elapsed()) >= 60 * save_interval):
|
||||
timer.reset()
|
||||
|
||||
save_diffusion_model(
|
||||
save_path,
|
||||
diffusion_prior,
|
||||
optimizer,
|
||||
scaler,
|
||||
trainer.optimizer,
|
||||
trainer.scaler,
|
||||
config,
|
||||
image_embed_dim)
|
||||
|
||||
# Log to wandb
|
||||
wandb.log({"Training loss": loss.item(),
|
||||
tracker.log({"Training loss": loss.item(),
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||
@@ -225,130 +334,13 @@ def train(image_embed_dim,
|
||||
dp_loss_type,
|
||||
phase="Validation")
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
trainer.update()
|
||||
|
||||
### Test run ###
|
||||
test_set_size = int(test_percent*train_set_size)
|
||||
start=train_set_size+val_set_size
|
||||
end=num_data_points
|
||||
start = train_set_size+val_set_size
|
||||
end = num_data_points
|
||||
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Logging
|
||||
parser.add_argument("--wandb-entity", type=str, default="laion")
|
||||
parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
|
||||
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
|
||||
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
|
||||
# URLs for embeddings
|
||||
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
# Hyperparameters
|
||||
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
|
||||
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
|
||||
parser.add_argument("--dropout", type=float, default=5e-2)
|
||||
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
||||
parser.add_argument("--batch-size", type=int, default=10**4)
|
||||
parser.add_argument("--num-epochs", type=int, default=5)
|
||||
# Image embed dimension
|
||||
parser.add_argument("--image-embed-dim", type=int, default=768)
|
||||
# Train-test split
|
||||
parser.add_argument("--train-percent", type=float, default=0.7)
|
||||
parser.add_argument("--val-percent", type=float, default=0.2)
|
||||
parser.add_argument("--test-percent", type=float, default=0.1)
|
||||
# LAION training(pre-computed embeddings)
|
||||
# DiffusionPriorNetwork(dpn) parameters
|
||||
parser.add_argument("--dpn-depth", type=int, default=6)
|
||||
parser.add_argument("--dpn-dim-head", type=int, default=64)
|
||||
parser.add_argument("--dpn-heads", type=int, default=8)
|
||||
# DiffusionPrior(dp) parameters
|
||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
||||
parser.add_argument("--dp-normformer", type=bool, default=False)
|
||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
|
||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||
parser.add_argument("--clip", type=str, default=None)
|
||||
parser.add_argument("--amp", type=bool, default=False)
|
||||
# Model checkpointing interval(minutes)
|
||||
parser.add_argument("--save-interval", type=int, default=30)
|
||||
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
|
||||
# Saved model path
|
||||
parser.add_argument("--pretrained-model-path", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = ({"learning_rate": args.learning_rate,
|
||||
"architecture": args.wandb_arch,
|
||||
"dataset": args.wandb_dataset,
|
||||
"weight_decay":args.weight_decay,
|
||||
"max_gradient_clipping_norm":args.max_grad_norm,
|
||||
"batch_size":args.batch_size,
|
||||
"epochs": args.num_epochs,
|
||||
"diffusion_prior_network":{"depth":args.dpn_depth,
|
||||
"dim_head":args.dpn_dim_head,
|
||||
"heads":args.dpn_heads,
|
||||
"normformer":args.dp_normformer},
|
||||
"diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings,
|
||||
"timesteps": args.dp_timesteps,
|
||||
"cond_drop_prob":args.dp_cond_drop_prob,
|
||||
"loss_type":args.dp_loss_type,
|
||||
"clip":args.clip}
|
||||
})
|
||||
|
||||
RESUME = False
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
DPRIOR_PATH = args.pretrained_model_path
|
||||
if(DPRIOR_PATH is not None):
|
||||
RESUME = True
|
||||
else:
|
||||
wandb.init(
|
||||
entity=args.wandb_entity,
|
||||
project=args.wandb_project,
|
||||
config=config)
|
||||
|
||||
# Obtain the utilized device.
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Training loop
|
||||
train(args.image_embed_dim,
|
||||
args.image_embed_url,
|
||||
args.text_embed_url,
|
||||
args.batch_size,
|
||||
args.train_percent,
|
||||
args.val_percent,
|
||||
args.test_percent,
|
||||
args.num_epochs,
|
||||
args.dp_loss_type,
|
||||
args.clip,
|
||||
args.dp_condition_on_text_encodings,
|
||||
args.dp_timesteps,
|
||||
args.dp_normformer,
|
||||
args.dp_cond_drop_prob,
|
||||
args.dpn_depth,
|
||||
args.dpn_dim_head,
|
||||
args.dpn_heads,
|
||||
args.save_interval,
|
||||
args.save_path,
|
||||
device,
|
||||
RESUME,
|
||||
DPRIOR_PATH,
|
||||
config,
|
||||
args.wandb_entity,
|
||||
args.wandb_project,
|
||||
args.learning_rate,
|
||||
args.max_grad_norm,
|
||||
args.weight_decay,
|
||||
args.dropout,
|
||||
args.amp)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
train()
|
||||
|
||||
Reference in New Issue
Block a user