mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
quick patch for new prior loader (#123)
This commit is contained in:
@@ -7,15 +7,13 @@ import torch
|
||||
import clip
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.dataloaders import make_splits
|
||||
from dalle2_pytorch.dataloaders import make_splits, get_reader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
# constants
|
||||
@@ -31,7 +29,7 @@ def exists(val):
|
||||
|
||||
# functions
|
||||
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -39,6 +37,8 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
||||
total_samples = 0.
|
||||
|
||||
for image_embeddings, text_data in tqdm(dataloader):
|
||||
image_embeddings = image_embeddings.to(device)
|
||||
text_data = text_data.to(device)
|
||||
|
||||
batches = image_embeddings.shape[0]
|
||||
|
||||
@@ -57,12 +57,14 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
||||
|
||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
|
||||
diffusion_prior.eval()
|
||||
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||
test_image_embeddings = test_image_embeddings.to(device)
|
||||
text_data = text_data.to(device)
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
@@ -296,15 +298,31 @@ def train(
|
||||
|
||||
# Utilize wrapper to abstract away loader logic
|
||||
print_ribbon("Downloading Embeddings")
|
||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
||||
reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
|
||||
|
||||
if dp_condition_on_text_encodings:
|
||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
||||
reader_args = dict(**reader_args, meta_url=meta_url)
|
||||
img_reader = get_reader(**reader_args)
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=dp_condition_on_text_encodings,
|
||||
batch_size=batch_size,
|
||||
num_data_points=num_data_points,
|
||||
train_split=train_percent,
|
||||
eval_split=val_percent,
|
||||
image_reader=img_reader
|
||||
)
|
||||
else:
|
||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
||||
reader_args = dict(**reader_args, txt_url=text_embed_url)
|
||||
img_reader, txt_reader = get_reader(**reader_args)
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=dp_condition_on_text_encodings,
|
||||
batch_size=batch_size,
|
||||
num_data_points=num_data_points,
|
||||
train_split=train_percent,
|
||||
eval_split=val_percent,
|
||||
image_reader=img_reader,
|
||||
text_reader=txt_reader
|
||||
)
|
||||
|
||||
### Training code ###
|
||||
|
||||
@@ -315,9 +333,11 @@ def train(
|
||||
for _ in range(epochs):
|
||||
|
||||
for image, text in tqdm(train_loader):
|
||||
|
||||
diffusion_prior.train()
|
||||
|
||||
image = image.to(device)
|
||||
text = text.to(device)
|
||||
|
||||
input_args = dict(image_embed=image)
|
||||
if dp_condition_on_text_encodings:
|
||||
input_args = dict(**input_args, text = text)
|
||||
@@ -350,9 +370,9 @@ def train(
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
|
||||
### Evaluate model(validation run) ###
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
|
||||
|
||||
step += 1
|
||||
trainer.update()
|
||||
|
||||
Reference in New Issue
Block a user