Val loss changes, with quite a few other changes. This is in place of the earlier PR(https://github.com/lucidrains/DALLE2-pytorch/pull/67) (#77)

* Val_loss changes - no rebased with lucidrains' master.

* Val Loss changes - now rebased with lucidrains' master

* train_diffusion_prior.py updates

* dalle2_pytorch.py updates

* __init__.py changes

* Update train_diffusion_prior.py

* Update dalle2_pytorch.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update dalle2_pytorch.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update train_diffusion_prior.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
This commit is contained in:
Kumar R
2022-05-09 21:23:29 +05:30
committed by GitHub
parent 53c189e46a
commit 8647cb5e76
4 changed files with 179 additions and 86 deletions

View File

@@ -927,7 +927,39 @@ The most significant parameters for the script are as follows:
### Sample wandb run log ### Sample wandb run log
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/aul0rhv5?workspace= Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/1blxu24j
### Loading and saving the Diffusion Prior model
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
## from dalle2_pytorch import load_diffusion_model, save_diffusion_model
load_diffusion_model(dprior_path, device)
dprior_path : path to saved model(.pth)
device : the cuda device you're running on
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
save_path : path to save at
model : object of Diffusion_Prior
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
scaler : a GradScaler object.
e.g: scaler = GradScaler(enabled=amp)
config : config object created in train_diffusion_prior.py - see file for example.
image_embed_dim - the dimension of the image_embedding
e.g: 768
## CLI (wip) ## CLI (wip)

View File

@@ -1,4 +1,4 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder,load_diffusion_model,save_diffusion_model
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer

View File

@@ -4,6 +4,8 @@ from inspect import isfunction
from functools import partial from functools import partial
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
from pathlib import Path
import time
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -32,6 +34,42 @@ from rotary_embedding_torch import RotaryEmbedding
from x_clip import CLIP from x_clip import CLIP
from coca_pytorch import CoCa from coca_pytorch import CoCa
# Diffusion Prior model loading and saving functions
def load_diffusion_model(dprior_path, device ):
dprior_path = Path(dprior_path)
assert dprior_path.exists(), 'Dprior model file does not exist'
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
# Get hyperparameters of loaded model
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
dp_config = loaded_obj['hparams']['diffusion_prior']
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict
print("====================================== Saving checkpoint ======================================")
state_dict = dict(model=model.state_dict(),
optimizer=optimizer.state_dict(),
scaler=scaler.state_dict(),
hparams = config,
image_embed_dim = {"image_embed_dim":image_embed_dim})
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
# helper functions # helper functions
def exists(val): def exists(val):
@@ -1914,3 +1952,4 @@ class DALLE2(nn.Module):
return images[0] return images[0]
return images return images

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from embedding_reader import EmbeddingReader from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, load_diffusion_model, save_diffusion_model
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler from torch.cuda.amp import autocast,GradScaler
@@ -41,73 +41,55 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
avg_loss = (total_loss / total_samples) avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss}) wandb.log({f'{phase} {loss_type}': avg_loss})
def save_model(save_path, state_dict): def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
# Saving State Dict
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_size, val_set_size, NUM_TEST_EMBEDDINGS, device):
cos = nn.CosineSimilarity(dim=1, eps=1e-6) cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size+val_set_size tstart = train_set_size
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS tend = train_set_size+NUM_TEST_EMBEDDINGS
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend), image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
# make a copy of the text embeddings for shuffling
text_embed = torch.tensor(embt[0]).to(device)
text_embed_shuffled = text_embed.clone()
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend),
image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
# make a copy of the text embeddings for shuffling
text_embed = torch.tensor(embt[0]).to(device)
text_embed_shuffled = text_embed.clone()
# roll the text embeddings to simulate "unrelated" captions # roll the text embeddings to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1) rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx] text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \ text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True) text_embed_shuffled.norm(dim=1, keepdim=True)
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled) test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
# prepare the text embedding # prepare the text embedding
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True) text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed=text_embed) test_text_cond = dict(text_embed=text_embed)
# prepare image embeddings # prepare image embeddings
test_image_embeddings = torch.tensor(embi[0]).to(device) test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / \ test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True) test_image_embeddings.norm(dim=1, keepdim=True)
# predict on the unshuffled text embeddings # predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop( predicted_image_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond) (NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / \ predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True) predicted_image_embeddings.norm(dim=1, keepdim=True)
# predict on the shuffled embeddings # predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop( predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond) (NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \ predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True) predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
# calculate similarities # calculate similarities
original_similarity = cos( original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy() text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos( predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy() text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos( unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy() text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos( predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy() test_image_embeddings, predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
wandb.log( "CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
{"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)}) "CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)": np.mean( "CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
predicted_similarity)}) "Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(
unrelated_similarity)})
wandb.log({"CosineSimilarity(image_embed,predicted_image_embed)": np.mean(
predicted_img_similarity)})
return np.mean(predicted_similarity - original_similarity)
def train(image_embed_dim, def train(image_embed_dim,
image_embed_url, image_embed_url,
@@ -129,6 +111,11 @@ def train(image_embed_dim,
save_interval, save_interval,
save_path, save_path,
device, device,
RESUME,
DPRIOR_PATH,
config,
wandb_entity,
wandb_project,
learning_rate=0.001, learning_rate=0.001,
max_grad_norm=0.5, max_grad_norm=0.5,
weight_decay=0.01, weight_decay=0.01,
@@ -152,16 +139,21 @@ def train(image_embed_dim,
loss_type = dp_loss_type, loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings).to(device) condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
# Load pre-trained model from DPRIOR_PATH
if RESUME:
diffusion_prior=load_diffusion_model(DPRIOR_PATH,device)
wandb.init( entity=wandb_entity, project=wandb_project, config=config)
# Create save_path if it doesn't exist
if not os.path.exists(save_path):
os.makedirs(save_path)
# Get image and text embeddings from the servers # Get image and text embeddings from the servers
print("==============Downloading embeddings - image and text====================") print("==============Downloading embeddings - image and text====================")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy") image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy") text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count num_data_points = text_reader.count
# Create save_path if it doesn't exist
if not os.path.exists(save_path):
os.makedirs(save_path)
### Training code ### ### Training code ###
scaler = GradScaler(enabled=amp) scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
@@ -172,6 +164,7 @@ def train(image_embed_dim,
train_set_size = int(train_percent*num_data_points) train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points) val_set_size = int(val_percent*num_data_points)
eval_start = train_set_size
for _ in range(epochs): for _ in range(epochs):
diffusion_prior.train() diffusion_prior.train()
@@ -192,9 +185,13 @@ def train(image_embed_dim,
if(int(time.time()-t) >= 60*save_interval): if(int(time.time()-t) >= 60*save_interval):
t = time.time() t = time.time()
save_model( save_diffusion_model(
save_path, save_path,
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict())) diffusion_prior,
optimizer,
scaler,
config,
image_embed_dim)
# Log to wandb # Log to wandb
wandb.log({"Training loss": loss.item(), wandb.log({"Training loss": loss.item(),
@@ -204,14 +201,22 @@ def train(image_embed_dim,
# Use NUM_TEST_EMBEDDINGS samples from the test set each time # Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model # Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0: if(step % REPORT_METRICS_EVERY) == 0:
diff_cosine_sim = report_cosine_sims(diffusion_prior, report_cosine_sims(diffusion_prior,
image_reader, image_reader,
text_reader, text_reader,
train_set_size, train_set_size,
val_set_size,
NUM_TEST_EMBEDDINGS, NUM_TEST_EMBEDDINGS,
device) device)
wandb.log({"Cosine similarity difference": diff_cosine_sim}) ### Evaluate model(validation run) ###
eval_model(diffusion_prior,
device,
image_reader,
text_reader,
eval_start,
eval_start+NUM_TEST_EMBEDDINGS,
NUM_TEST_EMBEDDINGS,
dp_loss_type,
phase="Validation")
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
@@ -220,11 +225,6 @@ def train(image_embed_dim,
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
### Evaluate model(validation run) ###
start = train_set_size
end=start+val_set_size
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Validation")
### Test run ### ### Test run ###
test_set_size = int(test_percent*train_set_size) test_set_size = int(test_percent*train_set_size)
start=train_set_size+val_set_size start=train_set_size+val_set_size
@@ -236,7 +236,6 @@ def main():
# Logging # Logging
parser.add_argument("--wandb-entity", type=str, default="laion") parser.add_argument("--wandb-entity", type=str, default="laion")
parser.add_argument("--wandb-project", type=str, default="diffusion-prior") parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
parser.add_argument("--wandb-name", type=str, default="laion-dprior")
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B") parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior") parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
# URLs for embeddings # URLs for embeddings
@@ -271,22 +270,40 @@ def main():
# Model checkpointing interval(minutes) # Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30) parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints") parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
# Saved model path
parser.add_argument("--pretrained-model-path", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
print("Setting up wandb logging... Please wait...") config = ({"learning_rate": args.learning_rate,
"architecture": args.wandb_arch,
"dataset": args.wandb_dataset,
"weight_decay":args.weight_decay,
"max_gradient_clipping_norm":args.max_grad_norm,
"batch_size":args.batch_size,
"epochs": args.num_epochs,
"diffusion_prior_network":{"depth":args.dpn_depth,
"dim_head":args.dpn_dim_head,
"heads":args.dpn_heads,
"normformer":args.dp_normformer},
"diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings,
"timesteps": args.dp_timesteps,
"cond_drop_prob":args.dp_cond_drop_prob,
"loss_type":args.dp_loss_type,
"clip":args.clip}
})
wandb.init( RESUME = False
entity=args.wandb_entity, # Check if DPRIOR_PATH exists(saved model path)
project=args.wandb_project, DPRIOR_PATH = args.pretrained_model_path
config={ if(DPRIOR_PATH is not None):
"learning_rate": args.learning_rate, RESUME = True
"architecture": args.wandb_arch, else:
"dataset": args.wandb_dataset, wandb.init(
"epochs": args.num_epochs, entity=args.wandb_entity,
}) project=args.wandb_project,
config=config)
print("wandb logging setup done!")
# Obtain the utilized device. # Obtain the utilized device.
has_cuda = torch.cuda.is_available() has_cuda = torch.cuda.is_available()
@@ -315,6 +332,11 @@ def main():
args.save_interval, args.save_interval,
args.save_path, args.save_path,
device, device,
RESUME,
DPRIOR_PATH,
config,
atgs.wandb_entity,
args.wandb_project,
args.learning_rate, args.learning_rate,
args.max_grad_norm, args.max_grad_norm,
args.weight_decay, args.weight_decay,