Files
DALLE2-pytorch/dalle2_pytorch/dataloaders
Aidan Dempster f5760bdb92 Add data flexibility to decoder trainer (#165)
* Added the ability to train decoder with text embeddings

* Added the ability to train using on the fly generated embeddings with clip

* Clip now generates embeddings for whatever is not precomputed
2022-06-25 19:05:20 -07:00
..
2022-05-29 07:39:59 -07:00
2022-06-01 13:28:02 -07:00

Dataloaders

In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.

Decoder: Image Embedding Dataset

When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a webdataset that contains .jpg and .npy files in the .tars that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain .npy files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the .jpg and the index of the embedding in the .npy. So, for example, 0001.tar from the webdataset with image 00010509.jpg (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a img_emb_0001.npy which contains a NumPy array with the embedding at index 509.

Generating a dataset of this type:

  1. Use img2dataset to generate a webdataset.
  2. Use clip-retrieval to convert the images to embeddings.
  3. Use embedding-dataset-reordering to reorder the embeddings into the expected format.

Usage:

from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader

# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
    tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
    embeddings_url="path/or/url/to/embeddings/folder",     # Included if .npy files are not in webdataset. Left out or set to None otherwise
    num_workers=4,
    batch_size=32,
    shard_width=4,                                         # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
    shuffle_num=200,                                       # Does a shuffle of the data with a buffer size of 200
    shuffle_shards=True,                                   # Shuffle the order the shards are read in
    resample_shards=False,                                 # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
)
for img, emb in dataloader:
    print(img.shape)  # torch.Size([32, 3, 256, 256])
    print(emb.shape)  # torch.Size([32, 512])
    # Train decoder only as shown above

# Or create a dataset without a loader so you can configure it manually
dataset = ImageEmbeddingDataset(
    urls="/path/or/url/to/webdataset/{0000..9999}.tar",
    embedding_folder_url="path/or/url/to/embeddings/folder",
    shard_width=4,
    shuffle_shards=True,
    resample=False
)

Diffusion Prior: Prior Embedding Dataset

When training the prior it is much more efficient to work with pre-computed embeddings. The PriorEmbeddingDataset class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.

To utilize the PriorEmbeddingDataset, all you need to do is make a single call to get_reader() which will create EmbeddingReader object(s) for you. Afterwards, you can utilize make_splits() to cleanly create DataLoader objects from for your training run.

If you are training in a distributed manner, make_splits() accepts rank and world_size arguments to properly distribute to each process. The defaults for these values are rank=0 and world_size=1, so single-process training can safely ignore these parameters.

Usage:

from dalle2_pytorch.dataloaders import get_reader, make_splits

# grab embeddings from some specified location
IMG_URL = "data/img_emb/"
META_URL = "data/meta/"

reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)

# some config for training
TRAIN_ARGS = {
    "world_size": 3,
    "text_conditioned": True,
    "start": 0,
    "num_data_points": 10000,
    "batch_size": 2,
    "train_split": 0.5,
    "eval_split": 0.25,
    "image_reader": reader,
}

# specifying a rank will handle allocation internally
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)