make sure optimizer and scaler is reloaded on resume for training diffusion prior script, move argparse to click

This commit is contained in:
Phil Wang
2022-05-15 10:48:10 -07:00
parent 71d0c4edae
commit aa6772dcff
2 changed files with 140 additions and 101 deletions

View File

@@ -117,7 +117,7 @@ def load_diffusion_model(dprior_path, device):
# Load state dict from saved model # Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model']) diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior return diffusion_prior, loaded_obj
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim): def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict # Saving State Dict

View File

@@ -1,7 +1,7 @@
import os from pathlib import Path
import click
import math import math
import time import time
import argparse
import numpy as np import numpy as np
import torch import torch
@@ -22,6 +22,17 @@ REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting du
tracker = WandbTracker() tracker = WandbTracker()
# helpers functions
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time
# functions # functions
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
@@ -155,7 +166,7 @@ def train(image_embed_dim,
# Load pre-trained model from DPRIOR_PATH # Load pre-trained model from DPRIOR_PATH
if RESUME: if RESUME:
diffusion_prior = load_diffusion_model(DPRIOR_PATH, device) diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
# TODO, optimizer and scaler needs to be loaded as well # TODO, optimizer and scaler needs to be loaded as well
@@ -171,10 +182,15 @@ def train(image_embed_dim,
amp = amp, amp = amp,
).to(device) ).to(device)
# load optimizer and scaler
if RESUME:
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
trainer.scaler.load_state_dict(loaded_obj['scaler'])
# Create save_path if it doesn't exist # Create save_path if it doesn't exist
if not os.path.exists(save_path): Path(save_path).mkdir(exist_ok = True, parents = True)
os.makedirs(save_path)
# Get image and text embeddings from the servers # Get image and text embeddings from the servers
@@ -185,8 +201,8 @@ def train(image_embed_dim,
### Training code ### ### Training code ###
timer = Timer()
epochs = num_epochs epochs = num_epochs
t = time.time()
train_set_size = int(train_percent*num_data_points) train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points) val_set_size = int(val_percent*num_data_points)
@@ -202,15 +218,15 @@ def train(image_embed_dim,
emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device)
loss = trainer(text_embed = emb_text_tensor,image_embed = emb_images_tensor) loss = trainer(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
# Samples per second # Samples per second
samples_per_sec = batch_size*step/(time.time()-t) samples_per_sec = batch_size * step / timer.elapsed()
# Save checkpoint every save_interval minutes # Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval): if(int(timer.elapsed()) >= 60 * save_interval):
t = time.time() timer.reset()
save_diffusion_model( save_diffusion_model(
save_path, save_path,
@@ -253,67 +269,89 @@ def train(image_embed_dim,
end = num_data_points end = num_data_points
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test") eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
def main(): @click.command()
parser = argparse.ArgumentParser() @click.option("--wandb-entity", default="laion")
# Logging @click.option("--wandb-project", default="diffusion-prior")
parser.add_argument("--wandb-entity", type=str, default="laion") @click.option("--wandb-dataset", default="LAION-5B")
parser.add_argument("--wandb-project", type=str, default="diffusion-prior") @click.option("--wandb-arch", default="DiffusionPrior")
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B") @click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior") @click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
# URLs for embeddings @click.option("--learning-rate", default=1.1e-4)
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") @click.option("--weight-decay", default=6.02e-2)
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") @click.option("--dropout", default=5e-2)
# Hyperparameters @click.option("--max-grad-norm", default=0.5)
parser.add_argument("--learning-rate", type=float, default=1.1e-4) @click.option("--batch-size", default=10**4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2) @click.option("--num-epochs", default=5)
parser.add_argument("--dropout", type=float, default=5e-2) @click.option("--image-embed-dim", default=768)
parser.add_argument("--max-grad-norm", type=float, default=0.5) @click.option("--train-percent", default=0.7)
parser.add_argument("--batch-size", type=int, default=10**4) @click.option("--val-percent", default=0.2)
parser.add_argument("--num-epochs", type=int, default=5) @click.option("--test-percent", default=0.1)
# Image embed dimension @click.option("--dpn-depth", default=6)
parser.add_argument("--image-embed-dim", type=int, default=768) @click.option("--dpn-dim-head", default=64)
# Train-test split @click.option("--dpn-heads", default=8)
parser.add_argument("--train-percent", type=float, default=0.7) @click.option("--dp-condition-on-text-encodings", default=False)
parser.add_argument("--val-percent", type=float, default=0.2) @click.option("--dp-timesteps", default=100)
parser.add_argument("--test-percent", type=float, default=0.1) @click.option("--dp-normformer", default=False)
# LAION training(pre-computed embeddings) @click.option("--dp-cond-drop-prob", default=0.1)
# DiffusionPriorNetwork(dpn) parameters @click.option("--dp-loss-type", default="l2")
parser.add_argument("--dpn-depth", type=int, default=6) @click.option("--clip", default=None)
parser.add_argument("--dpn-dim-head", type=int, default=64) @click.option("--amp", default=False)
parser.add_argument("--dpn-heads", type=int, default=8) @click.option("--save-interval", default=30)
# DiffusionPrior(dp) parameters @click.option("--save-path", default="./diffusion_prior_checkpoints")
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) @click.option("--pretrained-model-path", default=None)
parser.add_argument("--dp-timesteps", type=int, default=100) def main(
parser.add_argument("--dp-normformer", type=bool, default=False) wandb_entity,
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1) wandb_project,
parser.add_argument("--dp-loss-type", type=str, default="l2") wandb_dataset,
parser.add_argument("--clip", type=str, default=None) wandb_arch,
parser.add_argument("--amp", type=bool, default=False) image_embed_url,
# Model checkpointing interval(minutes) text_embed_url,
parser.add_argument("--save-interval", type=int, default=30) learning_rate,
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints") weight_decay,
# Saved model path dropout,
parser.add_argument("--pretrained-model-path", type=str, default=None) max_grad_norm,
batch_size,
args = parser.parse_args() num_epochs,
image_embed_dim,
config = ({"learning_rate": args.learning_rate, train_percent,
"architecture": args.wandb_arch, val_percent,
"dataset": args.wandb_dataset, test_percent,
"weight_decay":args.weight_decay, dpn_depth,
"max_gradient_clipping_norm":args.max_grad_norm, dpn_dim_head,
"batch_size":args.batch_size, dpn_heads,
"epochs": args.num_epochs, dp_condition_on_text_encodings,
"diffusion_prior_network":{"depth":args.dpn_depth, dp_timesteps,
"dim_head":args.dpn_dim_head, dp_normformer,
"heads":args.dpn_heads, dp_cond_drop_prob,
"normformer":args.dp_normformer}, dp_loss_type,
"diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings, clip,
"timesteps": args.dp_timesteps, amp,
"cond_drop_prob":args.dp_cond_drop_prob, save_interval,
"loss_type":args.dp_loss_type, save_path,
"clip":args.clip} pretrained_model_path
}) ):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
RESUME = False RESUME = False
@@ -325,9 +363,10 @@ def main():
RESUME = True RESUME = True
else: else:
tracker.init( tracker.init(
entity=args.wandb_entity, entity = wandb_entity,
project=args.wandb_project, project = wandb_project,
config=config) config = config
)
# Obtain the utilized device. # Obtain the utilized device.
@@ -337,36 +376,36 @@ def main():
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Training loop # Training loop
train(args.image_embed_dim, train(image_embed_dim,
args.image_embed_url, image_embed_url,
args.text_embed_url, text_embed_url,
args.batch_size, batch_size,
args.train_percent, train_percent,
args.val_percent, val_percent,
args.test_percent, test_percent,
args.num_epochs, num_epochs,
args.dp_loss_type, dp_loss_type,
args.clip, clip,
args.dp_condition_on_text_encodings, dp_condition_on_text_encodings,
args.dp_timesteps, dp_timesteps,
args.dp_normformer, dp_normformer,
args.dp_cond_drop_prob, dp_cond_drop_prob,
args.dpn_depth, dpn_depth,
args.dpn_dim_head, dpn_dim_head,
args.dpn_heads, dpn_heads,
args.save_interval, save_interval,
args.save_path, save_path,
device, device,
RESUME, RESUME,
DPRIOR_PATH, DPRIOR_PATH,
config, config,
args.wandb_entity, wandb_entity,
args.wandb_project, wandb_project,
args.learning_rate, learning_rate,
args.max_grad_norm, max_grad_norm,
args.weight_decay, weight_decay,
args.dropout, dropout,
args.amp) amp)
if __name__ == "__main__": if __name__ == "__main__":
main() main()