Added single GPU training script for decoder (#108)

Added config files for training

Changed example image generation to be more efficient

Added configuration description to README

Removed unused import
This commit is contained in:
Aidan Dempster
2022-05-20 22:46:19 -04:00
committed by GitHub
parent 430961cb97
commit 022c94e443
7 changed files with 801 additions and 2 deletions

6
.gitignore vendored
View File

@@ -1,6 +1,12 @@
# default experiment tracker data # default experiment tracker data
.tracker-data/ .tracker-data/
# Configuration Files
configs/*
!configs/*.example
!configs/*_defaults.py
!configs/README.md
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

109
configs/README.md Normal file
View File

@@ -0,0 +1,109 @@
## DALLE2 Training Configurations
For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
### Decoder Trainer
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example).
**<ins>Unets</ins>:**
Each member of this array defines a single unet that will be added to the decoder.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `dim` | Yes | N/A | The starting channels of the unet. |
| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |
| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |
Any parameter from the `Unet` constructor can also be given here.
**<ins>Decoder</ins>:**
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
| `image_size` | Yes | N/A | Not used. Can be any number. |
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
| `learned_variance` | No | `True` | Whether to learn the variance. |
Any parameter from the `Decoder` constructor can also be given here.
**<ins>Data</ins>:**
Settings for creation of the dataloaders.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
| `batch_size` | No | `64` | The batch size. |
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |
| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |
| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |
| `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |
| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |
| `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. |
[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.
[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.
[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.
**<ins>Train</ins>:**
Settings for controlling the training hyperparameters.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `epochs` | No | `20` | The number of epochs in the training run. |
| `lr` | No | `1e-4` | The learning rate. |
| `wd` | No | `0.01` | The weight decay. |
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
| `device` | No | `cuda:0` | The device to train on. |
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
| `ema_beta` | No | `0.99` | The ema coefficient. |
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
**<ins>Evaluate</ins>:**
Defines which evaluation metrics will be used to test the model.
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |
**<ins>Tracker</ins>:**
Selects which tracker to use and configures it.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
| `data_path` | No | `./models` | Where the tracker will store local data. |
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
**<ins>Load</ins>:**
Selects where to load a pretrained model from.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `source` | No | `None` | Supports `file` or `wandb`. |
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).

View File

@@ -0,0 +1,82 @@
"""
Defines the default values for the decoder config
"""
from enum import Enum
class ConfigField(Enum):
REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it.
default_config = {
"unets": ConfigField.REQUIRED,
"decoder": {
"image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet
"image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": True
},
"data": {
"webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images
"embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": True,
"resample_train": False,
"preprocessing": {
"ToTensor": True
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset
"device": "cuda:0",
"epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
"validation_samples": None, # Same as above but for validation.
"use_ema": True,
"ema_beta": 0.99,
"amp": False,
"save_all": False, # Whether to preserve all checkpoints
"save_latest": True, # Whether to always save the latest checkpoint
"save_best": True, # Whether to save the best checkpoint
"unet_training_mask": None # If None, use all unets
},
"evaluate": {
"n_evalation_samples": 1000,
"FID": None,
"IS": None,
"KID": None,
"LPIPS": None
},
"tracker": {
"tracker_type": "console", # Decoder currently supports console and wandb
"data_path": "./models", # The path where files will be saved locally
"wandb_entity": "", # Only needs to be set if tracker_type is wandb
"wandb_project": "",
"verbose": False # Whether to print console logging for non-console trackers
},
"load": {
"source": None, # Supports file and wandb
"run_path": "", # Used only if source is wandb
"file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
"resume": False # If using wandb, whether to resume the run
}
}

View File

@@ -0,0 +1,100 @@
{
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": {
"image_sizes": [64],
"image_size": [64],
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": true
},
"data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"embeddings_url": "s3://bucket/embeddings/path/",
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": true,
"resample_train": false,
"preprocessing": {
"RandomResizedCrop": {
"size": [128, 128],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6,
"device": "cuda:0",
"epoch_samples": null,
"validation_samples": null,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
"n_evalation_samples": 1000,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 10
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"tracker_type": "console",
"data_path": "./models",
"wandb_entity": "",
"wandb_project": "",
"verbose": false
},
"load": {
"source": null,
"run_path": "",
"file_path": "",
"resume": false
}
}

View File

@@ -11,7 +11,8 @@ def get_optimizer(
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.999),
eps = 1e-8, eps = 1e-8,
filter_by_requires_grad = False filter_by_requires_grad = False,
**kwargs
): ):
if filter_by_requires_grad: if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params)) params = list(filter(lambda t: t.requires_grad, params))

View File

@@ -41,7 +41,8 @@ setup(
'x-clip>=0.4.4', 'x-clip>=0.4.4',
'youtokentome', 'youtokentome',
'webdataset>=0.2.5', 'webdataset>=0.2.5',
'fsspec>=2022.1.0' 'fsspec>=2022.1.0',
'torchmetrics[image]>=0.8.0'
], ],
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',

500
train_decoder.py Normal file
View File

@@ -0,0 +1,500 @@
from dalle2_pytorch import Unet, Decoder
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from configs.decoder_defaults import default_config, ConfigField
import time
import json
import torchvision
from torchvision import transforms as T
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import webdataset as wds
import click
def create_dataloaders(
available_shards,
webdataset_base_url,
embeddings_url,
shard_width=6,
num_workers=4,
batch_size=32,
n_sample_images=6,
shuffle_train=True,
resample_train=False,
img_preproc = None,
index_width=4,
train_prop = 0.75,
val_prop = 0.15,
test_prop = 0.10,
**kwargs
):
"""
Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
"""
assert train_prop + test_prop + val_prop == 1
num_train = round(train_prop*len(available_shards))
num_test = round(test_prop*len(available_shards))
num_val = len(available_shards) - num_train - num_test
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
tar_url=tar_urls,
num_workers=num_workers,
batch_size=batch_size if not for_sampling else n_sample_images,
embeddings_url=embeddings_url,
index_width=index_width,
shuffle_num = None,
extra_keys= ["txt"] if with_text else [],
shuffle_shards = shuffle,
resample_shards = resample,
img_preproc=img_preproc,
handler=wds.handlers.warn_and_continue
)
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
return {
"train": train_dataloader,
"train_sampling": train_sampling_dataloader,
"val": val_dataloader,
"test": test_dataloader,
"test_sampling": test_sampling_dataloader
}
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = []
for i in range(0, len(unets_config)):
unets.append(Unet(
**unets_config[i]
))
decoder = Decoder(
unet=tuple(unets), # Must be tuple because of cast_tuple
**decoder_config
)
decoder.to(device=device)
return decoder
def get_dataset_keys(dataloader):
"""
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
"""
# If the dataloader is actually a WebLoader, we need to extract the real dataloader
if isinstance(dataloader, wds.WebLoader):
dataloader = dataloader.pipeline[0]
return dataloader.dataset.key_map
def get_example_data(dataloader, device, n=5):
"""
Samples the dataloader and returns a zipped list of examples
"""
images = []
embeddings = []
captions = []
dataset_keys = get_dataset_keys(dataloader)
has_caption = "txt" in dataset_keys
for data in dataloader:
if has_caption:
img, emb, txt = data
else:
img, emb = data
txt = [""] * emb.shape[0]
img = img.to(device=device, dtype=torch.float)
emb = emb.to(device=device, dtype=torch.float)
images.extend(list(img))
embeddings.extend(list(emb))
captions.extend(list(txt))
if len(images) >= n:
break
print("Generated {} examples".format(len(images)))
return list(zip(images[:n], embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, text_prepend=""):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
"""
real_images, embeddings, txts = zip(*example_data)
embeddings_tensor = torch.stack(embeddings)
samples = trainer.sample(embeddings_tensor)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, text_prepend=""):
"""
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
metrics = {}
# Prepare the data
examples = get_example_data(dataloader, device, n_evalation_samples)
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
if FID is not None:
fid = FrechetInceptionDistance(**FID)
fid.to(device=device)
fid.update(int_real_images, real=True)
fid.update(int_generated_images, real=False)
metrics["FID"] = fid.compute().item()
if IS is not None:
inception = InceptionScore(**IS)
inception.to(device=device)
inception.update(int_real_images)
is_mean, is_std = inception.compute()
metrics["IS_mean"] = is_mean.item()
metrics["IS_std"] = is_std.item()
if KID is not None:
kernel_inception = KernelInceptionDistance(**KID)
kernel_inception.to(device=device)
kernel_inception.update(int_real_images, real=True)
kernel_inception.update(int_generated_images, real=False)
kid_mean, kid_std = kernel_inception.compute()
metrics["KID_mean"] = kid_mean.item()
metrics["KID_std"] = kid_std.item()
if LPIPS is not None:
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
metrics["LPIPS"] = lpips.compute().item()
return metrics
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
"""
Logs the model with an appropriate method depending on the tracker
"""
if isinstance(relative_paths, str):
relative_paths = [relative_paths]
trainer_state_dict = {}
trainer_state_dict["trainer"] = trainer.state_dict()
trainer_state_dict['epoch'] = epoch
trainer_state_dict['step'] = step
trainer_state_dict['validation_losses'] = validation_losses
for relative_path in relative_paths:
tracker.save_state_dict(trainer_state_dict, relative_path)
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
"""
Loads the model with an appropriate method depending on the tracker
"""
print(print_ribbon(f"Loading model from {recall_source}"))
state_dict = tracker.recall_state_dict(recall_source, **load_config)
trainer.load_state_dict(state_dict["trainer"])
print("Model loaded")
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
def train(
dataloaders,
decoder,
tracker,
inference_device,
load_config=None,
evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None,
epochs = 20,
n_sample_images = 5,
save_every_n_samples = 100000,
save_all=False,
save_latest=True,
save_best=True,
unet_training_mask=None,
**kwargs
):
"""
Trains a decoder on a dataset.
"""
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
decoder,
**kwargs
)
# Set up starting model and parameters based on a recalled state dict
start_step = 0
start_epoch = 0
validation_losses = []
if load_config is not None and load_config["source"] is not None:
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
trainer.to(device=inference_device)
if unet_training_mask is None:
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
print(print_ribbon("Generating Example Data", repeat=40))
print("This can take a while to load the shard lists...")
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
step = start_step
for epoch in range(start_epoch, epochs):
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
trainer.train()
sample = 0
last_sample = 0
last_snapshot = 0
last_time = time.time()
losses = []
for i, (img, emb) in enumerate(dataloaders["train"]):
step += 1
sample += img.shape[0]
img, emb = send_to_device((img, emb))
for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training
if unet_training_mask[unet-1]: # Unet index is the unet number - 1
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
trainer.update(unet_number=unet)
losses.append(loss)
samples_per_sec = (sample - last_sample) / (time.time() - last_time)
last_time = time.time()
last_sample = sample
if i % 10 == 0:
average_loss = sum(losses) / len(losses)
log_data = {
"Training loss": average_loss,
"Epoch": epoch,
"Sample": sample,
"Step": i,
"Samples per second": samples_per_sec
}
tracker.log(log_data, step=step, verbose=True)
losses = []
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
last_snapshot = sample
# We need to know where the model should be saved
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
if n_sample_images is not None and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
trainer.train()
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
if epoch_samples is not None and sample >= epoch_samples:
break
trainer.eval()
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
with torch.no_grad():
sample = 0
average_loss = 0
start_time = time.time()
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
sample += img.shape[0]
img, emb = send_to_device((img, emb))
for unet in range(1, len(decoder.unets)+1):
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
average_loss += loss
if i % 10 == 0:
print(f"Epoch {epoch}/{epochs} - {sample / (time.time() - start_time):.2f} samples/sec")
print(f"Loss: {average_loss / (i+1)}")
print("")
if validation_samples is not None and sample >= validation_samples:
break
average_loss /= i+1
log_data = {
"Validation loss": average_loss
}
tracker.log(log_data, step=step, verbose=True)
# Compute evaluation metrics
trainer.eval()
if evaluate_config is not None:
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
tracker.log(evaluation, step=step, verbose=True)
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
# Get the same paths
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
save_paths.append("best.pth")
validation_losses.append(average_loss)
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
tracker_config = config["tracker"]
init_config = {}
init_config["config"] = config.config
if tracker_type == "console":
tracker = ConsoleTracker(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config["load"]
if load_config["source"] == "wandb" and load_config["resume"]:
# Then we are resuming the run load_config["run_path"]
run_id = config["resume"]["wandb_run_path"].split("/")[-1]
init_config["id"] = run_id
init_config["resume"] = "must"
init_config["entity"] = tracker_config["wandb_entity"]
init_config["project"] = tracker_config["wandb_project"]
tracker = WandbTracker(data_path)
tracker.init(**init_config)
else:
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
return tracker
def initialize_training(config):
# Create the save path
if "cuda" in config["train"]["device"]:
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device(config["train"]["device"])
torch.cuda.set_device(device)
all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
dataloaders = create_dataloaders (
available_shards=all_shards,
img_preproc = config.get_preprocessing(),
train_prop = config["data"]["splits"]["train"],
val_prop = config["data"]["splits"]["val"],
test_prop = config["data"]["splits"]["test"],
n_sample_images=config["train"]["n_sample_images"],
**config["data"]
)
decoder = create_decoder(device, config["decoder"], config["unets"])
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")
tracker = create_tracker(config, **config["tracker"])
train(dataloaders, decoder,
tracker=tracker,
inference_device=device,
load_config=config["load"],
evaluate_config=config["evaluate"],
**config["train"],
)
class TrainDecoderConfig:
def __init__(self, config):
self.config = self.map_config(config, default_config)
def map_config(self, config, defaults):
"""
Returns a dictionary containing all config options in the union of config and defaults.
If the config value is an array, apply the default value to each element.
If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
"""
def _check_option(option, option_config, option_defaults):
for key, value in option_defaults.items():
if key not in option_config:
if value == ConfigField.REQUIRED:
raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
option_config[key] = value
for key, value in defaults.items():
if key not in config:
# Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
if value == ConfigField.REQUIRED:
raise RuntimeError("Required config value '{}' not supplied".format(key))
elif isinstance(value, dict):
config[key] = {}
elif isinstance(value, list):
config[key] = [{}]
# Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
# If it is a list, then we need to check each element
if isinstance(value, list):
assert isinstance(config[key], list)
for element in config[key]:
_check_option(key, element, value[0])
elif isinstance(value, dict):
_check_option(key, config[key], value)
# This object does not support checking
return config
def get_preprocessing(self):
"""
Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
"""
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transformations = []
for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
if isinstance(transformation_kwargs, dict):
transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
else:
transformations.append(_get_transformation(transformation_name))
return T.Compose(transformations)
def __getitem__(self, key):
return self.config[key]
# Create a simple click command line interface to load the config and start the training
@click.command()
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file):
print("Recalling config from {}".format(config_file))
with open(config_file) as f:
config = json.load(f)
config = TrainDecoderConfig(config)
initialize_training(config)
if __name__ == "__main__":
main()