Compare commits

...

7 Commits
0.2.1 ... 0.2.4

Author SHA1 Message Date
Phil Wang
9b322ea634 patch 2022-05-09 19:46:19 -07:00
Phil Wang
ba64ea45cc 0.2.3 2022-05-09 16:50:31 -07:00
Phil Wang
64f7be1926 some cleanup 2022-05-09 16:50:21 -07:00
Phil Wang
db805e73e1 fix a bug with numerical stability in attention, sorry! 🐛 2022-05-09 16:23:37 -07:00
z
cb07b37970 Ensure Eval Mode In Metric Functions (#79)
* add eval/train toggles

* train/eval flags

* shift train toggle

Co-authored-by: nousr <z@localhost.com>
2022-05-09 16:05:40 -07:00
Phil Wang
a774bfefe2 add attention and feedforward dropouts to train_diffusion_prior script 2022-05-09 13:57:15 -07:00
Phil Wang
2ae57f0cf5 cleanup 2022-05-09 13:51:26 -07:00
6 changed files with 62 additions and 46 deletions

View File

@@ -933,7 +933,7 @@ Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/r
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory. Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
## from dalle2_pytorch import load_diffusion_model, save_diffusion_model ## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
load_diffusion_model(dprior_path, device) load_diffusion_model(dprior_path, device)

View File

@@ -1,4 +1,4 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder,load_diffusion_model,save_diffusion_model from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer

View File

@@ -5,7 +5,6 @@ from functools import partial
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
import time
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -34,42 +33,6 @@ from rotary_embedding_torch import RotaryEmbedding
from x_clip import CLIP from x_clip import CLIP
from coca_pytorch import CoCa from coca_pytorch import CoCa
# Diffusion Prior model loading and saving functions
def load_diffusion_model(dprior_path, device ):
dprior_path = Path(dprior_path)
assert dprior_path.exists(), 'Dprior model file does not exist'
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
# Get hyperparameters of loaded model
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
dp_config = loaded_obj['hparams']['diffusion_prior']
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict
print("====================================== Saving checkpoint ======================================")
state_dict = dict(model=model.state_dict(),
optimizer=optimizer.state_dict(),
scaler=scaler.state_dict(),
hparams = config,
image_embed_dim = {"image_embed_dim":image_embed_dim})
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
# helper functions # helper functions
def exists(val): def exists(val):
@@ -677,7 +640,7 @@ class Attention(nn.Module):
# attention # attention
sim = sim - sim.amax(dim = -1, keepdim = True) sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1) attn = sim.softmax(dim = -1)
attn = self.dropout(attn) attn = self.dropout(attn)
@@ -1204,7 +1167,7 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j') mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value) sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True) sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1) attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v) out = einsum('b h i j, b h j d -> b h i d', attn, v)

View File

@@ -1,3 +1,4 @@
import time
import copy import copy
from functools import partial from functools import partial
@@ -39,6 +40,50 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# saving and loading functions
# for diffusion prior
def load_diffusion_model(dprior_path, device):
dprior_path = Path(dprior_path)
assert dprior_path.exists(), 'Dprior model file does not exist'
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
# Get hyperparameters of loaded model
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
dp_config = loaded_obj['hparams']['diffusion_prior']
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict
print_ribbon('Saving checkpoint')
state_dict = dict(model=model.state_dict(),
optimizer=optimizer.state_dict(),
scaler=scaler.state_dict(),
hparams = config,
image_embed_dim = {"image_embed_dim":image_embed_dim})
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
# exponential moving average wrapper # exponential moving average wrapper
class EMA(nn.Module): class EMA(nn.Module):

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.1', version = '0.2.4',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -6,7 +6,8 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from embedding_reader import EmbeddingReader from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, load_diffusion_model, save_diffusion_model from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler from torch.cuda.amp import autocast,GradScaler
@@ -42,6 +43,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
wandb.log({f'{phase} {loss_type}': avg_loss}) wandb.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device): def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
diffusion_prior.eval()
cos = nn.CosineSimilarity(dim=1, eps=1e-6) cos = nn.CosineSimilarity(dim=1, eps=1e-6)
@@ -119,6 +121,7 @@ def train(image_embed_dim,
learning_rate=0.001, learning_rate=0.001,
max_grad_norm=0.5, max_grad_norm=0.5,
weight_decay=0.01, weight_decay=0.01,
dropout=0.05,
amp=False): amp=False):
# DiffusionPriorNetwork # DiffusionPriorNetwork
@@ -127,6 +130,8 @@ def train(image_embed_dim,
depth = dpn_depth, depth = dpn_depth,
dim_head = dpn_dim_head, dim_head = dpn_dim_head,
heads = dpn_heads, heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer).to(device) normformer = dp_normformer).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed # DiffusionPrior with text embeddings and image embeddings pre-computed
@@ -149,7 +154,7 @@ def train(image_embed_dim,
os.makedirs(save_path) os.makedirs(save_path)
# Get image and text embeddings from the servers # Get image and text embeddings from the servers
print("==============Downloading embeddings - image and text====================") print_ribbon("Downloading embeddings - image and text")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy") image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy") text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count num_data_points = text_reader.count
@@ -167,10 +172,12 @@ def train(image_embed_dim,
eval_start = train_set_size eval_start = train_set_size
for _ in range(epochs): for _ in range(epochs):
diffusion_prior.train()
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size), for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)): text_reader(batch_size=batch_size, start=0, end=train_set_size)):
diffusion_prior.train()
emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device)
@@ -244,6 +251,7 @@ def main():
# Hyperparameters # Hyperparameters
parser.add_argument("--learning-rate", type=float, default=1.1e-4) parser.add_argument("--learning-rate", type=float, default=1.1e-4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2) parser.add_argument("--weight-decay", type=float, default=6.02e-2)
parser.add_argument("--dropout", type=float, default=5e-2)
parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4) parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5) parser.add_argument("--num-epochs", type=int, default=5)
@@ -261,7 +269,6 @@ def main():
# DiffusionPrior(dp) parameters # DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100) parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-normformer", type=bool, default=False) parser.add_argument("--dp-normformer", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1) parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
parser.add_argument("--dp-loss-type", type=str, default="l2") parser.add_argument("--dp-loss-type", type=str, default="l2")
@@ -340,6 +347,7 @@ def main():
args.learning_rate, args.learning_rate,
args.max_grad_norm, args.max_grad_norm,
args.weight_decay, args.weight_decay,
args.dropout,
args.amp) args.amp)
if __name__ == "__main__": if __name__ == "__main__":