diff --git a/dalle2_pytorch/dataloaders/__init__.py b/dalle2_pytorch/dataloaders/__init__.py index 0fb1d7e..1e1cdf5 100644 --- a/dalle2_pytorch/dataloaders/__init__.py +++ b/dalle2_pytorch/dataloaders/__init__.py @@ -1 +1,2 @@ -from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader \ No newline at end of file +from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader +from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits diff --git a/dalle2_pytorch/dataloaders/embedding_wrapper.py b/dalle2_pytorch/dataloaders/embedding_wrapper.py new file mode 100644 index 0000000..1162f3b --- /dev/null +++ b/dalle2_pytorch/dataloaders/embedding_wrapper.py @@ -0,0 +1,180 @@ +from torch.utils.data import IterableDataset +from torch import from_numpy +from clip import tokenize +from embedding_reader import EmbeddingReader + + +class PriorEmbeddingLoader(IterableDataset): + def __init__( + self, + text_conditioned: bool, + batch_size: int, + start: int, + stop: int, + image_reader, + text_reader: EmbeddingReader = None, + device: str = "cpu", + ) -> None: + super(PriorEmbeddingLoader).__init__() + + self.text_conditioned = text_conditioned + + if not self.text_conditioned: + self.text_reader = text_reader + + self.image_reader = image_reader + self.batch_size = batch_size + self.start = start + self.stop = stop + self.device = device + + def __iter__(self): + self.n = 0 + loader_args = dict( + batch_size=self.batch_size, + start=self.start, + end=self.stop, + show_progress=False, + ) + if self.text_conditioned: + self.loader = self.image_reader(**loader_args) + else: + self.loader = zip( + self.image_reader(**loader_args), self.text_reader(**loader_args) + ) + return self + + def __next__(self): + try: + return self.get_sample() + except StopIteration: + raise StopIteration + + def get_sample(self): + """ + pre-proocess data from either reader into a common format + """ + self.n += 1 + + if self.text_conditioned: + image_embedding, caption = next(self.loader) + + image_embedding = from_numpy(image_embedding).to(self.device) + tokenized_caption = tokenize( + caption["caption"].to_list(), truncate=True + ).to(self.device) + + return image_embedding, tokenized_caption + + else: + (image_embedding, _), (text_embedding, _) = next(self.loader) + + image_embedding = from_numpy(image_embedding).to(self.device) + text_embedding = from_numpy(text_embedding).to(self.device) + + return image_embedding, text_embedding + + +def make_splits( + text_conditioned: bool, + batch_size: int, + num_data_points: int, + train_split: float, + eval_split: float, + device: str, + img_url: str, + meta_url: str = None, + txt_url: str = None, +): + + assert img_url is not None, "Must supply some image embeddings" + + if text_conditioned: + assert meta_url is not None, "Must supply metadata url if text-conditioning" + image_reader = EmbeddingReader( + embeddings_folder=img_url, + file_format="parquet_npy", + meta_columns=["caption"], + metadata_folder=meta_url, + ) + + # compute split points + if num_data_points > image_reader.count: + print("Specified point count is larger than the number of points available...defaulting to max length of reader.") + num_data_points = image_reader.count + + train_set_size = int(train_split * num_data_points) + eval_set_size = int(eval_split * num_data_points) + eval_stop = int(train_set_size + eval_set_size) + + train_loader = PriorEmbeddingLoader( + text_conditioned=text_conditioned, + image_reader=image_reader, + batch_size=batch_size, + start=0, + stop=train_set_size, + device=device, + ) + eval_loader = PriorEmbeddingLoader( + text_conditioned=text_conditioned, + image_reader=image_reader, + batch_size=batch_size, + start=train_set_size, + stop=eval_stop, + device=device, + ) + test_loader = PriorEmbeddingLoader( + text_conditioned=text_conditioned, + image_reader=image_reader, + batch_size=batch_size, + start=eval_stop, + stop=int(num_data_points), + device=device, + ) + + else: + assert ( + txt_url is not None + ), "Must supply text embedding url if not text-conditioning" + + image_reader = EmbeddingReader(img_url, file_format="npy") + text_reader = EmbeddingReader(txt_url, file_format="npy") + + # compute split points + if num_data_points > image_reader.count: + print("Specified point count is larger than the number of points available...defaulting to max length of reader.") + num_data_points = image_reader.count + + train_set_size = int(train_split * num_data_points) + eval_set_size = int(eval_split * num_data_points) + eval_stop = int(train_set_size + eval_set_size) + + train_loader = PriorEmbeddingLoader( + text_conditioned=text_conditioned, + image_reader=image_reader, + text_reader=text_reader, + batch_size=batch_size, + start=0, + stop=train_set_size, + device=device, + ) + eval_loader = PriorEmbeddingLoader( + text_conditioned=text_conditioned, + image_reader=image_reader, + text_reader=text_reader, + batch_size=batch_size, + start=train_set_size, + stop=eval_stop, + device=device, + ) + test_loader = PriorEmbeddingLoader( + text_conditioned=text_conditioned, + image_reader=image_reader, + text_reader=text_reader, + batch_size=batch_size, + start=eval_stop, + stop=int(num_data_points), + device=device, + ) + + return train_loader, eval_loader, test_loader diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index b8d3b2e..c06654d 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -5,9 +5,10 @@ import time import numpy as np import torch +import clip from torch import nn - -from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork +from dalle2_pytorch.dataloaders import make_splits +from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker @@ -17,8 +18,7 @@ from tqdm import tqdm # constants -NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training -REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training +REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training tracker = WandbTracker() @@ -36,81 +36,106 @@ class Timer: def elapsed(self): return time.time() - self.last_time + # functions -def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): +def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"): model.eval() + with torch.no_grad(): total_loss = 0. total_samples = 0. - for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end), - text_reader(batch_size=batch_size, start=start, end=end)): + for image_embeddings, text_data in tqdm(dataloader): - emb_images_tensor = torch.tensor(emb_images[0]).to(device) - emb_text_tensor = torch.tensor(emb_text[0]).to(device) + batches = image_embeddings.shape[0] - batches = emb_images_tensor.shape[0] + input_args = dict(image_embed=image_embeddings) + if text_conditioned: + input_args = dict(**input_args, text = text_data) + else: + input_args = dict(**input_args, text_embed=text_data) - loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor) + loss = model(**input_args) - total_loss += loss.item() * batches + total_loss += loss * batches total_samples += batches avg_loss = (total_loss / total_samples) + tracker.log({f'{phase} {loss_type}': avg_loss}) -def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device): +def report_cosine_sims(diffusion_prior, dataloader, text_conditioned): diffusion_prior.eval() cos = nn.CosineSimilarity(dim=1, eps=1e-6) - tstart = train_set_size - tend = train_set_size+NUM_TEST_EMBEDDINGS + for test_image_embeddings, text_data in tqdm(dataloader): + + # we are text conditioned, we produce an embedding from the tokenized text + if text_conditioned: + text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text( + text_data) + text_cond = dict(text_embed=text_embedding, + text_encodings=text_encodings, mask=text_mask) + else: + text_embedding = text_data + text_cond = dict(text_embed=text_embedding) + + # make a copy of the text embeddings for shuffling + text_embed_shuffled = text_embedding.clone() + + # roll the text to simulate "unrelated" captions + rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1) + text_embed_shuffled = text_embed_shuffled[rolled_idx] + text_embed_shuffled = text_embed_shuffled / \ + text_embed_shuffled.norm(dim=1, keepdim=True) + + if text_conditioned: + text_encodings_shuffled = text_encodings[rolled_idx] + text_mask_shuffled = text_mask[rolled_idx] + else: + text_encodings_shuffled = None + text_mask_shuffled = None + + text_cond_shuffled = dict(text_embed=text_embed_shuffled, + text_encodings=text_encodings_shuffled, mask=text_mask_shuffled) - for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend), - image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)): - # make a copy of the text embeddings for shuffling - text_embed = torch.tensor(embt[0]).to(device) - text_embed_shuffled = text_embed.clone() - # roll the text embeddings to simulate "unrelated" captions - rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1) - text_embed_shuffled = text_embed_shuffled[rolled_idx] - text_embed_shuffled = text_embed_shuffled / \ - text_embed_shuffled.norm(dim=1, keepdim=True) - test_text_shuffled_cond = dict(text_embed=text_embed_shuffled) # prepare the text embedding - text_embed = text_embed / text_embed.norm(dim=1, keepdim=True) - test_text_cond = dict(text_embed=text_embed) + text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True) + # prepare image embeddings - test_image_embeddings = torch.tensor(embi[0]).to(device) - test_image_embeddings = test_image_embeddings / \ - test_image_embeddings.norm(dim=1, keepdim=True) + test_image_embeddings = test_image_embeddings / \ + test_image_embeddings.norm(dim=1, keepdim=True) + # predict on the unshuffled text embeddings - predicted_image_embeddings = diffusion_prior.p_sample_loop( - (NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond) - predicted_image_embeddings = predicted_image_embeddings / \ - predicted_image_embeddings.norm(dim=1, keepdim=True) + predicted_image_embeddings = diffusion_prior.p_sample_loop( + test_image_embeddings.shape, text_cond) + predicted_image_embeddings = predicted_image_embeddings / \ + predicted_image_embeddings.norm(dim=1, keepdim=True) + # predict on the shuffled embeddings - predicted_unrelated_embeddings = diffusion_prior.p_sample_loop( - (NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond) - predicted_unrelated_embeddings = predicted_unrelated_embeddings / \ - predicted_unrelated_embeddings.norm(dim=1, keepdim=True) + predicted_unrelated_embeddings = diffusion_prior.p_sample_loop( + test_image_embeddings.shape, text_cond_shuffled) + predicted_unrelated_embeddings = predicted_unrelated_embeddings / \ + predicted_unrelated_embeddings.norm(dim=1, keepdim=True) + # calculate similarities - original_similarity = cos( + original_similarity = cos( text_embed, test_image_embeddings).cpu().numpy() - predicted_similarity = cos( + predicted_similarity = cos( text_embed, predicted_image_embeddings).cpu().numpy() - unrelated_similarity = cos( + unrelated_similarity = cos( text_embed, predicted_unrelated_embeddings).cpu().numpy() - predicted_img_similarity = cos( + predicted_img_similarity = cos( test_image_embeddings, predicted_image_embeddings).cpu().numpy() - tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity), + tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity), "CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity), "CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity), "CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity), "Cosine similarity difference":np.mean(predicted_similarity - original_similarity)}) + @click.command() @click.option("--wandb-entity", default="laion") @click.option("--wandb-project", default="diffusion-prior") @@ -118,29 +143,32 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N @click.option("--wandb-arch", default="DiffusionPrior") @click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") @click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") +@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/") @click.option("--learning-rate", default=1.1e-4) @click.option("--weight-decay", default=6.02e-2) @click.option("--dropout", default=5e-2) @click.option("--max-grad-norm", default=0.5) -@click.option("--batch-size", default=10**4) +@click.option("--num-data-points", default=250e6) +@click.option("--batch-size", default=320) @click.option("--num-epochs", default=5) @click.option("--image-embed-dim", default=768) -@click.option("--train-percent", default=0.7) -@click.option("--val-percent", default=0.2) -@click.option("--test-percent", default=0.1) -@click.option("--dpn-depth", default=6) +@click.option("--train-percent", default=0.9) +@click.option("--val-percent", default=1e-7) +@click.option("--test-percent", default=0.0999999) +@click.option("--dpn-depth", default=12) @click.option("--dpn-dim-head", default=64) -@click.option("--dpn-heads", default=8) -@click.option("--dp-condition-on-text-encodings", default=False) -@click.option("--dp-timesteps", default=100) -@click.option("--dp-normformer", default=False) +@click.option("--dpn-heads", default=12) +@click.option("--dp-condition-on-text-encodings", default=True) +@click.option("--dp-timesteps", default=1000) +@click.option("--dp-normformer", default=True) @click.option("--dp-cond-drop-prob", default=0.1) @click.option("--dp-loss-type", default="l2") -@click.option("--clip", default=None) +@click.option("--clip", default="ViT-L/14") @click.option("--amp", default=False) -@click.option("--save-interval", default=30) +@click.option("--save-interval", default=120) @click.option("--save-path", default="./diffusion_prior_checkpoints") @click.option("--pretrained-model-path", default=None) +@click.option("--gpu-device", default=0) def train( wandb_entity, wandb_project, @@ -148,10 +176,12 @@ def train( wandb_arch, image_embed_url, text_embed_url, + meta_url, learning_rate, weight_decay, dropout, max_grad_norm, + num_data_points, batch_size, num_epochs, image_embed_dim, @@ -170,7 +200,8 @@ def train( amp, save_interval, save_path, - pretrained_model_path + pretrained_model_path, + gpu_device ): config = { "learning_rate": learning_rate, @@ -197,7 +228,7 @@ def train( # Check if DPRIOR_PATH exists(saved model path) - DPRIOR_PATH = args.pretrained_model_path + DPRIOR_PATH = pretrained_model_path RESUME = exists(DPRIOR_PATH) if not RESUME: @@ -211,7 +242,7 @@ def train( has_cuda = torch.cuda.is_available() if has_cuda: - device = torch.device("cuda:0") + device = torch.device(f"cuda:{gpu_device}") torch.cuda.set_device(device) # Training loop @@ -227,11 +258,17 @@ def train( normformer = dp_normformer ) + # Load clip model if text-conditioning + if dp_condition_on_text_encodings: + clip_adapter = OpenAIClipAdapter(clip) + else: + clip_adapter = None + # diffusion prior with text embeddings and image embeddings pre-computed diffusion_prior = DiffusionPrior( net = prior_network, - clip = clip, + clip = clip_adapter, image_embed_dim = image_embed_dim, timesteps = dp_timesteps, cond_drop_prob = dp_cond_drop_prob, @@ -265,33 +302,37 @@ def train( Path(save_path).mkdir(exist_ok = True, parents = True) - # Get image and text embeddings from the servers + # Utilize wrapper to abstract away loader logic + print_ribbon("Downloading Embeddings") + loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points, + train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url) - print_ribbon("Downloading embeddings - image and text") - image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy") - text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy") - num_data_points = text_reader.count + if dp_condition_on_text_encodings: + loader_args = dict(**loader_args, meta_url=meta_url) + else: + loader_args = dict(**loader_args, txt_url=text_embed_url) + + train_loader, eval_loader, test_loader = make_splits(**loader_args) ### Training code ### + step = 1 timer = Timer() epochs = num_epochs - train_set_size = int(train_percent*num_data_points) - val_set_size = int(val_percent*num_data_points) - eval_start = train_set_size - for _ in range(epochs): - for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size), - text_reader(batch_size=batch_size, start=0, end=train_set_size)): - - trainer.train() + for image, text in tqdm(train_loader): - emb_images_tensor = torch.tensor(emb_images[0]).to(device) - emb_text_tensor = torch.tensor(emb_text[0]).to(device) + diffusion_prior.train() + + input_args = dict(image_embed=image) + if dp_condition_on_text_encodings: + input_args = dict(**input_args, text = text) + else: + input_args = dict(**input_args, text_embed=text) - loss = trainer(text_embed = emb_text_tensor, image_embed = emb_images_tensor) + loss = trainer(**input_args) # Samples per second @@ -310,37 +351,23 @@ def train( image_embed_dim) # Log to wandb - tracker.log({"Training loss": loss.item(), + tracker.log({"Training loss": loss, "Steps": step, "Samples per second": samples_per_sec}) # Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed) # Use NUM_TEST_EMBEDDINGS samples from the test set each time # Get embeddings from the most recently saved model if(step % REPORT_METRICS_EVERY) == 0: - report_cosine_sims(diffusion_prior, - image_reader, - text_reader, - train_set_size, - NUM_TEST_EMBEDDINGS, - device) + report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings) ### Evaluate model(validation run) ### - eval_model(diffusion_prior, - device, - image_reader, - text_reader, - eval_start, - eval_start+NUM_TEST_EMBEDDINGS, - NUM_TEST_EMBEDDINGS, - dp_loss_type, - phase="Validation") + eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation") + step += 1 trainer.update() ### Test run ### - test_set_size = int(test_percent*train_set_size) - start = train_set_size+val_set_size - end = num_data_points - eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test") + eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test") + if __name__ == "__main__": train()