Compare commits

...

4 Commits

4 changed files with 215 additions and 156 deletions

View File

@@ -814,8 +814,8 @@ clip = CLIP(
# mock data # mock data
text = torch.randint(0, 49408, (4, 256)).cuda() text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda() images = torch.randn(32, 3, 256, 256).cuda()
# prior networks (with transformer) # prior networks (with transformer)
@@ -842,7 +842,7 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
ema_update_every = 10, ema_update_every = 10,
) )
loss = diffusion_prior_trainer(text, images) loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop # after much of the above three lines in a loop

View File

@@ -66,15 +66,24 @@ def split(t, split_size = None):
return TypeError return TypeError
def split_args_and_kwargs(x, *args, split_size = None, **kwargs): def find_first(cond, arr):
batch_size = len(x) for el in arr:
if cond(el):
return el
return None
def split_args_and_kwargs(*args, split_size = None, **kwargs):
all_args = (*args, *kwargs.values())
len_all_args = len(all_args)
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
assert exists(first_tensor)
batch_size = len(first_tensor)
split_size = default(split_size, batch_size) split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size) chunk_size = ceil(batch_size / split_size)
dict_len = len(kwargs) dict_len = len(kwargs)
dict_keys = kwargs.keys() dict_keys = kwargs.keys()
all_args = (x, *args, *kwargs.values())
len_all_args = len(all_args)
split_kwargs_index = len_all_args - dict_len split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args] split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
@@ -117,7 +126,7 @@ def load_diffusion_model(dprior_path, device):
# Load state dict from saved model # Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model']) diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior return diffusion_prior, loaded_obj
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim): def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict # Saving State Dict
@@ -258,14 +267,13 @@ class DiffusionPriorTrainer(nn.Module):
def forward( def forward(
self, self,
x,
*args, *args,
max_batch_size = None, max_batch_size = None,
**kwargs **kwargs
): ):
total_loss = 0. total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp): with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs) loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac loss = loss * chunk_size_frac
@@ -385,15 +393,14 @@ class DecoderTrainer(nn.Module):
def forward( def forward(
self, self,
x, *args,
*,
unet_number, unet_number,
max_batch_size = None, max_batch_size = None,
**kwargs **kwargs
): ):
total_loss = 0. total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp): with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac loss = loss * chunk_size_frac

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.30', version = '0.2.31',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -1,27 +1,43 @@
import os from pathlib import Path
import click
import math import math
import time import time
import argparse
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from embedding_reader import EmbeddingReader from embedding_reader import EmbeddingReader
from tqdm import tqdm from tqdm import tqdm
# constants
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
tracker = WandbTracker() tracker = WandbTracker()
# helpers functions
def exists(val):
val is not None
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time
# functions
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@@ -126,48 +142,67 @@ def train(image_embed_dim,
dropout=0.05, dropout=0.05,
amp=False): amp=False):
# DiffusionPriorNetwork # diffusion prior network
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim, prior_network = DiffusionPriorNetwork(
depth = dpn_depth, dim = image_embed_dim,
dim_head = dpn_dim_head, depth = dpn_depth,
heads = dpn_heads, dim_head = dpn_dim_head,
attn_dropout = dropout, heads = dpn_heads,
ff_dropout = dropout, attn_dropout = dropout,
normformer = dp_normformer).to(device) ff_dropout = dropout,
normformer = dp_normformer
)
# diffusion prior with text embeddings and image embeddings pre-computed
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior( diffusion_prior = DiffusionPrior(
net = prior_network, net = prior_network,
clip = clip, clip = clip,
image_embed_dim = image_embed_dim, image_embed_dim = image_embed_dim,
timesteps = dp_timesteps, timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob, cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type, loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings).to(device) condition_on_text_encodings = dp_condition_on_text_encodings
)
# Load pre-trained model from DPRIOR_PATH # Load pre-trained model from DPRIOR_PATH
if RESUME: if RESUME:
diffusion_prior=load_diffusion_model(DPRIOR_PATH,device) diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
wandb.init( entity=wandb_entity, project=wandb_project, config=config) tracker.init(entity = wandb_entity, project = wandb_project, config = config)
# diffusion prior trainer
trainer = DiffusionPriorTrainer(
diffusion_prior = diffusion_prior,
lr = learning_rate,
wd = weight_decay,
max_grad_norm = max_grad_norm,
amp = amp,
).to(device)
# load optimizer and scaler
if RESUME:
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
trainer.scaler.load_state_dict(loaded_obj['scaler'])
# Create save_path if it doesn't exist # Create save_path if it doesn't exist
if not os.path.exists(save_path):
os.makedirs(save_path) Path(save_path).mkdir(exist_ok = True, parents = True)
# Get image and text embeddings from the servers # Get image and text embeddings from the servers
print_ribbon("Downloading embeddings - image and text") print_ribbon("Downloading embeddings - image and text")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy") image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy") text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count num_data_points = text_reader.count
### Training code ### ### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs
step = 0 timer = Timer()
t = time.time() epochs = num_epochs
train_set_size = int(train_percent*num_data_points) train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points) val_set_size = int(val_percent*num_data_points)
@@ -178,27 +213,26 @@ def train(image_embed_dim,
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size), for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)): text_reader(batch_size=batch_size, start=0, end=train_set_size)):
diffusion_prior.train() trainer.train()
emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device)
with autocast(enabled=amp): loss = trainer(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
scaler.scale(loss).backward()
# Samples per second # Samples per second
step+=1
samples_per_sec = batch_size*step/(time.time()-t) samples_per_sec = batch_size * step / timer.elapsed()
# Save checkpoint every save_interval minutes # Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval): if(int(timer.elapsed()) >= 60 * save_interval):
t = time.time() timer.reset()
save_diffusion_model( save_diffusion_model(
save_path, save_path,
diffusion_prior, diffusion_prior,
optimizer, trainer.optimizer,
scaler, trainer.scaler,
config, config,
image_embed_dim) image_embed_dim)
@@ -227,91 +261,109 @@ def train(image_embed_dim,
dp_loss_type, dp_loss_type,
phase="Validation") phase="Validation")
scaler.unscale_(optimizer) trainer.update()
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
### Test run ### ### Test run ###
test_set_size = int(test_percent*train_set_size) test_set_size = int(test_percent*train_set_size)
start=train_set_size+val_set_size start = train_set_size+val_set_size
end=num_data_points end = num_data_points
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test") eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
def main(): @click.command()
parser = argparse.ArgumentParser() @click.option("--wandb-entity", default="laion")
# Logging @click.option("--wandb-project", default="diffusion-prior")
parser.add_argument("--wandb-entity", type=str, default="laion") @click.option("--wandb-dataset", default="LAION-5B")
parser.add_argument("--wandb-project", type=str, default="diffusion-prior") @click.option("--wandb-arch", default="DiffusionPrior")
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B") @click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior") @click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
# URLs for embeddings @click.option("--learning-rate", default=1.1e-4)
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") @click.option("--weight-decay", default=6.02e-2)
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") @click.option("--dropout", default=5e-2)
# Hyperparameters @click.option("--max-grad-norm", default=0.5)
parser.add_argument("--learning-rate", type=float, default=1.1e-4) @click.option("--batch-size", default=10**4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2) @click.option("--num-epochs", default=5)
parser.add_argument("--dropout", type=float, default=5e-2) @click.option("--image-embed-dim", default=768)
parser.add_argument("--max-grad-norm", type=float, default=0.5) @click.option("--train-percent", default=0.7)
parser.add_argument("--batch-size", type=int, default=10**4) @click.option("--val-percent", default=0.2)
parser.add_argument("--num-epochs", type=int, default=5) @click.option("--test-percent", default=0.1)
# Image embed dimension @click.option("--dpn-depth", default=6)
parser.add_argument("--image-embed-dim", type=int, default=768) @click.option("--dpn-dim-head", default=64)
# Train-test split @click.option("--dpn-heads", default=8)
parser.add_argument("--train-percent", type=float, default=0.7) @click.option("--dp-condition-on-text-encodings", default=False)
parser.add_argument("--val-percent", type=float, default=0.2) @click.option("--dp-timesteps", default=100)
parser.add_argument("--test-percent", type=float, default=0.1) @click.option("--dp-normformer", default=False)
# LAION training(pre-computed embeddings) @click.option("--dp-cond-drop-prob", default=0.1)
# DiffusionPriorNetwork(dpn) parameters @click.option("--dp-loss-type", default="l2")
parser.add_argument("--dpn-depth", type=int, default=6) @click.option("--clip", default=None)
parser.add_argument("--dpn-dim-head", type=int, default=64) @click.option("--amp", default=False)
parser.add_argument("--dpn-heads", type=int, default=8) @click.option("--save-interval", default=30)
# DiffusionPrior(dp) parameters @click.option("--save-path", default="./diffusion_prior_checkpoints")
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) @click.option("--pretrained-model-path", default=None)
parser.add_argument("--dp-timesteps", type=int, default=100) def main(
parser.add_argument("--dp-normformer", type=bool, default=False) wandb_entity,
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1) wandb_project,
parser.add_argument("--dp-loss-type", type=str, default="l2") wandb_dataset,
parser.add_argument("--clip", type=str, default=None) wandb_arch,
parser.add_argument("--amp", type=bool, default=False) image_embed_url,
# Model checkpointing interval(minutes) text_embed_url,
parser.add_argument("--save-interval", type=int, default=30) learning_rate,
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints") weight_decay,
# Saved model path dropout,
parser.add_argument("--pretrained-model-path", type=str, default=None) max_grad_norm,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
args = parser.parse_args()
config = ({"learning_rate": args.learning_rate,
"architecture": args.wandb_arch,
"dataset": args.wandb_dataset,
"weight_decay":args.weight_decay,
"max_gradient_clipping_norm":args.max_grad_norm,
"batch_size":args.batch_size,
"epochs": args.num_epochs,
"diffusion_prior_network":{"depth":args.dpn_depth,
"dim_head":args.dpn_dim_head,
"heads":args.dpn_heads,
"normformer":args.dp_normformer},
"diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings,
"timesteps": args.dp_timesteps,
"cond_drop_prob":args.dp_cond_drop_prob,
"loss_type":args.dp_loss_type,
"clip":args.clip}
})
RESUME = False
# Check if DPRIOR_PATH exists(saved model path) # Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = args.pretrained_model_path DPRIOR_PATH = args.pretrained_model_path
if(DPRIOR_PATH is not None): RESUME = exists(DPRIOR_PATH)
RESUME = True
else: if not RESUME:
tracker.init( tracker.init(
entity=args.wandb_entity, entity = wandb_entity,
project=args.wandb_project, project = wandb_project,
config=config) config = config
)
# Obtain the utilized device. # Obtain the utilized device.
@@ -321,36 +373,36 @@ def main():
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Training loop # Training loop
train(args.image_embed_dim, train(image_embed_dim,
args.image_embed_url, image_embed_url,
args.text_embed_url, text_embed_url,
args.batch_size, batch_size,
args.train_percent, train_percent,
args.val_percent, val_percent,
args.test_percent, test_percent,
args.num_epochs, num_epochs,
args.dp_loss_type, dp_loss_type,
args.clip, clip,
args.dp_condition_on_text_encodings, dp_condition_on_text_encodings,
args.dp_timesteps, dp_timesteps,
args.dp_normformer, dp_normformer,
args.dp_cond_drop_prob, dp_cond_drop_prob,
args.dpn_depth, dpn_depth,
args.dpn_dim_head, dpn_dim_head,
args.dpn_heads, dpn_heads,
args.save_interval, save_interval,
args.save_path, save_path,
device, device,
RESUME, RESUME,
DPRIOR_PATH, DPRIOR_PATH,
config, config,
args.wandb_entity, wandb_entity,
args.wandb_project, wandb_project,
args.learning_rate, learning_rate,
args.max_grad_norm, max_grad_norm,
args.weight_decay, weight_decay,
args.dropout, dropout,
args.amp) amp)
if __name__ == "__main__": if __name__ == "__main__":
main() main()