mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Migrate to text-conditioned prior training (#95)
* migrate to conditioned prior * unify reader logic with a wrapper (#1) * separate out reader logic * support both training methods * Update train prior to use embedding wrapper (#3) * Support Both Methods * bug fixes * small bug fixes * embedding only wrapper bug * use smaller val perc * final bug fix for embedding-only Co-authored-by: nousr <>
This commit is contained in:
@@ -5,9 +5,10 @@ import time
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import clip
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||
from dalle2_pytorch.dataloaders import make_splits
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
|
||||
@@ -17,8 +18,7 @@ from tqdm import tqdm
|
||||
|
||||
# constants
|
||||
|
||||
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
|
||||
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
|
||||
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
|
||||
|
||||
tracker = WandbTracker()
|
||||
|
||||
@@ -36,81 +36,106 @@ class Timer:
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
|
||||
# functions
|
||||
|
||||
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
total_loss = 0.
|
||||
total_samples = 0.
|
||||
|
||||
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
|
||||
text_reader(batch_size=batch_size, start=start, end=end)):
|
||||
for image_embeddings, text_data in tqdm(dataloader):
|
||||
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
batches = image_embeddings.shape[0]
|
||||
|
||||
batches = emb_images_tensor.shape[0]
|
||||
input_args = dict(image_embed=image_embeddings)
|
||||
if text_conditioned:
|
||||
input_args = dict(**input_args, text = text_data)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text_data)
|
||||
|
||||
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
||||
loss = model(**input_args)
|
||||
|
||||
total_loss += loss.item() * batches
|
||||
total_loss += loss * batches
|
||||
total_samples += batches
|
||||
|
||||
avg_loss = (total_loss / total_samples)
|
||||
|
||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
diffusion_prior.eval()
|
||||
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
tstart = train_set_size
|
||||
tend = train_set_size+NUM_TEST_EMBEDDINGS
|
||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
|
||||
text_data)
|
||||
text_cond = dict(text_embed=text_embedding,
|
||||
text_encodings=text_encodings, mask=text_mask)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
text_cond = dict(text_embed=text_embedding)
|
||||
|
||||
# make a copy of the text embeddings for shuffling
|
||||
text_embed_shuffled = text_embedding.clone()
|
||||
|
||||
# roll the text to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / \
|
||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
text_mask_shuffled = text_mask[rolled_idx]
|
||||
else:
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
|
||||
|
||||
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend),
|
||||
image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
|
||||
# make a copy of the text embeddings for shuffling
|
||||
text_embed = torch.tensor(embt[0]).to(device)
|
||||
text_embed_shuffled = text_embed.clone()
|
||||
# roll the text embeddings to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / \
|
||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
|
||||
# prepare the text embedding
|
||||
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
|
||||
test_text_cond = dict(text_embed=text_embed)
|
||||
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
|
||||
|
||||
# prepare image embeddings
|
||||
test_image_embeddings = torch.tensor(embi[0]).to(device)
|
||||
test_image_embeddings = test_image_embeddings / \
|
||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
test_image_embeddings = test_image_embeddings / \
|
||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / \
|
||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / \
|
||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# predict on the shuffled embeddings
|
||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled)
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
# calculate similarities
|
||||
original_similarity = cos(
|
||||
original_similarity = cos(
|
||||
text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(
|
||||
predicted_similarity = cos(
|
||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = cos(
|
||||
unrelated_similarity = cos(
|
||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
predicted_img_similarity = cos(
|
||||
predicted_img_similarity = cos(
|
||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
||||
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
||||
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--wandb-entity", default="laion")
|
||||
@click.option("--wandb-project", default="diffusion-prior")
|
||||
@@ -118,29 +143,32 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N
|
||||
@click.option("--wandb-arch", default="DiffusionPrior")
|
||||
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
|
||||
@click.option("--learning-rate", default=1.1e-4)
|
||||
@click.option("--weight-decay", default=6.02e-2)
|
||||
@click.option("--dropout", default=5e-2)
|
||||
@click.option("--max-grad-norm", default=0.5)
|
||||
@click.option("--batch-size", default=10**4)
|
||||
@click.option("--num-data-points", default=250e6)
|
||||
@click.option("--batch-size", default=320)
|
||||
@click.option("--num-epochs", default=5)
|
||||
@click.option("--image-embed-dim", default=768)
|
||||
@click.option("--train-percent", default=0.7)
|
||||
@click.option("--val-percent", default=0.2)
|
||||
@click.option("--test-percent", default=0.1)
|
||||
@click.option("--dpn-depth", default=6)
|
||||
@click.option("--train-percent", default=0.9)
|
||||
@click.option("--val-percent", default=1e-7)
|
||||
@click.option("--test-percent", default=0.0999999)
|
||||
@click.option("--dpn-depth", default=12)
|
||||
@click.option("--dpn-dim-head", default=64)
|
||||
@click.option("--dpn-heads", default=8)
|
||||
@click.option("--dp-condition-on-text-encodings", default=False)
|
||||
@click.option("--dp-timesteps", default=100)
|
||||
@click.option("--dp-normformer", default=False)
|
||||
@click.option("--dpn-heads", default=12)
|
||||
@click.option("--dp-condition-on-text-encodings", default=True)
|
||||
@click.option("--dp-timesteps", default=1000)
|
||||
@click.option("--dp-normformer", default=True)
|
||||
@click.option("--dp-cond-drop-prob", default=0.1)
|
||||
@click.option("--dp-loss-type", default="l2")
|
||||
@click.option("--clip", default=None)
|
||||
@click.option("--clip", default="ViT-L/14")
|
||||
@click.option("--amp", default=False)
|
||||
@click.option("--save-interval", default=30)
|
||||
@click.option("--save-interval", default=120)
|
||||
@click.option("--save-path", default="./diffusion_prior_checkpoints")
|
||||
@click.option("--pretrained-model-path", default=None)
|
||||
@click.option("--gpu-device", default=0)
|
||||
def train(
|
||||
wandb_entity,
|
||||
wandb_project,
|
||||
@@ -148,10 +176,12 @@ def train(
|
||||
wandb_arch,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
meta_url,
|
||||
learning_rate,
|
||||
weight_decay,
|
||||
dropout,
|
||||
max_grad_norm,
|
||||
num_data_points,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
image_embed_dim,
|
||||
@@ -170,7 +200,8 @@ def train(
|
||||
amp,
|
||||
save_interval,
|
||||
save_path,
|
||||
pretrained_model_path
|
||||
pretrained_model_path,
|
||||
gpu_device
|
||||
):
|
||||
config = {
|
||||
"learning_rate": learning_rate,
|
||||
@@ -197,7 +228,7 @@ def train(
|
||||
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
|
||||
DPRIOR_PATH = args.pretrained_model_path
|
||||
DPRIOR_PATH = pretrained_model_path
|
||||
RESUME = exists(DPRIOR_PATH)
|
||||
|
||||
if not RESUME:
|
||||
@@ -211,7 +242,7 @@ def train(
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device("cuda:0")
|
||||
device = torch.device(f"cuda:{gpu_device}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Training loop
|
||||
@@ -227,11 +258,17 @@ def train(
|
||||
normformer = dp_normformer
|
||||
)
|
||||
|
||||
# Load clip model if text-conditioning
|
||||
if dp_condition_on_text_encodings:
|
||||
clip_adapter = OpenAIClipAdapter(clip)
|
||||
else:
|
||||
clip_adapter = None
|
||||
|
||||
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
clip = clip_adapter,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
@@ -265,33 +302,37 @@ def train(
|
||||
|
||||
Path(save_path).mkdir(exist_ok = True, parents = True)
|
||||
|
||||
# Get image and text embeddings from the servers
|
||||
# Utilize wrapper to abstract away loader logic
|
||||
print_ribbon("Downloading Embeddings")
|
||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
||||
|
||||
print_ribbon("Downloading embeddings - image and text")
|
||||
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
||||
num_data_points = text_reader.count
|
||||
if dp_condition_on_text_encodings:
|
||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
||||
else:
|
||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
||||
|
||||
### Training code ###
|
||||
|
||||
step = 1
|
||||
timer = Timer()
|
||||
epochs = num_epochs
|
||||
|
||||
train_set_size = int(train_percent*num_data_points)
|
||||
val_set_size = int(val_percent*num_data_points)
|
||||
eval_start = train_set_size
|
||||
|
||||
for _ in range(epochs):
|
||||
|
||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
|
||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||
|
||||
trainer.train()
|
||||
for image, text in tqdm(train_loader):
|
||||
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
diffusion_prior.train()
|
||||
|
||||
input_args = dict(image_embed=image)
|
||||
if dp_condition_on_text_encodings:
|
||||
input_args = dict(**input_args, text = text)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text)
|
||||
|
||||
loss = trainer(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
||||
loss = trainer(**input_args)
|
||||
|
||||
# Samples per second
|
||||
|
||||
@@ -310,37 +351,23 @@ def train(
|
||||
image_embed_dim)
|
||||
|
||||
# Log to wandb
|
||||
tracker.log({"Training loss": loss.item(),
|
||||
tracker.log({"Training loss": loss,
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
report_cosine_sims(diffusion_prior,
|
||||
image_reader,
|
||||
text_reader,
|
||||
train_set_size,
|
||||
NUM_TEST_EMBEDDINGS,
|
||||
device)
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
||||
### Evaluate model(validation run) ###
|
||||
eval_model(diffusion_prior,
|
||||
device,
|
||||
image_reader,
|
||||
text_reader,
|
||||
eval_start,
|
||||
eval_start+NUM_TEST_EMBEDDINGS,
|
||||
NUM_TEST_EMBEDDINGS,
|
||||
dp_loss_type,
|
||||
phase="Validation")
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
||||
|
||||
step += 1
|
||||
trainer.update()
|
||||
|
||||
### Test run ###
|
||||
test_set_size = int(test_percent*train_set_size)
|
||||
start = train_set_size+val_set_size
|
||||
end = num_data_points
|
||||
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
|
||||
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
|
||||
Reference in New Issue
Block a user