mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
quick patch for new prior loader (#123)
This commit is contained in:
@@ -7,15 +7,13 @@ import torch
|
|||||||
import clip
|
import clip
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from dalle2_pytorch.dataloaders import make_splits
|
from dalle2_pytorch.dataloaders import make_splits, get_reader
|
||||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
|
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
|
||||||
|
|
||||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
|
|
||||||
from embedding_reader import EmbeddingReader
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
@@ -31,7 +29,7 @@ def exists(val):
|
|||||||
|
|
||||||
# functions
|
# functions
|
||||||
|
|
||||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -39,6 +37,8 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
|||||||
total_samples = 0.
|
total_samples = 0.
|
||||||
|
|
||||||
for image_embeddings, text_data in tqdm(dataloader):
|
for image_embeddings, text_data in tqdm(dataloader):
|
||||||
|
image_embeddings = image_embeddings.to(device)
|
||||||
|
text_data = text_data.to(device)
|
||||||
|
|
||||||
batches = image_embeddings.shape[0]
|
batches = image_embeddings.shape[0]
|
||||||
|
|
||||||
@@ -57,12 +57,14 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
|||||||
|
|
||||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||||
|
|
||||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
|
||||||
diffusion_prior.eval()
|
diffusion_prior.eval()
|
||||||
|
|
||||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||||
|
|
||||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||||
|
test_image_embeddings = test_image_embeddings.to(device)
|
||||||
|
text_data = text_data.to(device)
|
||||||
|
|
||||||
# we are text conditioned, we produce an embedding from the tokenized text
|
# we are text conditioned, we produce an embedding from the tokenized text
|
||||||
if text_conditioned:
|
if text_conditioned:
|
||||||
@@ -296,15 +298,31 @@ def train(
|
|||||||
|
|
||||||
# Utilize wrapper to abstract away loader logic
|
# Utilize wrapper to abstract away loader logic
|
||||||
print_ribbon("Downloading Embeddings")
|
print_ribbon("Downloading Embeddings")
|
||||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
|
||||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
|
||||||
|
|
||||||
if dp_condition_on_text_encodings:
|
if dp_condition_on_text_encodings:
|
||||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
reader_args = dict(**reader_args, meta_url=meta_url)
|
||||||
|
img_reader = get_reader(**reader_args)
|
||||||
|
train_loader, eval_loader, test_loader = make_splits(
|
||||||
|
text_conditioned=dp_condition_on_text_encodings,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_data_points=num_data_points,
|
||||||
|
train_split=train_percent,
|
||||||
|
eval_split=val_percent,
|
||||||
|
image_reader=img_reader
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
reader_args = dict(**reader_args, txt_url=text_embed_url)
|
||||||
|
img_reader, txt_reader = get_reader(**reader_args)
|
||||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
train_loader, eval_loader, test_loader = make_splits(
|
||||||
|
text_conditioned=dp_condition_on_text_encodings,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_data_points=num_data_points,
|
||||||
|
train_split=train_percent,
|
||||||
|
eval_split=val_percent,
|
||||||
|
image_reader=img_reader,
|
||||||
|
text_reader=txt_reader
|
||||||
|
)
|
||||||
|
|
||||||
### Training code ###
|
### Training code ###
|
||||||
|
|
||||||
@@ -315,9 +333,11 @@ def train(
|
|||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
|
|
||||||
for image, text in tqdm(train_loader):
|
for image, text in tqdm(train_loader):
|
||||||
|
|
||||||
diffusion_prior.train()
|
diffusion_prior.train()
|
||||||
|
|
||||||
|
image = image.to(device)
|
||||||
|
text = text.to(device)
|
||||||
|
|
||||||
input_args = dict(image_embed=image)
|
input_args = dict(image_embed=image)
|
||||||
if dp_condition_on_text_encodings:
|
if dp_condition_on_text_encodings:
|
||||||
input_args = dict(**input_args, text = text)
|
input_args = dict(**input_args, text = text)
|
||||||
@@ -350,9 +370,9 @@ def train(
|
|||||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||||
# Get embeddings from the most recently saved model
|
# Get embeddings from the most recently saved model
|
||||||
if(step % REPORT_METRICS_EVERY) == 0:
|
if(step % REPORT_METRICS_EVERY) == 0:
|
||||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
|
||||||
### Evaluate model(validation run) ###
|
### Evaluate model(validation run) ###
|
||||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
trainer.update()
|
trainer.update()
|
||||||
|
|||||||
Reference in New Issue
Block a user