add gradient clipping, make sure weight decay is configurable, make sure learning rate is actually passed into get_optimizer, make sure model is set to training mode at beginning of each epoch

This commit is contained in:
Phil Wang
2022-05-01 11:55:38 -07:00
parent 53ce6dfdf6
commit 35cd63982d

View File

@@ -1,28 +1,30 @@
import argparse
import os import os
from dalle2_pytorch import DiffusionPrior
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer
import math import math
import time import argparse
from tqdm import tqdm
import torch import torch
from torch import nn from torch import nn
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer
import time
from tqdm import tqdm
import wandb import wandb
os.environ["WANDB_SILENT"] = "true" os.environ["WANDB_SILENT"] = "true"
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval()
with torch.no_grad(): with torch.no_grad():
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end), for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
text_reader(batch_size=batch_size, start=start, end=end)): text_reader(batch_size=batch_size, start=start, end=end)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device)
model.eval()
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor) loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
# Log to wandb # Log to wandb
wandb.log({phase + " " + loss_type: loss}) wandb.log({f'{phase} {loss_type}': loss})
def save_model(save_path,state_dict): def save_model(save_path,state_dict):
# Saving State Dict # Saving State Dict
@@ -48,7 +50,9 @@ def train(image_embed_dim,
save_interval, save_interval,
save_path, save_path,
device, device,
learning_rate=0.01): learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01):
# DiffusionPriorNetwork # DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
@@ -78,14 +82,18 @@ def train(image_embed_dim,
os.makedirs(save_path) os.makedirs(save_path)
### Training code ### ### Training code ###
optimizer = get_optimizer(diffusion_prior.parameters()) optimizer = get_optimizer(diffusion_prior.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs epochs = num_epochs
step = 0 step = 0
t = time.time() t = time.time()
train_set_size = int(train_percent*num_data_points) train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points) val_set_size = int(val_percent*num_data_points)
for _ in range(epochs): for _ in range(epochs):
diffusion_prior.train()
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size), for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)): text_reader(batch_size=batch_size, start=0, end=train_set_size)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_images_tensor = torch.tensor(emb_images[0]).to(device)
@@ -104,6 +112,8 @@ def train(image_embed_dim,
wandb.log({"Training loss": loss.item(), wandb.log({"Training loss": loss.item(),
"Steps": step, "Steps": step,
"Samples per second": samples_per_sec}) "Samples per second": samples_per_sec})
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
optimizer.step() optimizer.step()
### Evaluate model(validation run) ### ### Evaluate model(validation run) ###
@@ -129,7 +139,9 @@ def main():
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
# Hyperparameters # Hyperparameters
parser.add_argument("--learning-rate", type=float, default=0.01) parser.add_argument("--learning-rate", type=float, default=0.001)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4) parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5) parser.add_argument("--num-epochs", type=int, default=5)
# Image embed dimension # Image embed dimension
@@ -193,7 +205,9 @@ def main():
args.save_interval, args.save_interval,
args.save_path, args.save_path,
device, device,
args.learning_rate) args.learning_rate,
args.max_grad_norm,
args.weight_decay)
if __name__ == "__main__": if __name__ == "__main__":
main() main()