From aa6772dcff5275aa24d73264997f35d1c4b1f022 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 15 May 2022 10:48:10 -0700 Subject: [PATCH] make sure optimizer and scaler is reloaded on resume for training diffusion prior script, move argparse to click --- dalle2_pytorch/train.py | 2 +- train_diffusion_prior.py | 239 +++++++++++++++++++++++---------------- 2 files changed, 140 insertions(+), 101 deletions(-) diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index 768e72d..c8bb946 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -117,7 +117,7 @@ def load_diffusion_model(dprior_path, device): # Load state dict from saved model diffusion_prior.load_state_dict(loaded_obj['model']) - return diffusion_prior + return diffusion_prior, loaded_obj def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim): # Saving State Dict diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 107c53b..4ab2d09 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -1,7 +1,7 @@ -import os +from pathlib import Path +import click import math import time -import argparse import numpy as np import torch @@ -22,6 +22,17 @@ REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting du tracker = WandbTracker() +# helpers functions + +class Timer: + def __init__(self): + self.reset() + + def reset(self): + self.last_time = time.time() + + def elapsed(self): + return time.time() - self.last_time # functions def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): @@ -155,7 +166,7 @@ def train(image_embed_dim, # Load pre-trained model from DPRIOR_PATH if RESUME: - diffusion_prior = load_diffusion_model(DPRIOR_PATH, device) + diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device) # TODO, optimizer and scaler needs to be loaded as well @@ -171,10 +182,15 @@ def train(image_embed_dim, amp = amp, ).to(device) + # load optimizer and scaler + + if RESUME: + trainer.optimizer.load_state_dict(loaded_obj['optimizer']) + trainer.scaler.load_state_dict(loaded_obj['scaler']) + # Create save_path if it doesn't exist - if not os.path.exists(save_path): - os.makedirs(save_path) + Path(save_path).mkdir(exist_ok = True, parents = True) # Get image and text embeddings from the servers @@ -185,8 +201,8 @@ def train(image_embed_dim, ### Training code ### + timer = Timer() epochs = num_epochs - t = time.time() train_set_size = int(train_percent*num_data_points) val_set_size = int(val_percent*num_data_points) @@ -202,15 +218,15 @@ def train(image_embed_dim, emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device) - loss = trainer(text_embed = emb_text_tensor,image_embed = emb_images_tensor) + loss = trainer(text_embed = emb_text_tensor, image_embed = emb_images_tensor) # Samples per second - samples_per_sec = batch_size*step/(time.time()-t) + samples_per_sec = batch_size * step / timer.elapsed() # Save checkpoint every save_interval minutes - if(int(time.time()-t) >= 60*save_interval): - t = time.time() + if(int(timer.elapsed()) >= 60 * save_interval): + timer.reset() save_diffusion_model( save_path, @@ -253,67 +269,89 @@ def train(image_embed_dim, end = num_data_points eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test") -def main(): - parser = argparse.ArgumentParser() - # Logging - parser.add_argument("--wandb-entity", type=str, default="laion") - parser.add_argument("--wandb-project", type=str, default="diffusion-prior") - parser.add_argument("--wandb-dataset", type=str, default="LAION-5B") - parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior") - # URLs for embeddings - parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") - parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") - # Hyperparameters - parser.add_argument("--learning-rate", type=float, default=1.1e-4) - parser.add_argument("--weight-decay", type=float, default=6.02e-2) - parser.add_argument("--dropout", type=float, default=5e-2) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--batch-size", type=int, default=10**4) - parser.add_argument("--num-epochs", type=int, default=5) - # Image embed dimension - parser.add_argument("--image-embed-dim", type=int, default=768) - # Train-test split - parser.add_argument("--train-percent", type=float, default=0.7) - parser.add_argument("--val-percent", type=float, default=0.2) - parser.add_argument("--test-percent", type=float, default=0.1) - # LAION training(pre-computed embeddings) - # DiffusionPriorNetwork(dpn) parameters - parser.add_argument("--dpn-depth", type=int, default=6) - parser.add_argument("--dpn-dim-head", type=int, default=64) - parser.add_argument("--dpn-heads", type=int, default=8) - # DiffusionPrior(dp) parameters - parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) - parser.add_argument("--dp-timesteps", type=int, default=100) - parser.add_argument("--dp-normformer", type=bool, default=False) - parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1) - parser.add_argument("--dp-loss-type", type=str, default="l2") - parser.add_argument("--clip", type=str, default=None) - parser.add_argument("--amp", type=bool, default=False) - # Model checkpointing interval(minutes) - parser.add_argument("--save-interval", type=int, default=30) - parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints") - # Saved model path - parser.add_argument("--pretrained-model-path", type=str, default=None) - - args = parser.parse_args() - - config = ({"learning_rate": args.learning_rate, - "architecture": args.wandb_arch, - "dataset": args.wandb_dataset, - "weight_decay":args.weight_decay, - "max_gradient_clipping_norm":args.max_grad_norm, - "batch_size":args.batch_size, - "epochs": args.num_epochs, - "diffusion_prior_network":{"depth":args.dpn_depth, - "dim_head":args.dpn_dim_head, - "heads":args.dpn_heads, - "normformer":args.dp_normformer}, - "diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings, - "timesteps": args.dp_timesteps, - "cond_drop_prob":args.dp_cond_drop_prob, - "loss_type":args.dp_loss_type, - "clip":args.clip} - }) +@click.command() +@click.option("--wandb-entity", default="laion") +@click.option("--wandb-project", default="diffusion-prior") +@click.option("--wandb-dataset", default="LAION-5B") +@click.option("--wandb-arch", default="DiffusionPrior") +@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") +@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") +@click.option("--learning-rate", default=1.1e-4) +@click.option("--weight-decay", default=6.02e-2) +@click.option("--dropout", default=5e-2) +@click.option("--max-grad-norm", default=0.5) +@click.option("--batch-size", default=10**4) +@click.option("--num-epochs", default=5) +@click.option("--image-embed-dim", default=768) +@click.option("--train-percent", default=0.7) +@click.option("--val-percent", default=0.2) +@click.option("--test-percent", default=0.1) +@click.option("--dpn-depth", default=6) +@click.option("--dpn-dim-head", default=64) +@click.option("--dpn-heads", default=8) +@click.option("--dp-condition-on-text-encodings", default=False) +@click.option("--dp-timesteps", default=100) +@click.option("--dp-normformer", default=False) +@click.option("--dp-cond-drop-prob", default=0.1) +@click.option("--dp-loss-type", default="l2") +@click.option("--clip", default=None) +@click.option("--amp", default=False) +@click.option("--save-interval", default=30) +@click.option("--save-path", default="./diffusion_prior_checkpoints") +@click.option("--pretrained-model-path", default=None) +def main( + wandb_entity, + wandb_project, + wandb_dataset, + wandb_arch, + image_embed_url, + text_embed_url, + learning_rate, + weight_decay, + dropout, + max_grad_norm, + batch_size, + num_epochs, + image_embed_dim, + train_percent, + val_percent, + test_percent, + dpn_depth, + dpn_dim_head, + dpn_heads, + dp_condition_on_text_encodings, + dp_timesteps, + dp_normformer, + dp_cond_drop_prob, + dp_loss_type, + clip, + amp, + save_interval, + save_path, + pretrained_model_path +): + config = { + "learning_rate": learning_rate, + "architecture": wandb_arch, + "dataset": wandb_dataset, + "weight_decay": weight_decay, + "max_gradient_clipping_norm": max_grad_norm, + "batch_size": batch_size, + "epochs": num_epochs, + "diffusion_prior_network": { + "depth": dpn_depth, + "dim_head": dpn_dim_head, + "heads": dpn_heads, + "normformer": dp_normformer + }, + "diffusion_prior": { + "condition_on_text_encodings": dp_condition_on_text_encodings, + "timesteps": dp_timesteps, + "cond_drop_prob": dp_cond_drop_prob, + "loss_type": dp_loss_type, + "clip": clip + } + } RESUME = False @@ -325,9 +363,10 @@ def main(): RESUME = True else: tracker.init( - entity=args.wandb_entity, - project=args.wandb_project, - config=config) + entity = wandb_entity, + project = wandb_project, + config = config + ) # Obtain the utilized device. @@ -337,36 +376,36 @@ def main(): torch.cuda.set_device(device) # Training loop - train(args.image_embed_dim, - args.image_embed_url, - args.text_embed_url, - args.batch_size, - args.train_percent, - args.val_percent, - args.test_percent, - args.num_epochs, - args.dp_loss_type, - args.clip, - args.dp_condition_on_text_encodings, - args.dp_timesteps, - args.dp_normformer, - args.dp_cond_drop_prob, - args.dpn_depth, - args.dpn_dim_head, - args.dpn_heads, - args.save_interval, - args.save_path, + train(image_embed_dim, + image_embed_url, + text_embed_url, + batch_size, + train_percent, + val_percent, + test_percent, + num_epochs, + dp_loss_type, + clip, + dp_condition_on_text_encodings, + dp_timesteps, + dp_normformer, + dp_cond_drop_prob, + dpn_depth, + dpn_dim_head, + dpn_heads, + save_interval, + save_path, device, RESUME, DPRIOR_PATH, config, - args.wandb_entity, - args.wandb_project, - args.learning_rate, - args.max_grad_norm, - args.weight_decay, - args.dropout, - args.amp) + wandb_entity, + wandb_project, + learning_rate, + max_grad_norm, + weight_decay, + dropout, + amp) if __name__ == "__main__": main()