From 387c5bf77494dec8d4d566343e893bc84ef30d03 Mon Sep 17 00:00:00 2001 From: zion <51308183+nousr@users.noreply.github.com> Date: Sun, 29 May 2022 16:25:53 -0700 Subject: [PATCH] quick patch for new prior loader (#123) --- train_diffusion_prior.py | 60 ++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 3a625de..76df513 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -7,15 +7,13 @@ import torch import clip from torch import nn -from dalle2_pytorch.dataloaders import make_splits +from dalle2_pytorch.dataloaders import make_splits, get_reader from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker from dalle2_pytorch.utils import Timer, print_ribbon -from embedding_reader import EmbeddingReader - from tqdm import tqdm # constants @@ -31,7 +29,7 @@ def exists(val): # functions -def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"): +def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",): model.eval() with torch.no_grad(): @@ -39,6 +37,8 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation total_samples = 0. for image_embeddings, text_data in tqdm(dataloader): + image_embeddings = image_embeddings.to(device) + text_data = text_data.to(device) batches = image_embeddings.shape[0] @@ -57,12 +57,14 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation tracker.log({f'{phase} {loss_type}': avg_loss}) -def report_cosine_sims(diffusion_prior, dataloader, text_conditioned): +def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device): diffusion_prior.eval() cos = nn.CosineSimilarity(dim=1, eps=1e-6) for test_image_embeddings, text_data in tqdm(dataloader): + test_image_embeddings = test_image_embeddings.to(device) + text_data = text_data.to(device) # we are text conditioned, we produce an embedding from the tokenized text if text_conditioned: @@ -240,7 +242,7 @@ def train( # Training loop # diffusion prior network - prior_network = DiffusionPriorNetwork( + prior_network = DiffusionPriorNetwork( dim = image_embed_dim, depth = dpn_depth, dim_head = dpn_dim_head, @@ -249,16 +251,16 @@ def train( ff_dropout = dropout, normformer = dp_normformer ) - + # Load clip model if text-conditioning if dp_condition_on_text_encodings: clip_adapter = OpenAIClipAdapter(clip) else: clip_adapter = None - + # diffusion prior with text embeddings and image embeddings pre-computed - diffusion_prior = DiffusionPrior( + diffusion_prior = DiffusionPrior( net = prior_network, clip = clip_adapter, image_embed_dim = image_embed_dim, @@ -296,28 +298,46 @@ def train( # Utilize wrapper to abstract away loader logic print_ribbon("Downloading Embeddings") - loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points, - train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url) + reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url) if dp_condition_on_text_encodings: - loader_args = dict(**loader_args, meta_url=meta_url) + reader_args = dict(**reader_args, meta_url=meta_url) + img_reader = get_reader(**reader_args) + train_loader, eval_loader, test_loader = make_splits( + text_conditioned=dp_condition_on_text_encodings, + batch_size=batch_size, + num_data_points=num_data_points, + train_split=train_percent, + eval_split=val_percent, + image_reader=img_reader + ) else: - loader_args = dict(**loader_args, txt_url=text_embed_url) - - train_loader, eval_loader, test_loader = make_splits(**loader_args) + reader_args = dict(**reader_args, txt_url=text_embed_url) + img_reader, txt_reader = get_reader(**reader_args) + train_loader, eval_loader, test_loader = make_splits( + text_conditioned=dp_condition_on_text_encodings, + batch_size=batch_size, + num_data_points=num_data_points, + train_split=train_percent, + eval_split=val_percent, + image_reader=img_reader, + text_reader=txt_reader + ) ### Training code ### - step = 1 + step = 1 timer = Timer() epochs = num_epochs for _ in range(epochs): for image, text in tqdm(train_loader): - diffusion_prior.train() - + + image = image.to(device) + text = text.to(device) + input_args = dict(image_embed=image) if dp_condition_on_text_encodings: input_args = dict(**input_args, text = text) @@ -350,9 +370,9 @@ def train( # Use NUM_TEST_EMBEDDINGS samples from the test set each time # Get embeddings from the most recently saved model if(step % REPORT_METRICS_EVERY) == 0: - report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings) + report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device) ### Evaluate model(validation run) ### - eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation") + eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device) step += 1 trainer.update()