## Dataloaders In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network. ### Decoder: Image Embedding Dataset When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509. Generating a dataset of this type: 1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset. 2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings. 3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format. Usage: ```python from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader # Create a dataloader directly. dataloader = create_image_embedding_dataloader( tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise num_workers=4, batch_size=32, shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index shuffle_num=200, # Does a shuffle of the data with a buffer size of 200 shuffle_shards=True, # Shuffle the order the shards are read in resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually ) for img, emb in dataloader: print(img.shape) # torch.Size([32, 3, 256, 256]) print(emb.shape) # torch.Size([32, 512]) # Train decoder only as shown above # Or create a dataset without a loader so you can configure it manually dataset = ImageEmbeddingDataset( urls="/path/or/url/to/webdataset/{0000..9999}.tar", embedding_folder_url="path/or/url/to/embeddings/folder", shard_width=4, shuffle_shards=True, resample=False ) ```