Compare commits

..

9 Commits

7 changed files with 171 additions and 31 deletions

View File

@@ -732,8 +732,8 @@ clip = CLIP(
# mock data # mock data
text = torch.randint(0, 49408, (4, 256)).cuda() text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda() images = torch.randn(32, 3, 256, 256).cuda()
# decoder (with unet) # decoder (with unet)
@@ -774,8 +774,12 @@ decoder_trainer = DecoderTrainer(
) )
for unet_number in (1, 2): for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward loss = decoder_trainer(
loss.backward() images,
text = text,
unet_number = unet_number, # which unet to train on
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
)
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
@@ -839,7 +843,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
) )
loss = diffusion_prior_trainer(text, images) loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop # after much of the above three lines in a loop
@@ -1004,6 +1007,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images - [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14 - [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option - [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
@@ -1011,12 +1015,12 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824 - [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes - [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
## Citations ## Citations

View File

@@ -1163,6 +1163,7 @@ class CrossAttention(nn.Module):
dim_head = 64, dim_head = 64,
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
norm_context = False
): ):
super().__init__() super().__init__()
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
@@ -1172,7 +1173,7 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim) context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim) self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim) self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -1378,6 +1379,9 @@ class Unet(nn.Module):
Rearrange('b (n d) -> b n d', n = num_image_tokens) Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity() ) if image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim)
# text encoding conditioning (optional) # text encoding conditioning (optional)
self.text_to_cond = None self.text_to_cond = None
@@ -1593,6 +1597,11 @@ class Unet(nn.Module):
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2) mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# normalize conditioning tokens
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# go through the layers of the unet, down and up # go through the layers of the unet, down and up
hiddens = [] hiddens = []

View File

@@ -10,13 +10,14 @@ def get_optimizer(
lr = 2e-5, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False filter_by_requires_grad = False
): ):
if filter_by_requires_grad: if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params)) params = list(filter(lambda t: t.requires_grad, params))
if wd == 0: if wd == 0:
return Adam(params, lr = lr, betas = betas) return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params) params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params) wd_params, no_wd_params = separate_weight_decayable_params(params)
@@ -26,4 +27,4 @@ def get_optimizer(
{'params': list(no_wd_params), 'weight_decay': 0}, {'params': list(no_wd_params), 'weight_decay': 0},
] ]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas) return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -0,0 +1,49 @@
import os
import torch
from torch import nn
# helper functions
def exists(val):
return val is not None
# base class
class BaseTracker(nn.Module):
def __init__(self):
super().__init__()
def init(self, config, **kwargs):
raise NotImplementedError
def log(self, log, **kwargs):
raise NotImplementedError
# basic stdout class
class ConsoleTracker(BaseTracker):
def init(self, **config):
print(config)
def log(self, log, **kwargs):
print(log)
# basic wandb class
class WandbTracker(BaseTracker):
def __init__(self):
super().__init__()
try:
import wandb
except ImportError as e:
print('`pip install wandb` to use the wandb experiment tracker')
raise e
os.environ["WANDB_SILENT"] = "true"
self.wandb = wandb
def init(self, **config):
self.wandb.init(**config)
def log(self, log, **kwargs):
self.wandb.log(log, **kwargs)

View File

@@ -1,6 +1,8 @@
import time import time
import copy import copy
from math import ceil
from functools import partial from functools import partial
from collections.abc import Iterable
import torch import torch
from torch import nn from torch import nn
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
def exists(val): def exists(val):
return val is not None return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1): def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length) return val if isinstance(val, tuple) else ((val,) * length)
@@ -40,6 +45,47 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
# gradient accumulation functions
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
batch_size = len(x)
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
all_args = (x, *args, *kwargs.values())
len_all_args = len(all_args)
split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers # print helpers
def print_ribbon(s, symbol = '=', repeat = 40): def print_ribbon(s, symbol = '=', repeat = 40):
@@ -90,7 +136,7 @@ class EMA(nn.Module):
def __init__( def __init__(
self, self,
model, model,
beta = 0.99, beta = 0.9999,
update_after_step = 1000, update_after_step = 1000,
update_every = 10, update_every = 10,
): ):
@@ -147,6 +193,7 @@ class DiffusionPriorTrainer(nn.Module):
use_ema = True, use_ema = True,
lr = 3e-4, lr = 3e-4,
wd = 1e-2, wd = 1e-2,
eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
**kwargs **kwargs
@@ -173,6 +220,7 @@ class DiffusionPriorTrainer(nn.Module):
diffusion_prior.parameters(), diffusion_prior.parameters(),
lr = lr, lr = lr,
wd = wd, wd = wd,
eps = eps,
**kwargs **kwargs
) )
@@ -180,6 +228,8 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
def update(self): def update(self):
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
@@ -192,6 +242,8 @@ class DiffusionPriorTrainer(nn.Module):
if self.use_ema: if self.use_ema:
self.ema_diffusion_prior.update() self.ema_diffusion_prior.update()
self.step += 1
@torch.inference_mode() @torch.inference_mode()
def p_sample_loop(self, *args, **kwargs): def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@@ -206,13 +258,22 @@ class DiffusionPriorTrainer(nn.Module):
def forward( def forward(
self, self,
x,
*args, *args,
divisor = 1, max_batch_size = None,
**kwargs **kwargs
): ):
with autocast(enabled = self.amp): total_loss = 0.
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor) for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scaler.scale(loss).backward()
return total_loss
# decoder trainer # decoder trainer
@@ -223,6 +284,7 @@ class DecoderTrainer(nn.Module):
use_ema = True, use_ema = True,
lr = 2e-5, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
eps = 1e-8,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
**kwargs **kwargs
@@ -247,13 +309,14 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay # be able to finely customize learning rate, weight decay
# per unet # per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd)) lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)): for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizer = get_optimizer( optimizer = get_optimizer(
unet.parameters(), unet.parameters(),
lr = unet_lr, lr = unet_lr,
wd = unet_wd, wd = unet_wd,
eps = unet_eps,
**kwargs **kwargs
) )
@@ -269,6 +332,8 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
@property @property
def unets(self): def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
@@ -299,6 +364,8 @@ class DecoderTrainer(nn.Module):
ema_unet = self.ema_unets[index] ema_unet = self.ema_unets[index]
ema_unet.update() ema_unet.update()
self.step += 1
@torch.no_grad() @torch.no_grad()
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
if self.use_ema: if self.use_ema:
@@ -321,9 +388,17 @@ class DecoderTrainer(nn.Module):
x, x,
*, *,
unet_number, unet_number,
divisor = 1, max_batch_size = None,
**kwargs **kwargs
): ):
with autocast(enabled = self.amp): total_loss = 0.
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number) for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scale(loss, unet_number = unet_number).backward()
return total_loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.20', version = '0.2.30',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -1,24 +1,26 @@
import os import os
import math import math
import time
import argparse import argparse
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from embedding_reader import EmbeddingReader from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from embedding_reader import EmbeddingReader
import time
from tqdm import tqdm from tqdm import tqdm
import wandb
os.environ["WANDB_SILENT"] = "true"
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
tracker = WandbTracker()
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval() model.eval()
@@ -40,7 +42,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
total_samples += batches total_samples += batches
avg_loss = (total_loss / total_samples) avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss}) tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device): def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
diffusion_prior.eval() diffusion_prior.eval()
@@ -87,7 +89,7 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N
text_embed, predicted_unrelated_embeddings).cpu().numpy() text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos( predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy() test_image_embeddings, predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity), tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity), "CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity), "CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity), "CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
@@ -201,7 +203,7 @@ def train(image_embed_dim,
image_embed_dim) image_embed_dim)
# Log to wandb # Log to wandb
wandb.log({"Training loss": loss.item(), tracker.log({"Training loss": loss.item(),
"Steps": step, "Steps": step,
"Samples per second": samples_per_sec}) "Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed) # Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
@@ -306,7 +308,7 @@ def main():
if(DPRIOR_PATH is not None): if(DPRIOR_PATH is not None):
RESUME = True RESUME = True
else: else:
wandb.init( tracker.init(
entity=args.wandb_entity, entity=args.wandb_entity,
project=args.wandb_project, project=args.wandb_project,
config=config) config=config)
@@ -351,4 +353,4 @@ def main():
args.amp) args.amp)
if __name__ == "__main__": if __name__ == "__main__":
main() main()