mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Added single GPU training script for decoder (#108)
Added config files for training Changed example image generation to be more efficient Added configuration description to README Removed unused import
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,6 +1,12 @@
|
||||
# default experiment tracker data
|
||||
.tracker-data/
|
||||
|
||||
# Configuration Files
|
||||
configs/*
|
||||
!configs/*.example
|
||||
!configs/*_defaults.py
|
||||
!configs/README.md
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
109
configs/README.md
Normal file
109
configs/README.md
Normal file
@@ -0,0 +1,109 @@
|
||||
## DALLE2 Training Configurations
|
||||
|
||||
For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
|
||||
|
||||
### Decoder Trainer
|
||||
|
||||
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example).
|
||||
|
||||
**<ins>Unets</ins>:**
|
||||
|
||||
Each member of this array defines a single unet that will be added to the decoder.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `dim` | Yes | N/A | The starting channels of the unet. |
|
||||
| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |
|
||||
| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |
|
||||
|
||||
Any parameter from the `Unet` constructor can also be given here.
|
||||
|
||||
**<ins>Decoder</ins>:**
|
||||
|
||||
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
||||
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
||||
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
||||
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
|
||||
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
|
||||
| `learned_variance` | No | `True` | Whether to learn the variance. |
|
||||
|
||||
Any parameter from the `Decoder` constructor can also be given here.
|
||||
|
||||
**<ins>Data</ins>:**
|
||||
|
||||
Settings for creation of the dataloaders.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
|
||||
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
|
||||
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
|
||||
| `batch_size` | No | `64` | The batch size. |
|
||||
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
|
||||
| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |
|
||||
| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |
|
||||
| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |
|
||||
| `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
|
||||
| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |
|
||||
| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |
|
||||
| `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. |
|
||||
|
||||
[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.
|
||||
|
||||
[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.
|
||||
|
||||
[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.
|
||||
|
||||
**<ins>Train</ins>:**
|
||||
|
||||
Settings for controlling the training hyperparameters.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `epochs` | No | `20` | The number of epochs in the training run. |
|
||||
| `lr` | No | `1e-4` | The learning rate. |
|
||||
| `wd` | No | `0.01` | The weight decay. |
|
||||
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
|
||||
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
|
||||
| `device` | No | `cuda:0` | The device to train on. |
|
||||
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
|
||||
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
|
||||
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
|
||||
| `ema_beta` | No | `0.99` | The ema coefficient. |
|
||||
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
|
||||
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
|
||||
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
|
||||
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
|
||||
|
||||
**<ins>Evaluate</ins>:**
|
||||
|
||||
Defines which evaluation metrics will be used to test the model.
|
||||
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
||||
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
||||
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
||||
| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |
|
||||
|
||||
**<ins>Tracker</ins>:**
|
||||
|
||||
Selects which tracker to use and configures it.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
|
||||
| `data_path` | No | `./models` | Where the tracker will store local data. |
|
||||
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
|
||||
|
||||
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
|
||||
|
||||
**<ins>Load</ins>:**
|
||||
|
||||
Selects where to load a pretrained model from.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `source` | No | `None` | Supports `file` or `wandb`. |
|
||||
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
|
||||
|
||||
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
|
||||
82
configs/decoder_defaults.py
Normal file
82
configs/decoder_defaults.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Defines the default values for the decoder config
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
class ConfigField(Enum):
|
||||
REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it.
|
||||
|
||||
default_config = {
|
||||
"unets": ConfigField.REQUIRED,
|
||||
"decoder": {
|
||||
"image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet
|
||||
"image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think
|
||||
"channels": 3,
|
||||
"timesteps": 1000,
|
||||
"loss_type": "l2",
|
||||
"beta_schedule": "cosine",
|
||||
"learned_variance": True
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images
|
||||
"embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings
|
||||
"num_workers": 4,
|
||||
"batch_size": 64,
|
||||
"start_shard": 0,
|
||||
"end_shard": 9999999,
|
||||
"shard_width": 6,
|
||||
"index_width": 4,
|
||||
"splits": {
|
||||
"train": 0.75,
|
||||
"val": 0.15,
|
||||
"test": 0.1
|
||||
},
|
||||
"shuffle_train": True,
|
||||
"resample_train": False,
|
||||
"preprocessing": {
|
||||
"ToTensor": True
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 20,
|
||||
"lr": 1e-4,
|
||||
"wd": 0.01,
|
||||
"max_grad_norm": 0.5,
|
||||
"save_every_n_samples": 100000,
|
||||
"n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset
|
||||
"device": "cuda:0",
|
||||
"epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
"validation_samples": None, # Same as above but for validation.
|
||||
"use_ema": True,
|
||||
"ema_beta": 0.99,
|
||||
"amp": False,
|
||||
"save_all": False, # Whether to preserve all checkpoints
|
||||
"save_latest": True, # Whether to always save the latest checkpoint
|
||||
"save_best": True, # Whether to save the best checkpoint
|
||||
"unet_training_mask": None # If None, use all unets
|
||||
},
|
||||
"evaluate": {
|
||||
"n_evalation_samples": 1000,
|
||||
"FID": None,
|
||||
"IS": None,
|
||||
"KID": None,
|
||||
"LPIPS": None
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "console", # Decoder currently supports console and wandb
|
||||
"data_path": "./models", # The path where files will be saved locally
|
||||
|
||||
"wandb_entity": "", # Only needs to be set if tracker_type is wandb
|
||||
"wandb_project": "",
|
||||
|
||||
"verbose": False # Whether to print console logging for non-console trackers
|
||||
},
|
||||
"load": {
|
||||
"source": None, # Supports file and wandb
|
||||
|
||||
"run_path": "", # Used only if source is wandb
|
||||
"file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
||||
|
||||
"resume": False # If using wandb, whether to resume the run
|
||||
}
|
||||
}
|
||||
100
configs/train_decoder_config.json.example
Normal file
100
configs/train_decoder_config.json.example
Normal file
@@ -0,0 +1,100 @@
|
||||
{
|
||||
"unets": [
|
||||
{
|
||||
"dim": 128,
|
||||
"image_embed_dim": 768,
|
||||
"cond_dim": 64,
|
||||
"channels": 3,
|
||||
"dim_mults": [1, 2, 4, 8],
|
||||
"attn_dim_head": 32,
|
||||
"attn_heads": 16
|
||||
}
|
||||
],
|
||||
"decoder": {
|
||||
"image_sizes": [64],
|
||||
"image_size": [64],
|
||||
"channels": 3,
|
||||
"timesteps": 1000,
|
||||
"loss_type": "l2",
|
||||
"beta_schedule": "cosine",
|
||||
"learned_variance": true
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
|
||||
"embeddings_url": "s3://bucket/embeddings/path/",
|
||||
"num_workers": 4,
|
||||
"batch_size": 64,
|
||||
"start_shard": 0,
|
||||
"end_shard": 9999999,
|
||||
"shard_width": 6,
|
||||
"index_width": 4,
|
||||
"splits": {
|
||||
"train": 0.75,
|
||||
"val": 0.15,
|
||||
"test": 0.1
|
||||
},
|
||||
"shuffle_train": true,
|
||||
"resample_train": false,
|
||||
"preprocessing": {
|
||||
"RandomResizedCrop": {
|
||||
"size": [128, 128],
|
||||
"scale": [0.75, 1.0],
|
||||
"ratio": [1.0, 1.0]
|
||||
},
|
||||
"ToTensor": true
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 20,
|
||||
"lr": 1e-4,
|
||||
"wd": 0.01,
|
||||
"max_grad_norm": 0.5,
|
||||
"save_every_n_samples": 100000,
|
||||
"n_sample_images": 6,
|
||||
"device": "cuda:0",
|
||||
"epoch_samples": null,
|
||||
"validation_samples": null,
|
||||
"use_ema": true,
|
||||
"ema_beta": 0.99,
|
||||
"amp": false,
|
||||
"save_all": false,
|
||||
"save_latest": true,
|
||||
"save_best": true,
|
||||
"unet_training_mask": [true]
|
||||
},
|
||||
"evaluate": {
|
||||
"n_evalation_samples": 1000,
|
||||
"FID": {
|
||||
"feature": 64
|
||||
},
|
||||
"IS": {
|
||||
"feature": 64,
|
||||
"splits": 10
|
||||
},
|
||||
"KID": {
|
||||
"feature": 64,
|
||||
"subset_size": 10
|
||||
},
|
||||
"LPIPS": {
|
||||
"net_type": "vgg",
|
||||
"reduction": "mean"
|
||||
}
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "console",
|
||||
"data_path": "./models",
|
||||
|
||||
"wandb_entity": "",
|
||||
"wandb_project": "",
|
||||
|
||||
"verbose": false
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
|
||||
"run_path": "",
|
||||
"file_path": "",
|
||||
|
||||
"resume": false
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,8 @@ def get_optimizer(
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
eps = 1e-8,
|
||||
filter_by_requires_grad = False
|
||||
filter_by_requires_grad = False,
|
||||
**kwargs
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
params = list(filter(lambda t: t.requires_grad, params))
|
||||
|
||||
3
setup.py
3
setup.py
@@ -41,7 +41,8 @@ setup(
|
||||
'x-clip>=0.4.4',
|
||||
'youtokentome',
|
||||
'webdataset>=0.2.5',
|
||||
'fsspec>=2022.1.0'
|
||||
'fsspec>=2022.1.0',
|
||||
'torchmetrics[image]>=0.8.0'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
|
||||
500
train_decoder.py
Normal file
500
train_decoder.py
Normal file
@@ -0,0 +1,500 @@
|
||||
from dalle2_pytorch import Unet, Decoder
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||
from configs.decoder_defaults import default_config, ConfigField
|
||||
import time
|
||||
import json
|
||||
import torchvision
|
||||
from torchvision import transforms as T
|
||||
import torch
|
||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
from torchmetrics.image.inception import InceptionScore
|
||||
from torchmetrics.image.kid import KernelInceptionDistance
|
||||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||
import webdataset as wds
|
||||
import click
|
||||
|
||||
|
||||
def create_dataloaders(
|
||||
available_shards,
|
||||
webdataset_base_url,
|
||||
embeddings_url,
|
||||
shard_width=6,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
n_sample_images=6,
|
||||
shuffle_train=True,
|
||||
resample_train=False,
|
||||
img_preproc = None,
|
||||
index_width=4,
|
||||
train_prop = 0.75,
|
||||
val_prop = 0.15,
|
||||
test_prop = 0.10,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
|
||||
"""
|
||||
assert train_prop + test_prop + val_prop == 1
|
||||
num_train = round(train_prop*len(available_shards))
|
||||
num_test = round(test_prop*len(available_shards))
|
||||
num_val = len(available_shards) - num_train - num_test
|
||||
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
|
||||
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
|
||||
|
||||
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
|
||||
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
|
||||
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
||||
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
||||
|
||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
|
||||
tar_url=tar_urls,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size if not for_sampling else n_sample_images,
|
||||
embeddings_url=embeddings_url,
|
||||
index_width=index_width,
|
||||
shuffle_num = None,
|
||||
extra_keys= ["txt"] if with_text else [],
|
||||
shuffle_shards = shuffle,
|
||||
resample_shards = resample,
|
||||
img_preproc=img_preproc,
|
||||
handler=wds.handlers.warn_and_continue
|
||||
)
|
||||
|
||||
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
||||
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
||||
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
|
||||
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
|
||||
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
||||
return {
|
||||
"train": train_dataloader,
|
||||
"train_sampling": train_sampling_dataloader,
|
||||
"val": val_dataloader,
|
||||
"test": test_dataloader,
|
||||
"test_sampling": test_sampling_dataloader
|
||||
}
|
||||
|
||||
|
||||
def create_decoder(device, decoder_config, unets_config):
|
||||
"""Creates a sample decoder"""
|
||||
unets = []
|
||||
for i in range(0, len(unets_config)):
|
||||
unets.append(Unet(
|
||||
**unets_config[i]
|
||||
))
|
||||
|
||||
decoder = Decoder(
|
||||
unet=tuple(unets), # Must be tuple because of cast_tuple
|
||||
**decoder_config
|
||||
)
|
||||
decoder.to(device=device)
|
||||
|
||||
return decoder
|
||||
|
||||
def get_dataset_keys(dataloader):
|
||||
"""
|
||||
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
||||
"""
|
||||
# If the dataloader is actually a WebLoader, we need to extract the real dataloader
|
||||
if isinstance(dataloader, wds.WebLoader):
|
||||
dataloader = dataloader.pipeline[0]
|
||||
return dataloader.dataset.key_map
|
||||
|
||||
def get_example_data(dataloader, device, n=5):
|
||||
"""
|
||||
Samples the dataloader and returns a zipped list of examples
|
||||
"""
|
||||
images = []
|
||||
embeddings = []
|
||||
captions = []
|
||||
dataset_keys = get_dataset_keys(dataloader)
|
||||
has_caption = "txt" in dataset_keys
|
||||
for data in dataloader:
|
||||
if has_caption:
|
||||
img, emb, txt = data
|
||||
else:
|
||||
img, emb = data
|
||||
txt = [""] * emb.shape[0]
|
||||
img = img.to(device=device, dtype=torch.float)
|
||||
emb = emb.to(device=device, dtype=torch.float)
|
||||
images.extend(list(img))
|
||||
embeddings.extend(list(emb))
|
||||
captions.extend(list(txt))
|
||||
if len(images) >= n:
|
||||
break
|
||||
print("Generated {} examples".format(len(images)))
|
||||
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, text_prepend=""):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
"""
|
||||
real_images, embeddings, txts = zip(*example_data)
|
||||
embeddings_tensor = torch.stack(embeddings)
|
||||
samples = trainer.sample(embeddings_tensor)
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||
"""
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
"""
|
||||
Computes evaluation metrics for the decoder
|
||||
"""
|
||||
metrics = {}
|
||||
# Prepare the data
|
||||
examples = get_example_data(dataloader, device, n_evalation_samples)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||
if FID is not None:
|
||||
fid = FrechetInceptionDistance(**FID)
|
||||
fid.to(device=device)
|
||||
fid.update(int_real_images, real=True)
|
||||
fid.update(int_generated_images, real=False)
|
||||
metrics["FID"] = fid.compute().item()
|
||||
if IS is not None:
|
||||
inception = InceptionScore(**IS)
|
||||
inception.to(device=device)
|
||||
inception.update(int_real_images)
|
||||
is_mean, is_std = inception.compute()
|
||||
metrics["IS_mean"] = is_mean.item()
|
||||
metrics["IS_std"] = is_std.item()
|
||||
if KID is not None:
|
||||
kernel_inception = KernelInceptionDistance(**KID)
|
||||
kernel_inception.to(device=device)
|
||||
kernel_inception.update(int_real_images, real=True)
|
||||
kernel_inception.update(int_generated_images, real=False)
|
||||
kid_mean, kid_std = kernel_inception.compute()
|
||||
metrics["KID_mean"] = kid_mean.item()
|
||||
metrics["KID_std"] = kid_std.item()
|
||||
if LPIPS is not None:
|
||||
# Convert from [0, 1] to [-1, 1]
|
||||
renorm_real_images = real_images.mul(2).sub(1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
|
||||
lpips.to(device=device)
|
||||
lpips.update(renorm_real_images, renorm_generated_images)
|
||||
metrics["LPIPS"] = lpips.compute().item()
|
||||
return metrics
|
||||
|
||||
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
|
||||
"""
|
||||
Logs the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
if isinstance(relative_paths, str):
|
||||
relative_paths = [relative_paths]
|
||||
trainer_state_dict = {}
|
||||
trainer_state_dict["trainer"] = trainer.state_dict()
|
||||
trainer_state_dict['epoch'] = epoch
|
||||
trainer_state_dict['step'] = step
|
||||
trainer_state_dict['validation_losses'] = validation_losses
|
||||
for relative_path in relative_paths:
|
||||
tracker.save_state_dict(trainer_state_dict, relative_path)
|
||||
|
||||
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
||||
"""
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
print(print_ribbon(f"Loading model from {recall_source}"))
|
||||
state_dict = tracker.recall_state_dict(recall_source, **load_config)
|
||||
trainer.load_state_dict(state_dict["trainer"])
|
||||
print("Model loaded")
|
||||
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
||||
|
||||
def train(
|
||||
dataloaders,
|
||||
decoder,
|
||||
tracker,
|
||||
inference_device,
|
||||
load_config=None,
|
||||
evaluate_config=None,
|
||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||
validation_samples = None,
|
||||
epochs = 20,
|
||||
n_sample_images = 5,
|
||||
save_every_n_samples = 100000,
|
||||
save_all=False,
|
||||
save_latest=True,
|
||||
save_best=True,
|
||||
unet_training_mask=None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Trains a decoder on a dataset.
|
||||
"""
|
||||
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
|
||||
decoder,
|
||||
**kwargs
|
||||
)
|
||||
# Set up starting model and parameters based on a recalled state dict
|
||||
start_step = 0
|
||||
start_epoch = 0
|
||||
validation_losses = []
|
||||
|
||||
if load_config is not None and load_config["source"] is not None:
|
||||
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
|
||||
trainer.to(device=inference_device)
|
||||
|
||||
if unet_training_mask is None:
|
||||
# Then the unet mask should be true for all unets in the decoder
|
||||
unet_training_mask = [True] * trainer.num_unets
|
||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||
|
||||
print(print_ribbon("Generating Example Data", repeat=40))
|
||||
print("This can take a while to load the shard lists...")
|
||||
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
|
||||
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
|
||||
|
||||
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
||||
step = start_step
|
||||
for epoch in range(start_epoch, epochs):
|
||||
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
||||
trainer.train()
|
||||
|
||||
sample = 0
|
||||
last_sample = 0
|
||||
last_snapshot = 0
|
||||
last_time = time.time()
|
||||
losses = []
|
||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||
step += 1
|
||||
sample += img.shape[0]
|
||||
img, emb = send_to_device((img, emb))
|
||||
|
||||
for unet in range(1, trainer.num_unets+1):
|
||||
# Check if this is a unet we are training
|
||||
if unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
losses.append(loss)
|
||||
|
||||
samples_per_sec = (sample - last_sample) / (time.time() - last_time)
|
||||
last_time = time.time()
|
||||
last_sample = sample
|
||||
|
||||
if i % 10 == 0:
|
||||
average_loss = sum(losses) / len(losses)
|
||||
log_data = {
|
||||
"Training loss": average_loss,
|
||||
"Epoch": epoch,
|
||||
"Sample": sample,
|
||||
"Step": i,
|
||||
"Samples per second": samples_per_sec
|
||||
}
|
||||
tracker.log(log_data, step=step, verbose=True)
|
||||
losses = []
|
||||
|
||||
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||
last_snapshot = sample
|
||||
# We need to know where the model should be saved
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_all:
|
||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||
if n_sample_images is not None and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
trainer.train()
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||
|
||||
if epoch_samples is not None and sample >= epoch_samples:
|
||||
break
|
||||
|
||||
trainer.eval()
|
||||
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
|
||||
with torch.no_grad():
|
||||
sample = 0
|
||||
average_loss = 0
|
||||
start_time = time.time()
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
sample += img.shape[0]
|
||||
img, emb = send_to_device((img, emb))
|
||||
|
||||
for unet in range(1, len(decoder.unets)+1):
|
||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
||||
average_loss += loss
|
||||
|
||||
if i % 10 == 0:
|
||||
print(f"Epoch {epoch}/{epochs} - {sample / (time.time() - start_time):.2f} samples/sec")
|
||||
print(f"Loss: {average_loss / (i+1)}")
|
||||
print("")
|
||||
|
||||
if validation_samples is not None and sample >= validation_samples:
|
||||
break
|
||||
average_loss /= i+1
|
||||
log_data = {
|
||||
"Validation loss": average_loss
|
||||
}
|
||||
tracker.log(log_data, step=step, verbose=True)
|
||||
|
||||
# Compute evaluation metrics
|
||||
trainer.eval()
|
||||
if evaluate_config is not None:
|
||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
||||
tracker.log(evaluation, step=step, verbose=True)
|
||||
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||
|
||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||
# Get the same paths
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
||||
save_paths.append("best.pth")
|
||||
validation_losses.append(average_loss)
|
||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||
|
||||
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
||||
"""
|
||||
Creates a tracker of the specified type and initializes special features based on the full config
|
||||
"""
|
||||
tracker_config = config["tracker"]
|
||||
init_config = {}
|
||||
init_config["config"] = config.config
|
||||
if tracker_type == "console":
|
||||
tracker = ConsoleTracker(**init_config)
|
||||
elif tracker_type == "wandb":
|
||||
# We need to initialize the resume state here
|
||||
load_config = config["load"]
|
||||
if load_config["source"] == "wandb" and load_config["resume"]:
|
||||
# Then we are resuming the run load_config["run_path"]
|
||||
run_id = config["resume"]["wandb_run_path"].split("/")[-1]
|
||||
init_config["id"] = run_id
|
||||
init_config["resume"] = "must"
|
||||
init_config["entity"] = tracker_config["wandb_entity"]
|
||||
init_config["project"] = tracker_config["wandb_project"]
|
||||
tracker = WandbTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
else:
|
||||
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
||||
return tracker
|
||||
|
||||
def initialize_training(config):
|
||||
# Create the save path
|
||||
if "cuda" in config["train"]["device"]:
|
||||
assert torch.cuda.is_available(), "CUDA is not available"
|
||||
device = torch.device(config["train"]["device"])
|
||||
torch.cuda.set_device(device)
|
||||
all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
|
||||
|
||||
dataloaders = create_dataloaders (
|
||||
available_shards=all_shards,
|
||||
img_preproc = config.get_preprocessing(),
|
||||
train_prop = config["data"]["splits"]["train"],
|
||||
val_prop = config["data"]["splits"]["val"],
|
||||
test_prop = config["data"]["splits"]["test"],
|
||||
n_sample_images=config["train"]["n_sample_images"],
|
||||
**config["data"]
|
||||
)
|
||||
|
||||
decoder = create_decoder(device, config["decoder"], config["unets"])
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
print(print_ribbon("Loaded Config", repeat=40))
|
||||
print(f"Number of parameters: {num_parameters}")
|
||||
|
||||
tracker = create_tracker(config, **config["tracker"])
|
||||
|
||||
train(dataloaders, decoder,
|
||||
tracker=tracker,
|
||||
inference_device=device,
|
||||
load_config=config["load"],
|
||||
evaluate_config=config["evaluate"],
|
||||
**config["train"],
|
||||
)
|
||||
|
||||
|
||||
class TrainDecoderConfig:
|
||||
def __init__(self, config):
|
||||
self.config = self.map_config(config, default_config)
|
||||
|
||||
def map_config(self, config, defaults):
|
||||
"""
|
||||
Returns a dictionary containing all config options in the union of config and defaults.
|
||||
If the config value is an array, apply the default value to each element.
|
||||
If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
|
||||
"""
|
||||
def _check_option(option, option_config, option_defaults):
|
||||
for key, value in option_defaults.items():
|
||||
if key not in option_config:
|
||||
if value == ConfigField.REQUIRED:
|
||||
raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
|
||||
option_config[key] = value
|
||||
|
||||
for key, value in defaults.items():
|
||||
if key not in config:
|
||||
# Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
|
||||
if value == ConfigField.REQUIRED:
|
||||
raise RuntimeError("Required config value '{}' not supplied".format(key))
|
||||
elif isinstance(value, dict):
|
||||
config[key] = {}
|
||||
elif isinstance(value, list):
|
||||
config[key] = [{}]
|
||||
# Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
|
||||
# If it is a list, then we need to check each element
|
||||
if isinstance(value, list):
|
||||
assert isinstance(config[key], list)
|
||||
for element in config[key]:
|
||||
_check_option(key, element, value[0])
|
||||
elif isinstance(value, dict):
|
||||
_check_option(key, config[key], value)
|
||||
# This object does not support checking
|
||||
return config
|
||||
|
||||
def get_preprocessing(self):
|
||||
"""
|
||||
Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
|
||||
"""
|
||||
def _get_transformation(transformation_name, **kwargs):
|
||||
if transformation_name == "RandomResizedCrop":
|
||||
return T.RandomResizedCrop(**kwargs)
|
||||
elif transformation_name == "RandomHorizontalFlip":
|
||||
return T.RandomHorizontalFlip()
|
||||
elif transformation_name == "ToTensor":
|
||||
return T.ToTensor()
|
||||
|
||||
transformations = []
|
||||
for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
|
||||
if isinstance(transformation_kwargs, dict):
|
||||
transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
|
||||
else:
|
||||
transformations.append(_get_transformation(transformation_name))
|
||||
return T.Compose(transformations)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.config[key]
|
||||
|
||||
# Create a simple click command line interface to load the config and start the training
|
||||
@click.command()
|
||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||
def main(config_file):
|
||||
print("Recalling config from {}".format(config_file))
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
config = TrainDecoderConfig(config)
|
||||
initialize_training(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user