From 8647cb5e761e15756024ea5a1be3a820e8efe395 Mon Sep 17 00:00:00 2001 From: Kumar R Date: Mon, 9 May 2022 21:23:29 +0530 Subject: [PATCH] Val loss changes, with quite a few other changes. This is in place of the earlier PR(https://github.com/lucidrains/DALLE2-pytorch/pull/67) (#77) * Val_loss changes - no rebased with lucidrains' master. * Val Loss changes - now rebased with lucidrains' master * train_diffusion_prior.py updates * dalle2_pytorch.py updates * __init__.py changes * Update train_diffusion_prior.py * Update dalle2_pytorch.py * Update train_diffusion_prior.py * Update train_diffusion_prior.py * Update dalle2_pytorch.py * Update train_diffusion_prior.py * Update train_diffusion_prior.py * Update train_diffusion_prior.py * Update train_diffusion_prior.py * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md --- README.md | 34 +++++- dalle2_pytorch/__init__.py | 2 +- dalle2_pytorch/dalle2_pytorch.py | 39 +++++++ train_diffusion_prior.py | 190 +++++++++++++++++-------------- 4 files changed, 179 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index 54983b7..b9c65c3 100644 --- a/README.md +++ b/README.md @@ -927,7 +927,39 @@ The most significant parameters for the script are as follows: ### Sample wandb run log -Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/aul0rhv5?workspace= +Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/1blxu24j + +### Loading and saving the Diffusion Prior model + +Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory. + +## from dalle2_pytorch import load_diffusion_model, save_diffusion_model + + load_diffusion_model(dprior_path, device) + + dprior_path : path to saved model(.pth) + + device : the cuda device you're running on + + save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim) + + save_path : path to save at + + model : object of Diffusion_Prior + + optimizer : optimizer object - see train_diffusion_prior.py for how to create one. + + e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) + + scaler : a GradScaler object. + + e.g: scaler = GradScaler(enabled=amp) + + config : config object created in train_diffusion_prior.py - see file for example. + + image_embed_dim - the dimension of the image_embedding + + e.g: 768 ## CLI (wip) diff --git a/dalle2_pytorch/__init__.py b/dalle2_pytorch/__init__.py index 60987bd..96eebca 100644 --- a/dalle2_pytorch/__init__.py +++ b/dalle2_pytorch/__init__.py @@ -1,4 +1,4 @@ -from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder +from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder,load_diffusion_model,save_diffusion_model from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 981e8eb..0ab1f76 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -4,6 +4,8 @@ from inspect import isfunction from functools import partial from contextlib import contextmanager from collections import namedtuple +from pathlib import Path +import time import torch import torch.nn.functional as F @@ -32,6 +34,42 @@ from rotary_embedding_torch import RotaryEmbedding from x_clip import CLIP from coca_pytorch import CoCa +# Diffusion Prior model loading and saving functions + +def load_diffusion_model(dprior_path, device ): + + dprior_path = Path(dprior_path) + assert dprior_path.exists(), 'Dprior model file does not exist' + loaded_obj = torch.load(str(dprior_path), map_location='cpu') + + # Get hyperparameters of loaded model + dpn_config = loaded_obj['hparams']['diffusion_prior_network'] + dp_config = loaded_obj['hparams']['diffusion_prior'] + image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim'] + + # Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters + + # DiffusionPriorNetwork + prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device) + + # DiffusionPrior with text embeddings and image embeddings pre-computed + diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device) + + # Load state dict from saved model + diffusion_prior.load_state_dict(loaded_obj['model']) + + return diffusion_prior + +def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim): + # Saving State Dict + print("====================================== Saving checkpoint ======================================") + state_dict = dict(model=model.state_dict(), + optimizer=optimizer.state_dict(), + scaler=scaler.state_dict(), + hparams = config, + image_embed_dim = {"image_embed_dim":image_embed_dim}) + torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth') + # helper functions def exists(val): @@ -1914,3 +1952,4 @@ class DALLE2(nn.Module): return images[0] return images + diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 657db7e..4f96ff7 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -6,7 +6,7 @@ import numpy as np import torch from torch import nn from embedding_reader import EmbeddingReader -from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork +from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, load_diffusion_model, save_diffusion_model from dalle2_pytorch.optimizer import get_optimizer from torch.cuda.amp import autocast,GradScaler @@ -41,73 +41,55 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t avg_loss = (total_loss / total_samples) wandb.log({f'{phase} {loss_type}': avg_loss}) -def save_model(save_path, state_dict): - # Saving State Dict - print("====================================== Saving checkpoint ======================================") - torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth') +def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device): - -def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_size, val_set_size, NUM_TEST_EMBEDDINGS, device): cos = nn.CosineSimilarity(dim=1, eps=1e-6) - tstart = train_set_size+val_set_size - tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS - - for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend), image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)): - # make a copy of the text embeddings for shuffling - text_embed = torch.tensor(embt[0]).to(device) - text_embed_shuffled = text_embed.clone() + tstart = train_set_size + tend = train_set_size+NUM_TEST_EMBEDDINGS + for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend), + image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)): + # make a copy of the text embeddings for shuffling + text_embed = torch.tensor(embt[0]).to(device) + text_embed_shuffled = text_embed.clone() # roll the text embeddings to simulate "unrelated" captions - rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1) - text_embed_shuffled = text_embed_shuffled[rolled_idx] - text_embed_shuffled = text_embed_shuffled / \ - text_embed_shuffled.norm(dim=1, keepdim=True) - test_text_shuffled_cond = dict(text_embed=text_embed_shuffled) - + rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1) + text_embed_shuffled = text_embed_shuffled[rolled_idx] + text_embed_shuffled = text_embed_shuffled / \ + text_embed_shuffled.norm(dim=1, keepdim=True) + test_text_shuffled_cond = dict(text_embed=text_embed_shuffled) # prepare the text embedding - text_embed = text_embed / text_embed.norm(dim=1, keepdim=True) - test_text_cond = dict(text_embed=text_embed) - + text_embed = text_embed / text_embed.norm(dim=1, keepdim=True) + test_text_cond = dict(text_embed=text_embed) # prepare image embeddings - test_image_embeddings = torch.tensor(embi[0]).to(device) - test_image_embeddings = test_image_embeddings / \ - test_image_embeddings.norm(dim=1, keepdim=True) - + test_image_embeddings = torch.tensor(embi[0]).to(device) + test_image_embeddings = test_image_embeddings / \ + test_image_embeddings.norm(dim=1, keepdim=True) # predict on the unshuffled text embeddings - predicted_image_embeddings = diffusion_prior.p_sample_loop( - (NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond) - predicted_image_embeddings = predicted_image_embeddings / \ - predicted_image_embeddings.norm(dim=1, keepdim=True) - + predicted_image_embeddings = diffusion_prior.p_sample_loop( + (NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond) + predicted_image_embeddings = predicted_image_embeddings / \ + predicted_image_embeddings.norm(dim=1, keepdim=True) # predict on the shuffled embeddings - predicted_unrelated_embeddings = diffusion_prior.p_sample_loop( - (NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond) - predicted_unrelated_embeddings = predicted_unrelated_embeddings / \ - predicted_unrelated_embeddings.norm(dim=1, keepdim=True) - + predicted_unrelated_embeddings = diffusion_prior.p_sample_loop( + (NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond) + predicted_unrelated_embeddings = predicted_unrelated_embeddings / \ + predicted_unrelated_embeddings.norm(dim=1, keepdim=True) # calculate similarities - original_similarity = cos( - text_embed, test_image_embeddings).cpu().numpy() - predicted_similarity = cos( - text_embed, predicted_image_embeddings).cpu().numpy() - unrelated_similarity = cos( - text_embed, predicted_unrelated_embeddings).cpu().numpy() - predicted_img_similarity = cos( - test_image_embeddings, predicted_image_embeddings).cpu().numpy() - - wandb.log( - {"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)}) - wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)": np.mean( - predicted_similarity)}) - wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean( - unrelated_similarity)}) - wandb.log({"CosineSimilarity(image_embed,predicted_image_embed)": np.mean( - predicted_img_similarity)}) - - return np.mean(predicted_similarity - original_similarity) - - + original_similarity = cos( + text_embed, test_image_embeddings).cpu().numpy() + predicted_similarity = cos( + text_embed, predicted_image_embeddings).cpu().numpy() + unrelated_similarity = cos( + text_embed, predicted_unrelated_embeddings).cpu().numpy() + predicted_img_similarity = cos( + test_image_embeddings, predicted_image_embeddings).cpu().numpy() + wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity), + "CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity), + "CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity), + "CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity), + "Cosine similarity difference":np.mean(predicted_similarity - original_similarity)}) def train(image_embed_dim, image_embed_url, @@ -129,6 +111,11 @@ def train(image_embed_dim, save_interval, save_path, device, + RESUME, + DPRIOR_PATH, + config, + wandb_entity, + wandb_project, learning_rate=0.001, max_grad_norm=0.5, weight_decay=0.01, @@ -152,16 +139,21 @@ def train(image_embed_dim, loss_type = dp_loss_type, condition_on_text_encodings = dp_condition_on_text_encodings).to(device) + # Load pre-trained model from DPRIOR_PATH + if RESUME: + diffusion_prior=load_diffusion_model(DPRIOR_PATH,device) + wandb.init( entity=wandb_entity, project=wandb_project, config=config) + + # Create save_path if it doesn't exist + if not os.path.exists(save_path): + os.makedirs(save_path) + # Get image and text embeddings from the servers print("==============Downloading embeddings - image and text====================") image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy") text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy") num_data_points = text_reader.count - # Create save_path if it doesn't exist - if not os.path.exists(save_path): - os.makedirs(save_path) - ### Training code ### scaler = GradScaler(enabled=amp) optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) @@ -172,6 +164,7 @@ def train(image_embed_dim, train_set_size = int(train_percent*num_data_points) val_set_size = int(val_percent*num_data_points) + eval_start = train_set_size for _ in range(epochs): diffusion_prior.train() @@ -192,9 +185,13 @@ def train(image_embed_dim, if(int(time.time()-t) >= 60*save_interval): t = time.time() - save_model( + save_diffusion_model( save_path, - dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict())) + diffusion_prior, + optimizer, + scaler, + config, + image_embed_dim) # Log to wandb wandb.log({"Training loss": loss.item(), @@ -204,14 +201,22 @@ def train(image_embed_dim, # Use NUM_TEST_EMBEDDINGS samples from the test set each time # Get embeddings from the most recently saved model if(step % REPORT_METRICS_EVERY) == 0: - diff_cosine_sim = report_cosine_sims(diffusion_prior, + report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_size, - val_set_size, NUM_TEST_EMBEDDINGS, device) - wandb.log({"Cosine similarity difference": diff_cosine_sim}) + ### Evaluate model(validation run) ### + eval_model(diffusion_prior, + device, + image_reader, + text_reader, + eval_start, + eval_start+NUM_TEST_EMBEDDINGS, + NUM_TEST_EMBEDDINGS, + dp_loss_type, + phase="Validation") scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) @@ -220,11 +225,6 @@ def train(image_embed_dim, scaler.update() optimizer.zero_grad() - ### Evaluate model(validation run) ### - start = train_set_size - end=start+val_set_size - eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Validation") - ### Test run ### test_set_size = int(test_percent*train_set_size) start=train_set_size+val_set_size @@ -236,7 +236,6 @@ def main(): # Logging parser.add_argument("--wandb-entity", type=str, default="laion") parser.add_argument("--wandb-project", type=str, default="diffusion-prior") - parser.add_argument("--wandb-name", type=str, default="laion-dprior") parser.add_argument("--wandb-dataset", type=str, default="LAION-5B") parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior") # URLs for embeddings @@ -271,22 +270,40 @@ def main(): # Model checkpointing interval(minutes) parser.add_argument("--save-interval", type=int, default=30) parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints") + # Saved model path + parser.add_argument("--pretrained-model-path", type=str, default=None) args = parser.parse_args() - print("Setting up wandb logging... Please wait...") + config = ({"learning_rate": args.learning_rate, + "architecture": args.wandb_arch, + "dataset": args.wandb_dataset, + "weight_decay":args.weight_decay, + "max_gradient_clipping_norm":args.max_grad_norm, + "batch_size":args.batch_size, + "epochs": args.num_epochs, + "diffusion_prior_network":{"depth":args.dpn_depth, + "dim_head":args.dpn_dim_head, + "heads":args.dpn_heads, + "normformer":args.dp_normformer}, + "diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings, + "timesteps": args.dp_timesteps, + "cond_drop_prob":args.dp_cond_drop_prob, + "loss_type":args.dp_loss_type, + "clip":args.clip} + }) - wandb.init( - entity=args.wandb_entity, - project=args.wandb_project, - config={ - "learning_rate": args.learning_rate, - "architecture": args.wandb_arch, - "dataset": args.wandb_dataset, - "epochs": args.num_epochs, - }) + RESUME = False + # Check if DPRIOR_PATH exists(saved model path) + DPRIOR_PATH = args.pretrained_model_path + if(DPRIOR_PATH is not None): + RESUME = True + else: + wandb.init( + entity=args.wandb_entity, + project=args.wandb_project, + config=config) - print("wandb logging setup done!") # Obtain the utilized device. has_cuda = torch.cuda.is_available() @@ -315,6 +332,11 @@ def main(): args.save_interval, args.save_path, device, + RESUME, + DPRIOR_PATH, + config, + atgs.wandb_entity, + args.wandb_project, args.learning_rate, args.max_grad_norm, args.weight_decay,