diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index f73e0ac..107c53b 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -6,22 +6,24 @@ import numpy as np import torch from torch import nn -from torch.cuda.amp import autocast, GradScaler from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork -from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon -from dalle2_pytorch.optimizer import get_optimizer +from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker from embedding_reader import EmbeddingReader from tqdm import tqdm +# constants + NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training tracker = WandbTracker() +# functions + def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): model.eval() with torch.no_grad(): @@ -126,47 +128,64 @@ def train(image_embed_dim, dropout=0.05, amp=False): - # DiffusionPriorNetwork + # diffusion prior network + prior_network = DiffusionPriorNetwork( - dim = image_embed_dim, - depth = dpn_depth, - dim_head = dpn_dim_head, - heads = dpn_heads, - attn_dropout = dropout, - ff_dropout = dropout, - normformer = dp_normformer).to(device) + dim = image_embed_dim, + depth = dpn_depth, + dim_head = dpn_dim_head, + heads = dpn_heads, + attn_dropout = dropout, + ff_dropout = dropout, + normformer = dp_normformer + ) - # DiffusionPrior with text embeddings and image embeddings pre-computed + # diffusion prior with text embeddings and image embeddings pre-computed + diffusion_prior = DiffusionPrior( - net = prior_network, - clip = clip, - image_embed_dim = image_embed_dim, - timesteps = dp_timesteps, - cond_drop_prob = dp_cond_drop_prob, - loss_type = dp_loss_type, - condition_on_text_encodings = dp_condition_on_text_encodings).to(device) + net = prior_network, + clip = clip, + image_embed_dim = image_embed_dim, + timesteps = dp_timesteps, + cond_drop_prob = dp_cond_drop_prob, + loss_type = dp_loss_type, + condition_on_text_encodings = dp_condition_on_text_encodings + ) # Load pre-trained model from DPRIOR_PATH + if RESUME: - diffusion_prior=load_diffusion_model(DPRIOR_PATH,device) - wandb.init( entity=wandb_entity, project=wandb_project, config=config) + diffusion_prior = load_diffusion_model(DPRIOR_PATH, device) + + # TODO, optimizer and scaler needs to be loaded as well + + tracker.init(entity = wandb_entity, project = wandb_project, config = config) + + # diffusion prior trainer + + trainer = DiffusionPriorTrainer( + diffusion_prior = diffusion_prior, + lr = learning_rate, + wd = weight_decay, + max_grad_norm = max_grad_norm, + amp = amp, + ).to(device) # Create save_path if it doesn't exist + if not os.path.exists(save_path): os.makedirs(save_path) # Get image and text embeddings from the servers + print_ribbon("Downloading embeddings - image and text") image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy") text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy") num_data_points = text_reader.count ### Training code ### - scaler = GradScaler(enabled=amp) - optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) - epochs = num_epochs - step = 0 + epochs = num_epochs t = time.time() train_set_size = int(train_percent*num_data_points) @@ -178,18 +197,17 @@ def train(image_embed_dim, for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size), text_reader(batch_size=batch_size, start=0, end=train_set_size)): - diffusion_prior.train() + trainer.train() emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device) - with autocast(enabled=amp): - loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor) - scaler.scale(loss).backward() + loss = trainer(text_embed = emb_text_tensor,image_embed = emb_images_tensor) # Samples per second - step+=1 + samples_per_sec = batch_size*step/(time.time()-t) + # Save checkpoint every save_interval minutes if(int(time.time()-t) >= 60*save_interval): t = time.time() @@ -197,8 +215,8 @@ def train(image_embed_dim, save_diffusion_model( save_path, diffusion_prior, - optimizer, - scaler, + trainer.optimizer, + trainer.scaler, config, image_embed_dim) @@ -227,17 +245,12 @@ def train(image_embed_dim, dp_loss_type, phase="Validation") - scaler.unscale_(optimizer) - nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() + trainer.update() ### Test run ### test_set_size = int(test_percent*train_set_size) - start=train_set_size+val_set_size - end=num_data_points + start = train_set_size+val_set_size + end = num_data_points eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test") def main(): @@ -303,8 +316,11 @@ def main(): }) RESUME = False + # Check if DPRIOR_PATH exists(saved model path) + DPRIOR_PATH = args.pretrained_model_path + if(DPRIOR_PATH is not None): RESUME = True else: