DALLE2 Training Configurations
For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
Decoder Trainer
The decoder trainer has 7 main configuration options. A full example of their use can be found in the example decoder configuration.
Unet:
This is a single unet config, which belongs as an array nested under the decoder config as a list of unets
| Option | Required | Default | Description |
|---|---|---|---|
dim |
Yes | N/A | The starting channels of the unet. |
image_embed_dim |
Yes | N/A | The dimension of the image embeddings. |
dim_mults |
No | (1, 2, 4, 8) |
The growth factors of the channels. |
Any parameter from the Unet constructor can also be given here.
Decoder:
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description |
|---|---|---|---|
unets |
Yes | N/A | A list of unets, using the configuration above |
image_sizes |
Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
image_size |
Yes | N/A | Not used. Can be any number. |
timesteps |
No | 1000 |
The number of diffusion timesteps used for generation. |
loss_type |
No | l2 |
The loss function. Options are l1, huber, or l2. |
beta_schedule |
No | cosine |
The noising schedule. Options are cosine, linear, quadratic, jsd, or sigmoid. |
learned_variance |
No | True |
Whether to learn the variance. |
Any parameter from the Decoder constructor can also be given here.
Data:
Settings for creation of the dataloaders.
| Option | Required | Default | Description |
|---|---|---|---|
webdataset_base_url |
Yes | N/A | The url of a shard in the webdataset with the shard replaced with {}1 . |
embeddings_url |
No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
num_workers |
No | 4 |
The number of workers used in the dataloader. |
batch_size |
No | 64 |
The batch size. |
start_shard |
No | 0 |
Defines the start of the shard range the dataset will recall. |
end_shard |
No | 9999999 |
Defines the end of the shard range the dataset will recall. |
shard_width |
No | 6 |
Defines the width of one webdataset shard number2 . |
index_width |
No | 4 |
Defines the width of the index of a file inside a shard3 . |
splits |
No | { "train": 0.75, "val": 0.15, "test": 0.1 } |
Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
shuffle_train |
No | True |
Whether to shuffle the shards of the training dataset. |
resample_train |
No | False |
If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if shuffle_train is enabled. |
preprocessing |
No | { "ToTensor": True } |
Defines preprocessing applied to images from the datasets. |
Train:
Settings for controlling the training hyperparameters.
| Option | Required | Default | Description |
|---|---|---|---|
epochs |
No | 20 |
The number of epochs in the training run. |
lr |
No | 1e-4 |
The learning rate. |
wd |
No | 0.01 |
The weight decay. |
max_grad_norm |
No | 0.5 |
The grad norm clipping. |
save_every_n_samples |
No | 100000 |
Samples will be generated and a checkpoint will be saved every save_every_n_samples samples. |
device |
No | cuda:0 |
The device to train on. |
epoch_samples |
No | None |
Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
validation_samples |
No | None |
The number of samples to use for validation. None mean the entire validation set. |
use_ema |
No | True |
Whether to use exponential moving average models for sampling. |
ema_beta |
No | 0.99 |
The ema coefficient. |
save_all |
No | False |
If True, preserves a checkpoint for every epoch. |
save_latest |
No | True |
If True, overwrites the latest.pth every time the model is saved. |
save_best |
No | True |
If True, overwrites the best.pth every time the model has a lower validation loss than all previous models. |
unet_training_mask |
No | None |
A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of None trains all unets. |
Evaluate:
Defines which evaluation metrics will be used to test the model. Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
| Option | Required | Default | Description |
|---|---|---|---|
n_evaluation_samples |
No | 1000 |
The number of samples to generate to test the model. |
FID |
No | None |
Setting to an object enables the Frechet Inception Distance metric. |
IS |
No | None |
Setting to an object enables the Inception Score metric. |
KID |
No | None |
Setting to an object enables the Kernel Inception Distance metric. |
LPIPS |
No | None |
Setting to an object enables the Learned Perceptual Image Patch Similarity metric. |
Tracker:
Selects which tracker to use and configures it.
| Option | Required | Default | Description |
|---|---|---|---|
tracker_type |
No | console |
Which tracker to use. Currently accepts console or wandb. |
data_path |
No | ./models |
Where the tracker will store local data. |
verbose |
No | False |
Enables console logging for non-console trackers. |
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each tracker.
Load:
Selects where to load a pretrained model from.
| Option | Required | Default | Description |
|---|---|---|---|
source |
No | None |
Supports file or wandb. |
resume |
No | False |
If the tracker support resuming the run, resume it. |
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the tracker file.
-
If your shard files have the paths
protocol://path/to/shard/00104.tar, then the base url would beprotocol://path/to/shard/{}.tar. If you are using a protocol likes3, you need to pipe the tars. For examplepipe:s3cmd get s3://bucket/path/{}.tar -. ↩︎ -
This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename
00104.tar, your shard length is 5. ↩︎ -
Inside the webdataset
tar, you have files named something like001045945.jpg. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is001041and index is5945). Theindex_widthin this case is 4. ↩︎