diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 6b50bec..78e98ab 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -1,28 +1,30 @@ -import argparse import os -from dalle2_pytorch import DiffusionPrior -from embedding_reader import EmbeddingReader -from dalle2_pytorch import DiffusionPriorNetwork -from dalle2_pytorch.optimizer import get_optimizer import math -import time -from tqdm import tqdm +import argparse + import torch from torch import nn +from embedding_reader import EmbeddingReader +from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork +from dalle2_pytorch.optimizer import get_optimizer + +import time +from tqdm import tqdm + import wandb os.environ["WANDB_SILENT"] = "true" def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): + model.eval() with torch.no_grad(): for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end), text_reader(batch_size=batch_size, start=start, end=end)): emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device) - model.eval() - loss = model(text_embed = emb_text_tensor,image_embed = emb_images_tensor) + loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor) # Log to wandb - wandb.log({phase + " " + loss_type: loss}) + wandb.log({f'{phase} {loss_type}': loss}) def save_model(save_path,state_dict): # Saving State Dict @@ -48,7 +50,9 @@ def train(image_embed_dim, save_interval, save_path, device, - learning_rate=0.01): + learning_rate=0.001, + max_grad_norm=0.5, + weight_decay=0.01): # DiffusionPriorNetwork prior_network = DiffusionPriorNetwork( @@ -78,14 +82,18 @@ def train(image_embed_dim, os.makedirs(save_path) ### Training code ### - optimizer = get_optimizer(diffusion_prior.parameters()) + optimizer = get_optimizer(diffusion_prior.parameters(), wd=weight_decay, lr=learning_rate) epochs = num_epochs step = 0 t = time.time() + train_set_size = int(train_percent*num_data_points) val_set_size = int(val_percent*num_data_points) + for _ in range(epochs): + diffusion_prior.train() + for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size), text_reader(batch_size=batch_size, start=0, end=train_set_size)): emb_images_tensor = torch.tensor(emb_images[0]).to(device) @@ -104,6 +112,8 @@ def train(image_embed_dim, wandb.log({"Training loss": loss.item(), "Steps": step, "Samples per second": samples_per_sec}) + + nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) optimizer.step() ### Evaluate model(validation run) ### @@ -129,7 +139,9 @@ def main(): parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") # Hyperparameters - parser.add_argument("--learning-rate", type=float, default=0.01) + parser.add_argument("--learning-rate", type=float, default=0.001) + parser.add_argument("--weight-decay", type=float, default=0.01) + parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--batch-size", type=int, default=10**4) parser.add_argument("--num-epochs", type=int, default=5) # Image embed dimension @@ -193,7 +205,9 @@ def main(): args.save_interval, args.save_path, device, - args.learning_rate) + args.learning_rate, + args.max_grad_norm, + args.weight_decay) if __name__ == "__main__": main()