mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
cleanup to use diffusion prior trainer
This commit is contained in:
@@ -6,22 +6,24 @@ import numpy as np
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
|
||||||
|
|
||||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
|
from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
|
||||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||||
|
|
||||||
from embedding_reader import EmbeddingReader
|
from embedding_reader import EmbeddingReader
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# constants
|
||||||
|
|
||||||
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
|
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
|
||||||
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
|
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
|
||||||
|
|
||||||
tracker = WandbTracker()
|
tracker = WandbTracker()
|
||||||
|
|
||||||
|
# functions
|
||||||
|
|
||||||
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -126,7 +128,8 @@ def train(image_embed_dim,
|
|||||||
dropout=0.05,
|
dropout=0.05,
|
||||||
amp=False):
|
amp=False):
|
||||||
|
|
||||||
# DiffusionPriorNetwork
|
# diffusion prior network
|
||||||
|
|
||||||
prior_network = DiffusionPriorNetwork(
|
prior_network = DiffusionPriorNetwork(
|
||||||
dim = image_embed_dim,
|
dim = image_embed_dim,
|
||||||
depth = dpn_depth,
|
depth = dpn_depth,
|
||||||
@@ -134,9 +137,11 @@ def train(image_embed_dim,
|
|||||||
heads = dpn_heads,
|
heads = dpn_heads,
|
||||||
attn_dropout = dropout,
|
attn_dropout = dropout,
|
||||||
ff_dropout = dropout,
|
ff_dropout = dropout,
|
||||||
normformer = dp_normformer).to(device)
|
normformer = dp_normformer
|
||||||
|
)
|
||||||
|
|
||||||
|
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
|
||||||
diffusion_prior = DiffusionPrior(
|
diffusion_prior = DiffusionPrior(
|
||||||
net = prior_network,
|
net = prior_network,
|
||||||
clip = clip,
|
clip = clip,
|
||||||
@@ -144,29 +149,43 @@ def train(image_embed_dim,
|
|||||||
timesteps = dp_timesteps,
|
timesteps = dp_timesteps,
|
||||||
cond_drop_prob = dp_cond_drop_prob,
|
cond_drop_prob = dp_cond_drop_prob,
|
||||||
loss_type = dp_loss_type,
|
loss_type = dp_loss_type,
|
||||||
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
|
condition_on_text_encodings = dp_condition_on_text_encodings
|
||||||
|
)
|
||||||
|
|
||||||
# Load pre-trained model from DPRIOR_PATH
|
# Load pre-trained model from DPRIOR_PATH
|
||||||
|
|
||||||
if RESUME:
|
if RESUME:
|
||||||
diffusion_prior = load_diffusion_model(DPRIOR_PATH, device)
|
diffusion_prior = load_diffusion_model(DPRIOR_PATH, device)
|
||||||
wandb.init( entity=wandb_entity, project=wandb_project, config=config)
|
|
||||||
|
# TODO, optimizer and scaler needs to be loaded as well
|
||||||
|
|
||||||
|
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
||||||
|
|
||||||
|
# diffusion prior trainer
|
||||||
|
|
||||||
|
trainer = DiffusionPriorTrainer(
|
||||||
|
diffusion_prior = diffusion_prior,
|
||||||
|
lr = learning_rate,
|
||||||
|
wd = weight_decay,
|
||||||
|
max_grad_norm = max_grad_norm,
|
||||||
|
amp = amp,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
# Create save_path if it doesn't exist
|
# Create save_path if it doesn't exist
|
||||||
|
|
||||||
if not os.path.exists(save_path):
|
if not os.path.exists(save_path):
|
||||||
os.makedirs(save_path)
|
os.makedirs(save_path)
|
||||||
|
|
||||||
# Get image and text embeddings from the servers
|
# Get image and text embeddings from the servers
|
||||||
|
|
||||||
print_ribbon("Downloading embeddings - image and text")
|
print_ribbon("Downloading embeddings - image and text")
|
||||||
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
||||||
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
||||||
num_data_points = text_reader.count
|
num_data_points = text_reader.count
|
||||||
|
|
||||||
### Training code ###
|
### Training code ###
|
||||||
scaler = GradScaler(enabled=amp)
|
|
||||||
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
|
||||||
epochs = num_epochs
|
|
||||||
|
|
||||||
step = 0
|
epochs = num_epochs
|
||||||
t = time.time()
|
t = time.time()
|
||||||
|
|
||||||
train_set_size = int(train_percent*num_data_points)
|
train_set_size = int(train_percent*num_data_points)
|
||||||
@@ -178,18 +197,17 @@ def train(image_embed_dim,
|
|||||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
|
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
|
||||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||||
|
|
||||||
diffusion_prior.train()
|
trainer.train()
|
||||||
|
|
||||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||||
|
|
||||||
with autocast(enabled=amp):
|
loss = trainer(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
||||||
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
|
|
||||||
# Samples per second
|
# Samples per second
|
||||||
step+=1
|
|
||||||
samples_per_sec = batch_size*step/(time.time()-t)
|
samples_per_sec = batch_size*step/(time.time()-t)
|
||||||
|
|
||||||
# Save checkpoint every save_interval minutes
|
# Save checkpoint every save_interval minutes
|
||||||
if(int(time.time()-t) >= 60*save_interval):
|
if(int(time.time()-t) >= 60*save_interval):
|
||||||
t = time.time()
|
t = time.time()
|
||||||
@@ -197,8 +215,8 @@ def train(image_embed_dim,
|
|||||||
save_diffusion_model(
|
save_diffusion_model(
|
||||||
save_path,
|
save_path,
|
||||||
diffusion_prior,
|
diffusion_prior,
|
||||||
optimizer,
|
trainer.optimizer,
|
||||||
scaler,
|
trainer.scaler,
|
||||||
config,
|
config,
|
||||||
image_embed_dim)
|
image_embed_dim)
|
||||||
|
|
||||||
@@ -227,12 +245,7 @@ def train(image_embed_dim,
|
|||||||
dp_loss_type,
|
dp_loss_type,
|
||||||
phase="Validation")
|
phase="Validation")
|
||||||
|
|
||||||
scaler.unscale_(optimizer)
|
trainer.update()
|
||||||
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
|
||||||
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
### Test run ###
|
### Test run ###
|
||||||
test_set_size = int(test_percent*train_set_size)
|
test_set_size = int(test_percent*train_set_size)
|
||||||
@@ -303,8 +316,11 @@ def main():
|
|||||||
})
|
})
|
||||||
|
|
||||||
RESUME = False
|
RESUME = False
|
||||||
|
|
||||||
# Check if DPRIOR_PATH exists(saved model path)
|
# Check if DPRIOR_PATH exists(saved model path)
|
||||||
|
|
||||||
DPRIOR_PATH = args.pretrained_model_path
|
DPRIOR_PATH = args.pretrained_model_path
|
||||||
|
|
||||||
if(DPRIOR_PATH is not None):
|
if(DPRIOR_PATH is not None):
|
||||||
RESUME = True
|
RESUME = True
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user