diff --git a/setup.py b/setup.py index e7f460f..9b27a98 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,8 @@ setup( 'tqdm', 'vector-quantize-pytorch', 'x-clip>=0.5.1', - 'youtokentome' + 'youtokentome', + 'embedding-reader' ], classifiers=[ 'Development Status :: 4 - Beta', diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py new file mode 100644 index 0000000..6b50bec --- /dev/null +++ b/train_diffusion_prior.py @@ -0,0 +1,199 @@ +import argparse +import os +from dalle2_pytorch import DiffusionPrior +from embedding_reader import EmbeddingReader +from dalle2_pytorch import DiffusionPriorNetwork +from dalle2_pytorch.optimizer import get_optimizer +import math +import time +from tqdm import tqdm +import torch +from torch import nn +import wandb +os.environ["WANDB_SILENT"] = "true" + +def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): + with torch.no_grad(): + for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end), + text_reader(batch_size=batch_size, start=start, end=end)): + emb_images_tensor = torch.tensor(emb_images[0]).to(device) + emb_text_tensor = torch.tensor(emb_text[0]).to(device) + model.eval() + loss = model(text_embed = emb_text_tensor,image_embed = emb_images_tensor) + + # Log to wandb + wandb.log({phase + " " + loss_type: loss}) + +def save_model(save_path,state_dict): + # Saving State Dict + print("====================================== Saving checkpoint ======================================") + torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth') + +def train(image_embed_dim, + image_embed_url, + text_embed_url, + batch_size, + train_percent, + val_percent, + test_percent, + num_epochs, + dp_loss_type, + clip, + dp_condition_on_text_encodings, + dp_timesteps, + dp_cond_drop_prob, + dpn_depth, + dpn_dim_head, + dpn_heads, + save_interval, + save_path, + device, + learning_rate=0.01): + + # DiffusionPriorNetwork + prior_network = DiffusionPriorNetwork( + dim = image_embed_dim, + depth = dpn_depth, + dim_head = dpn_dim_head, + heads = dpn_heads).to(device) + + # DiffusionPrior with text embeddings and image embeddings pre-computed + diffusion_prior = DiffusionPrior( + net = prior_network, + clip = clip, + image_embed_dim = image_embed_dim, + timesteps = dp_timesteps, + cond_drop_prob = dp_cond_drop_prob, + loss_type = dp_loss_type, + condition_on_text_encodings = dp_condition_on_text_encodings).to(device) + + # Get image and text embeddings from the servers + print("==============Downloading embeddings - image and text====================") + image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy") + text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy") + num_data_points = text_reader.count + + # Create save_path if it doesn't exist + if not os.path.exists(save_path): + os.makedirs(save_path) + + ### Training code ### + optimizer = get_optimizer(diffusion_prior.parameters()) + epochs = num_epochs + + step = 0 + t = time.time() + train_set_size = int(train_percent*num_data_points) + val_set_size = int(val_percent*num_data_points) + for _ in range(epochs): + for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size), + text_reader(batch_size=batch_size, start=0, end=train_set_size)): + emb_images_tensor = torch.tensor(emb_images[0]).to(device) + emb_text_tensor = torch.tensor(emb_text[0]).to(device) + optimizer.zero_grad() + loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor) + loss.backward() + # Samples per second + step+=1 + samples_per_sec = batch_size*step/(time.time()-t) + # Save checkpoint every save_interval minutes + if(int(time.time()-t) >= 60*save_interval): + t = time.time() + save_model(save_path,diffusion_prior.state_dict()) + # Log to wandb + wandb.log({"Training loss": loss.item(), + "Steps": step, + "Samples per second": samples_per_sec}) + optimizer.step() + + ### Evaluate model(validation run) ### + start = train_set_size + end=start+val_set_size + eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Validation") + + ### Test run ### + test_set_size = int(test_percent*train_set_size) + start=train_set_size+val_set_size + end=num_data_points + eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test") + +def main(): + parser = argparse.ArgumentParser() + # Logging + parser.add_argument("--wandb-entity", type=str, default="laion") + parser.add_argument("--wandb-project", type=str, default="diffusion-prior") + parser.add_argument("--wandb-name", type=str, default="laion-dprior") + parser.add_argument("--wandb-dataset", type=str, default="LAION-5B") + parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior") + # URLs for embeddings + parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") + parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") + # Hyperparameters + parser.add_argument("--learning-rate", type=float, default=0.01) + parser.add_argument("--batch-size", type=int, default=10**4) + parser.add_argument("--num-epochs", type=int, default=5) + # Image embed dimension + parser.add_argument("--image-embed-dim", type=int, default=768) + # Train-test split + parser.add_argument("--train-percent", type=float, default=0.7) + parser.add_argument("--val-percent", type=float, default=0.2) + parser.add_argument("--test-percent", type=float, default=0.1) + # LAION training(pre-computed embeddings) + # DiffusionPriorNetwork(dpn) parameters + parser.add_argument("--dpn-depth", type=int, default=6) + parser.add_argument("--dpn-dim-head", type=int, default=64) + parser.add_argument("--dpn-heads", type=int, default=8) + # DiffusionPrior(dp) parameters + parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) + parser.add_argument("--dp-timesteps", type=int, default=100) + parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2) + parser.add_argument("--dp-loss-type", type=str, default="l2") + parser.add_argument("--clip", type=str, default=None) + # Model checkpointing interval(minutes) + parser.add_argument("--save-interval", type=int, default=30) + parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints") + + args = parser.parse_args() + print("Setting up wandb logging... Please wait...") + wandb.init( + entity=args.wandb_entity, + project=args.wandb_project, + config={ + "learning_rate": args.learning_rate, + "architecture": args.wandb_arch, + "dataset": args.wandb_dataset, + "epochs": args.num_epochs, + }) + print("wandb logging setup done!") + # Obtain the utilized device. + if torch.cuda.is_available(): + device = torch.device("cuda:0") + torch.cuda.set_device(device) + has_cuda = True + else: + device = torch.device("cpu") + has_cuda = False + # Training loop + train(args.image_embed_dim, + args.image_embed_url, + args.text_embed_url, + args.batch_size, + args.train_percent, + args.val_percent, + args.test_percent, + args.num_epochs, + args.dp_loss_type, + args.clip, + args.dp_condition_on_text_encodings, + args.dp_timesteps, + args.dp_cond_drop_prob, + args.dpn_depth, + args.dpn_dim_head, + args.dpn_heads, + args.save_interval, + args.save_path, + device, + args.learning_rate) + +if __name__ == "__main__": + main()