diff --git a/dalle2_pytorch/dataloaders/README.md b/dalle2_pytorch/dataloaders/README.md index 0b89fda..c61e730 100644 --- a/dalle2_pytorch/dataloaders/README.md +++ b/dalle2_pytorch/dataloaders/README.md @@ -4,7 +4,7 @@ In order to make loading data simple and efficient, we include some general data ### Decoder: Image Embedding Dataset When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509. -Generating a dataset of this type: +Generating a dataset of this type: 1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset. 2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings. 3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format. @@ -39,3 +39,37 @@ dataset = ImageEmbeddingDataset( ) ``` +### Diffusion Prior: Prior Embedding Dataset +When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code. + +To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run. + +If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters. + +Usage: +```python +from dalle2_pytorch.dataloaders import get_reader, make_splits + +# grab embeddings from some specified location +IMG_URL = "data/img_emb/" +META_URL = "data/meta/" + +reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL) + +# some config for training +TRAIN_ARGS = { + "world_size": 3, + "text_conditioned": True, + "start": 0, + "num_data_points": 10000, + "batch_size": 2, + "train_split": 0.5, + "eval_split": 0.25, + "image_reader": reader, +} + +# specifying a rank will handle allocation internally +rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS) +rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS) +rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS) +``` diff --git a/dalle2_pytorch/dataloaders/__init__.py b/dalle2_pytorch/dataloaders/__init__.py index 1e1cdf5..72af534 100644 --- a/dalle2_pytorch/dataloaders/__init__.py +++ b/dalle2_pytorch/dataloaders/__init__.py @@ -1,2 +1,2 @@ from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader -from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits +from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset diff --git a/dalle2_pytorch/dataloaders/embedding_wrapper.py b/dalle2_pytorch/dataloaders/embedding_wrapper.py deleted file mode 100644 index 1162f3b..0000000 --- a/dalle2_pytorch/dataloaders/embedding_wrapper.py +++ /dev/null @@ -1,180 +0,0 @@ -from torch.utils.data import IterableDataset -from torch import from_numpy -from clip import tokenize -from embedding_reader import EmbeddingReader - - -class PriorEmbeddingLoader(IterableDataset): - def __init__( - self, - text_conditioned: bool, - batch_size: int, - start: int, - stop: int, - image_reader, - text_reader: EmbeddingReader = None, - device: str = "cpu", - ) -> None: - super(PriorEmbeddingLoader).__init__() - - self.text_conditioned = text_conditioned - - if not self.text_conditioned: - self.text_reader = text_reader - - self.image_reader = image_reader - self.batch_size = batch_size - self.start = start - self.stop = stop - self.device = device - - def __iter__(self): - self.n = 0 - loader_args = dict( - batch_size=self.batch_size, - start=self.start, - end=self.stop, - show_progress=False, - ) - if self.text_conditioned: - self.loader = self.image_reader(**loader_args) - else: - self.loader = zip( - self.image_reader(**loader_args), self.text_reader(**loader_args) - ) - return self - - def __next__(self): - try: - return self.get_sample() - except StopIteration: - raise StopIteration - - def get_sample(self): - """ - pre-proocess data from either reader into a common format - """ - self.n += 1 - - if self.text_conditioned: - image_embedding, caption = next(self.loader) - - image_embedding = from_numpy(image_embedding).to(self.device) - tokenized_caption = tokenize( - caption["caption"].to_list(), truncate=True - ).to(self.device) - - return image_embedding, tokenized_caption - - else: - (image_embedding, _), (text_embedding, _) = next(self.loader) - - image_embedding = from_numpy(image_embedding).to(self.device) - text_embedding = from_numpy(text_embedding).to(self.device) - - return image_embedding, text_embedding - - -def make_splits( - text_conditioned: bool, - batch_size: int, - num_data_points: int, - train_split: float, - eval_split: float, - device: str, - img_url: str, - meta_url: str = None, - txt_url: str = None, -): - - assert img_url is not None, "Must supply some image embeddings" - - if text_conditioned: - assert meta_url is not None, "Must supply metadata url if text-conditioning" - image_reader = EmbeddingReader( - embeddings_folder=img_url, - file_format="parquet_npy", - meta_columns=["caption"], - metadata_folder=meta_url, - ) - - # compute split points - if num_data_points > image_reader.count: - print("Specified point count is larger than the number of points available...defaulting to max length of reader.") - num_data_points = image_reader.count - - train_set_size = int(train_split * num_data_points) - eval_set_size = int(eval_split * num_data_points) - eval_stop = int(train_set_size + eval_set_size) - - train_loader = PriorEmbeddingLoader( - text_conditioned=text_conditioned, - image_reader=image_reader, - batch_size=batch_size, - start=0, - stop=train_set_size, - device=device, - ) - eval_loader = PriorEmbeddingLoader( - text_conditioned=text_conditioned, - image_reader=image_reader, - batch_size=batch_size, - start=train_set_size, - stop=eval_stop, - device=device, - ) - test_loader = PriorEmbeddingLoader( - text_conditioned=text_conditioned, - image_reader=image_reader, - batch_size=batch_size, - start=eval_stop, - stop=int(num_data_points), - device=device, - ) - - else: - assert ( - txt_url is not None - ), "Must supply text embedding url if not text-conditioning" - - image_reader = EmbeddingReader(img_url, file_format="npy") - text_reader = EmbeddingReader(txt_url, file_format="npy") - - # compute split points - if num_data_points > image_reader.count: - print("Specified point count is larger than the number of points available...defaulting to max length of reader.") - num_data_points = image_reader.count - - train_set_size = int(train_split * num_data_points) - eval_set_size = int(eval_split * num_data_points) - eval_stop = int(train_set_size + eval_set_size) - - train_loader = PriorEmbeddingLoader( - text_conditioned=text_conditioned, - image_reader=image_reader, - text_reader=text_reader, - batch_size=batch_size, - start=0, - stop=train_set_size, - device=device, - ) - eval_loader = PriorEmbeddingLoader( - text_conditioned=text_conditioned, - image_reader=image_reader, - text_reader=text_reader, - batch_size=batch_size, - start=train_set_size, - stop=eval_stop, - device=device, - ) - test_loader = PriorEmbeddingLoader( - text_conditioned=text_conditioned, - image_reader=image_reader, - text_reader=text_reader, - batch_size=batch_size, - start=eval_stop, - stop=int(num_data_points), - device=device, - ) - - return train_loader, eval_loader, test_loader diff --git a/dalle2_pytorch/dataloaders/prior_loader.py b/dalle2_pytorch/dataloaders/prior_loader.py new file mode 100644 index 0000000..cbbfc57 --- /dev/null +++ b/dalle2_pytorch/dataloaders/prior_loader.py @@ -0,0 +1,273 @@ +from math import ceil +from clip import tokenize +from embedding_reader import EmbeddingReader +from torch import from_numpy +from torch.utils.data import IterableDataset, DataLoader + + +class PriorEmbeddingDataset(IterableDataset): + """ + PriorEmbeddingDataset is a wrapper of EmbeddingReader. + + It enables one to simplify the logic necessary to yield samples from + the different EmbeddingReader configurations available. + """ + + def __init__( + self, + text_conditioned: bool, + batch_size: int, + start: int, + stop: int, + image_reader, + text_reader: EmbeddingReader = None, + ) -> None: + super(PriorEmbeddingDataset).__init__() + + self.text_conditioned = text_conditioned + + if not self.text_conditioned: + self.text_reader = text_reader + + self.image_reader = image_reader + self.start = start + self.stop = stop + self.batch_size = batch_size + + def __len__(self): + return self.stop - self.start + + def __iter__(self): + # D.R.Y loader args + loader_args = dict( + batch_size=self.batch_size, + start=self.start, + end=self.stop, + show_progress=False, + ) + + # if the data requested is text conditioned, only load images + if self.text_conditioned: + self.loader = self.image_reader(**loader_args) + # otherwise, include text embeddings and bypass metadata + else: + self.loader = zip( + self.image_reader(**loader_args), self.text_reader(**loader_args) + ) + + # return the data loader in its formatted state + return self + + def __next__(self): + try: + return self.get_sample() + except StopIteration: + raise StopIteration + + def __str__(self): + return f"" + + def get_sample(self): + """ + pre-proocess data from either reader into a common format + """ + if self.text_conditioned: + image_embedding, caption = next(self.loader) + + image_embedding = from_numpy(image_embedding) + tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True) + + return image_embedding, tokenized_caption + + else: + (image_embedding, _), (text_embedding, _) = next(self.loader) + + image_embedding = from_numpy(image_embedding) + text_embedding = from_numpy(text_embedding) + + return image_embedding, text_embedding + + +# helper functions + + +def distribute_to_rank(start, stop, rank, world_size): + """ + Distribute data to each rank given the world size. + + Return: + - New start and stop points for this rank. + """ + num_samples = int(stop - start) + + per_rank = int(ceil((num_samples) / float(world_size))) + + assert ( + per_rank > 0 + ), f"Number of samples per rank must be larger than 0, (found: {per_rank})" + + rank_start = start + rank * per_rank + + rank_stop = min(rank_start + per_rank, stop) + + new_length = rank_stop - rank_start + + assert ( + new_length > 0 + ), "Calculated start and stop points result in a length of zero for this rank." + + return rank_start, rank_stop + + +def get_reader( + text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None +): + """ + Create an EmbeddingReader object from the specified URLs + + get_reader() will always expect a url to image embeddings. + + If text-conditioned, it will also expect a meta_url for the captions. + Otherwise, it will need txt_url for the matching text embeddings. + + Returns an image_reader object if text-conditioned. + Otherwise it returns both an image_reader and a text_reader + """ + + assert img_url is not None, "Must supply a image url" + + if text_conditioned: + assert meta_url is not None, "Must supply meta url if text-conditioned" + + image_reader = EmbeddingReader( + embeddings_folder=img_url, + file_format="parquet_npy", + # will assume the caption column exists and is the only one requested + meta_columns=["caption"], + metadata_folder=meta_url, + ) + + return image_reader + + # otherwise we will require text embeddings as well and return two readers + assert ( + txt_url is not None + ), "Must supply text embedding url if not text-conditioning" + + image_reader = EmbeddingReader(img_url, file_format="npy") + text_reader = EmbeddingReader(txt_url, file_format="npy") + + return image_reader, text_reader + + +def make_splits( + text_conditioned: bool, + batch_size: int, + num_data_points: int, + train_split: float, + eval_split: float, + image_reader: EmbeddingReader, + text_reader: EmbeddingReader = None, + start=0, + rank=0, + world_size=1, +): + """ + Split an embedding reader object as needed. + + NOTE: make_splits() will infer the test set size from your train and eval. + + Input: + - text_conditioned: whether to prepare text-conditioned training data + - batch_size: the batch size for a single gpu + - num_data_points: the total number of data points you wish to train on + - train_split: the percentage of data you wish to train on + - eval_split: the percentage of data you wish to validate on + - image_reader: the image_reader you wish to split + - text_reader: the text_reader you want to split (if !text_conditioned) + - start: the starting point within your dataset + - rank: the rank of your worker + - world_size: the total world size of your distributed training run + + Returns: + - PyTorch Dataloaders that yield tuples of (img, txt) data. + """ + + assert start < image_reader.count, "start position cannot exceed reader count." + + # verify that the num_data_points does not exceed the max points + if num_data_points > (image_reader.count - start): + print( + "Specified count is larger than what's available...defaulting to reader's count." + ) + num_data_points = image_reader.count + + # compute split points + train_set_size = int(train_split * num_data_points) + eval_set_size = int(eval_split * num_data_points) + eval_start = train_set_size + eval_stop = int(eval_start + eval_set_size) + + assert ( + train_split + eval_split + ) < 1.0, "Specified train and eval split is too large to infer a test split." + + # distribute to rank + rank_train_start, rank_train_stop = distribute_to_rank( + start, train_set_size, rank, world_size + ) + rank_eval_start, rank_eval_stop = distribute_to_rank( + train_set_size, eval_stop, rank, world_size + ) + rank_test_start, rank_test_stop = distribute_to_rank( + eval_stop, num_data_points, rank, world_size + ) + + # wrap up splits into a dict + train_split_args = dict( + start=rank_train_start, stop=rank_train_stop, batch_size=batch_size + ) + eval_split_args = dict( + start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size + ) + test_split_args = dict( + start=rank_test_start, stop=rank_test_stop, batch_size=batch_size + ) + + if text_conditioned: + # add the text-conditioned args to a unified dict + reader_args = dict( + text_conditioned=text_conditioned, + image_reader=image_reader, + ) + + train_split_args = dict(**reader_args, **train_split_args) + eval_split_args = dict(**reader_args, **eval_split_args) + test_split_args = dict(**reader_args, **test_split_args) + + train = PriorEmbeddingDataset(**train_split_args) + val = PriorEmbeddingDataset(**eval_split_args) + test = PriorEmbeddingDataset(**test_split_args) + + else: + # add the non-conditioned args to a unified dict + reader_args = dict( + text_conditioned=text_conditioned, + image_reader=image_reader, + text_reader=text_reader, + ) + + train_split_args = dict(**reader_args, **train_split_args) + eval_split_args = dict(**reader_args, **eval_split_args) + test_split_args = dict(**reader_args, **test_split_args) + + train = PriorEmbeddingDataset(**train_split_args) + val = PriorEmbeddingDataset(**eval_split_args) + test = PriorEmbeddingDataset(**test_split_args) + + # true batch size is specifed in the PriorEmbeddingDataset + train_loader = DataLoader(train, batch_size=None) + eval_loader = DataLoader(val, batch_size=None) + test_loader = DataLoader(test, batch_size=None) + + return train_loader, eval_loader, test_loader