diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 3c2bf80..ce2439e 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -36,7 +36,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t avg_loss = (total_loss / total_samples) wandb.log({f'{phase} {loss_type}': avg_loss}) -def save_model(save_path,state_dict): +def save_model(save_path, state_dict): # Saving State Dict print("====================================== Saving checkpoint ======================================") torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth') @@ -62,7 +62,8 @@ def train(image_embed_dim, device, learning_rate=0.001, max_grad_norm=0.5, - weight_decay=0.01): + weight_decay=0.01, + amp=False): # DiffusionPriorNetwork prior_network = DiffusionPriorNetwork( @@ -92,6 +93,7 @@ def train(image_embed_dim, os.makedirs(save_path) ### Training code ### + scaler = GradScaler(enabled=amp) optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) epochs = num_epochs @@ -108,23 +110,33 @@ def train(image_embed_dim, text_reader(batch_size=batch_size, start=0, end=train_set_size)): emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device) - optimizer.zero_grad() - loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor) - loss.backward() + + with autocast(enabled=amp): + loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor) + scaler.scale(loss).backward() + # Samples per second step+=1 samples_per_sec = batch_size*step/(time.time()-t) # Save checkpoint every save_interval minutes if(int(time.time()-t) >= 60*save_interval): t = time.time() - save_model(save_path,diffusion_prior.state_dict()) + + save_model( + save_path, + dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict())) + # Log to wandb wandb.log({"Training loss": loss.item(), "Steps": step, "Samples per second": samples_per_sec}) + scaler.unscale_(optimizer) nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) - optimizer.step() + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() ### Evaluate model(validation run) ### start = train_set_size @@ -171,12 +183,15 @@ def main(): parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2) parser.add_argument("--dp-loss-type", type=str, default="l2") parser.add_argument("--clip", type=str, default=None) + parser.add_argument("--amp", type=bool, default=False) # Model checkpointing interval(minutes) parser.add_argument("--save-interval", type=int, default=30) parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints") args = parser.parse_args() + print("Setting up wandb logging... Please wait...") + wandb.init( entity=args.wandb_entity, project=args.wandb_project, @@ -186,6 +201,7 @@ def main(): "dataset": args.wandb_dataset, "epochs": args.num_epochs, }) + print("wandb logging setup done!") # Obtain the utilized device. @@ -216,7 +232,8 @@ def main(): device, args.learning_rate, args.max_grad_norm, - args.weight_decay) + args.weight_decay, + args.amp) if __name__ == "__main__": main()