mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-01-10 03:54:19 +01:00
Added single GPU training script for decoder (#108)
Added config files for training Changed example image generation to be more efficient Added configuration description to README Removed unused import
This commit is contained in:
82
configs/decoder_defaults.py
Normal file
82
configs/decoder_defaults.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Defines the default values for the decoder config
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
class ConfigField(Enum):
|
||||
REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it.
|
||||
|
||||
default_config = {
|
||||
"unets": ConfigField.REQUIRED,
|
||||
"decoder": {
|
||||
"image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet
|
||||
"image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think
|
||||
"channels": 3,
|
||||
"timesteps": 1000,
|
||||
"loss_type": "l2",
|
||||
"beta_schedule": "cosine",
|
||||
"learned_variance": True
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images
|
||||
"embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings
|
||||
"num_workers": 4,
|
||||
"batch_size": 64,
|
||||
"start_shard": 0,
|
||||
"end_shard": 9999999,
|
||||
"shard_width": 6,
|
||||
"index_width": 4,
|
||||
"splits": {
|
||||
"train": 0.75,
|
||||
"val": 0.15,
|
||||
"test": 0.1
|
||||
},
|
||||
"shuffle_train": True,
|
||||
"resample_train": False,
|
||||
"preprocessing": {
|
||||
"ToTensor": True
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 20,
|
||||
"lr": 1e-4,
|
||||
"wd": 0.01,
|
||||
"max_grad_norm": 0.5,
|
||||
"save_every_n_samples": 100000,
|
||||
"n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset
|
||||
"device": "cuda:0",
|
||||
"epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
"validation_samples": None, # Same as above but for validation.
|
||||
"use_ema": True,
|
||||
"ema_beta": 0.99,
|
||||
"amp": False,
|
||||
"save_all": False, # Whether to preserve all checkpoints
|
||||
"save_latest": True, # Whether to always save the latest checkpoint
|
||||
"save_best": True, # Whether to save the best checkpoint
|
||||
"unet_training_mask": None # If None, use all unets
|
||||
},
|
||||
"evaluate": {
|
||||
"n_evalation_samples": 1000,
|
||||
"FID": None,
|
||||
"IS": None,
|
||||
"KID": None,
|
||||
"LPIPS": None
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "console", # Decoder currently supports console and wandb
|
||||
"data_path": "./models", # The path where files will be saved locally
|
||||
|
||||
"wandb_entity": "", # Only needs to be set if tracker_type is wandb
|
||||
"wandb_project": "",
|
||||
|
||||
"verbose": False # Whether to print console logging for non-console trackers
|
||||
},
|
||||
"load": {
|
||||
"source": None, # Supports file and wandb
|
||||
|
||||
"run_path": "", # Used only if source is wandb
|
||||
"file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
||||
|
||||
"resume": False # If using wandb, whether to resume the run
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user