diff --git a/README.md b/README.md index f7c6d86..83a7f31 100644 --- a/README.md +++ b/README.md @@ -382,38 +382,6 @@ For the layperson, no worries, training will all be automated into a CLI tool, a ## Training on Preprocessed CLIP Embeddings -## Using the train_diffusion_prior.py script -This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process. -Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below. -## Usage -```bash -$ pyhon train_diffusion_prior.py -``` -The most significant parameters for the script are as follows: - ---image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") - ---text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") - ---image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates - ---learning-rate, default=1.1e-4 - ---weight-decay, default=6.02e-2 - ---max-grad-norm, default=0.5 - ---batch-size, default=10 ** 4 - ---num-epochs, default=5 - ---clip, default=None # Signals the prior to use pre-computed embeddings - -## Sample wandb run log - -Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/aul0rhv5?workspace= - - It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask` Working example below @@ -619,47 +587,6 @@ images = dalle2( Now you'll just have to worry about training the Prior and the Decoder! -## Dataloaders -In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network. - -### Decoder: Image Embedding Dataset -When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509. - -Generating a dataset of this type: -1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset. -2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings. -3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format. - -Usage: -```python -from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader - -# Create a dataloader directly. -dataloader = create_image_embedding_dataloader( - tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar - embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise - num_workers=4, - batch_size=32, - shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index - shuffle_num=200, # Does a shuffle of the data with a buffer size of 200 - shuffle_shards=True, # Shuffle the order the shards are read in - resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually -) -for img, emb in dataloader: - print(img.shape) # torch.Size([32, 3, 256, 256]) - print(emb.shape) # torch.Size([32, 512]) - # Train decoder only as shown above - -# Or create a dataset without a loader so you can configure it manually -dataset = ImageEmbeddingDataset( - urls="/path/or/url/to/webdataset/{0000..9999}.tar", - embedding_folder_url="path/or/url/to/embeddings/folder", - shard_width=4, - shuffle_shards=True, - resample=False -) -``` - ## Experimental ### DALL-E2 with Latent Diffusion @@ -859,6 +786,87 @@ mock_image_embed = torch.randn(4, 512).cuda() images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256) ``` +### Decoder Dataloaders + +In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network. + +#### Decoder: Image Embedding Dataset + +When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509. + +Generating a dataset of this type: +1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset. +2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings. +3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format. + +Usage: + +```python +from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader + +# Create a dataloader directly. +dataloader = create_image_embedding_dataloader( + tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar + embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise + num_workers=4, + batch_size=32, + shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index + shuffle_num=200, # Does a shuffle of the data with a buffer size of 200 + shuffle_shards=True, # Shuffle the order the shards are read in + resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually +) +for img, emb in dataloader: + print(img.shape) # torch.Size([32, 3, 256, 256]) + print(emb.shape) # torch.Size([32, 512]) + # Train decoder only as shown above + +# Or create a dataset without a loader so you can configure it manually +dataset = ImageEmbeddingDataset( + urls="/path/or/url/to/webdataset/{0000..9999}.tar", + embedding_folder_url="path/or/url/to/embeddings/folder", + shard_width=4, + shuffle_shards=True, + resample=False +) +``` + +## Scripts + +### Using the `train_diffusion_prior.py` script + +This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process. +Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below. + +### Usage + +```bash +$ pyhon train_diffusion_prior.py +``` + +The most significant parameters for the script are as follows: + +--image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") + +--text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") + +--image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates + +--learning-rate, default=1.1e-4 + +--weight-decay, default=6.02e-2 + +--max-grad-norm, default=0.5 + +--batch-size, default=10 ** 4 + +--num-epochs, default=5 + +--clip, default=None # Signals the prior to use pre-computed embeddings + +### Sample wandb run log + +Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/aul0rhv5?workspace= + ## CLI (wip) ```bash