quick patch for new prior loader (#123)

This commit is contained in:
zion
2022-05-29 16:25:53 -07:00
committed by GitHub
parent a13d2d89c5
commit 387c5bf774

View File

@@ -7,15 +7,13 @@ import torch
import clip import clip
from torch import nn from torch import nn
from dalle2_pytorch.dataloaders import make_splits from dalle2_pytorch.dataloaders import make_splits, get_reader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer, print_ribbon from dalle2_pytorch.utils import Timer, print_ribbon
from embedding_reader import EmbeddingReader
from tqdm import tqdm from tqdm import tqdm
# constants # constants
@@ -31,7 +29,7 @@ def exists(val):
# functions # functions
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"): def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@@ -39,6 +37,8 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
total_samples = 0. total_samples = 0.
for image_embeddings, text_data in tqdm(dataloader): for image_embeddings, text_data in tqdm(dataloader):
image_embeddings = image_embeddings.to(device)
text_data = text_data.to(device)
batches = image_embeddings.shape[0] batches = image_embeddings.shape[0]
@@ -57,12 +57,14 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
tracker.log({f'{phase} {loss_type}': avg_loss}) tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned): def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
diffusion_prior.eval() diffusion_prior.eval()
cos = nn.CosineSimilarity(dim=1, eps=1e-6) cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for test_image_embeddings, text_data in tqdm(dataloader): for test_image_embeddings, text_data in tqdm(dataloader):
test_image_embeddings = test_image_embeddings.to(device)
text_data = text_data.to(device)
# we are text conditioned, we produce an embedding from the tokenized text # we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned: if text_conditioned:
@@ -240,7 +242,7 @@ def train(
# Training loop # Training loop
# diffusion prior network # diffusion prior network
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = image_embed_dim, dim = image_embed_dim,
depth = dpn_depth, depth = dpn_depth,
dim_head = dpn_dim_head, dim_head = dpn_dim_head,
@@ -249,16 +251,16 @@ def train(
ff_dropout = dropout, ff_dropout = dropout,
normformer = dp_normformer normformer = dp_normformer
) )
# Load clip model if text-conditioning # Load clip model if text-conditioning
if dp_condition_on_text_encodings: if dp_condition_on_text_encodings:
clip_adapter = OpenAIClipAdapter(clip) clip_adapter = OpenAIClipAdapter(clip)
else: else:
clip_adapter = None clip_adapter = None
# diffusion prior with text embeddings and image embeddings pre-computed # diffusion prior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior( diffusion_prior = DiffusionPrior(
net = prior_network, net = prior_network,
clip = clip_adapter, clip = clip_adapter,
image_embed_dim = image_embed_dim, image_embed_dim = image_embed_dim,
@@ -296,28 +298,46 @@ def train(
# Utilize wrapper to abstract away loader logic # Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings") print_ribbon("Downloading Embeddings")
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points, reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
if dp_condition_on_text_encodings: if dp_condition_on_text_encodings:
loader_args = dict(**loader_args, meta_url=meta_url) reader_args = dict(**reader_args, meta_url=meta_url)
img_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader
)
else: else:
loader_args = dict(**loader_args, txt_url=text_embed_url) reader_args = dict(**reader_args, txt_url=text_embed_url)
img_reader, txt_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(**loader_args) train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader,
text_reader=txt_reader
)
### Training code ### ### Training code ###
step = 1 step = 1
timer = Timer() timer = Timer()
epochs = num_epochs epochs = num_epochs
for _ in range(epochs): for _ in range(epochs):
for image, text in tqdm(train_loader): for image, text in tqdm(train_loader):
diffusion_prior.train() diffusion_prior.train()
image = image.to(device)
text = text.to(device)
input_args = dict(image_embed=image) input_args = dict(image_embed=image)
if dp_condition_on_text_encodings: if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text) input_args = dict(**input_args, text = text)
@@ -350,9 +370,9 @@ def train(
# Use NUM_TEST_EMBEDDINGS samples from the test set each time # Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model # Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0: if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings) report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
### Evaluate model(validation run) ### ### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation") eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
step += 1 step += 1
trainer.update() trainer.update()