diff --git a/.gitignore b/.gitignore
index 55301b1..41f11cf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,12 @@
# default experiment tracker data
.tracker-data/
+# Configuration Files
+configs/*
+!configs/*.example
+!configs/*_defaults.py
+!configs/README.md
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
diff --git a/configs/README.md b/configs/README.md
new file mode 100644
index 0000000..80a942e
--- /dev/null
+++ b/configs/README.md
@@ -0,0 +1,109 @@
+## DALLE2 Training Configurations
+
+For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
+
+### Decoder Trainer
+
+The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example).
+
+**Unets:**
+
+Each member of this array defines a single unet that will be added to the decoder.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `dim` | Yes | N/A | The starting channels of the unet. |
+| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |
+| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |
+
+Any parameter from the `Unet` constructor can also be given here.
+
+**Decoder:**
+
+Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
+| `image_size` | Yes | N/A | Not used. Can be any number. |
+| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
+| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
+| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
+| `learned_variance` | No | `True` | Whether to learn the variance. |
+
+Any parameter from the `Decoder` constructor can also be given here.
+
+**Data:**
+
+Settings for creation of the dataloaders.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
+| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
+| `num_workers` | No | `4` | The number of workers used in the dataloader. |
+| `batch_size` | No | `64` | The batch size. |
+| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
+| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |
+| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |
+| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |
+| `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
+| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |
+| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |
+| `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. |
+
+[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.
+
+[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.
+
+[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.
+
+**Train:**
+
+Settings for controlling the training hyperparameters.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `epochs` | No | `20` | The number of epochs in the training run. |
+| `lr` | No | `1e-4` | The learning rate. |
+| `wd` | No | `0.01` | The weight decay. |
+| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
+| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
+| `device` | No | `cuda:0` | The device to train on. |
+| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
+| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
+| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
+| `ema_beta` | No | `0.99` | The ema coefficient. |
+| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
+| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
+| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
+| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
+
+**Evaluate:**
+
+Defines which evaluation metrics will be used to test the model.
+Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
+| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
+| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
+| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
+| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |
+
+**Tracker:**
+
+Selects which tracker to use and configures it.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
+| `data_path` | No | `./models` | Where the tracker will store local data. |
+| `verbose` | No | `False` | Enables console logging for non-console trackers. |
+
+Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
+
+**Load:**
+
+Selects where to load a pretrained model from.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `source` | No | `None` | Supports `file` or `wandb`. |
+| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
+
+Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
diff --git a/configs/decoder_defaults.py b/configs/decoder_defaults.py
new file mode 100644
index 0000000..e36cd41
--- /dev/null
+++ b/configs/decoder_defaults.py
@@ -0,0 +1,82 @@
+"""
+Defines the default values for the decoder config
+"""
+
+from enum import Enum
+class ConfigField(Enum):
+ REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it.
+
+default_config = {
+ "unets": ConfigField.REQUIRED,
+ "decoder": {
+ "image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet
+ "image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think
+ "channels": 3,
+ "timesteps": 1000,
+ "loss_type": "l2",
+ "beta_schedule": "cosine",
+ "learned_variance": True
+ },
+ "data": {
+ "webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images
+ "embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings
+ "num_workers": 4,
+ "batch_size": 64,
+ "start_shard": 0,
+ "end_shard": 9999999,
+ "shard_width": 6,
+ "index_width": 4,
+ "splits": {
+ "train": 0.75,
+ "val": 0.15,
+ "test": 0.1
+ },
+ "shuffle_train": True,
+ "resample_train": False,
+ "preprocessing": {
+ "ToTensor": True
+ }
+ },
+ "train": {
+ "epochs": 20,
+ "lr": 1e-4,
+ "wd": 0.01,
+ "max_grad_norm": 0.5,
+ "save_every_n_samples": 100000,
+ "n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset
+ "device": "cuda:0",
+ "epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
+ "validation_samples": None, # Same as above but for validation.
+ "use_ema": True,
+ "ema_beta": 0.99,
+ "amp": False,
+ "save_all": False, # Whether to preserve all checkpoints
+ "save_latest": True, # Whether to always save the latest checkpoint
+ "save_best": True, # Whether to save the best checkpoint
+ "unet_training_mask": None # If None, use all unets
+ },
+ "evaluate": {
+ "n_evalation_samples": 1000,
+ "FID": None,
+ "IS": None,
+ "KID": None,
+ "LPIPS": None
+ },
+ "tracker": {
+ "tracker_type": "console", # Decoder currently supports console and wandb
+ "data_path": "./models", # The path where files will be saved locally
+
+ "wandb_entity": "", # Only needs to be set if tracker_type is wandb
+ "wandb_project": "",
+
+ "verbose": False # Whether to print console logging for non-console trackers
+ },
+ "load": {
+ "source": None, # Supports file and wandb
+
+ "run_path": "", # Used only if source is wandb
+ "file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
+
+ "resume": False # If using wandb, whether to resume the run
+ }
+}
diff --git a/configs/train_decoder_config.json.example b/configs/train_decoder_config.json.example
new file mode 100644
index 0000000..d2645c1
--- /dev/null
+++ b/configs/train_decoder_config.json.example
@@ -0,0 +1,100 @@
+{
+ "unets": [
+ {
+ "dim": 128,
+ "image_embed_dim": 768,
+ "cond_dim": 64,
+ "channels": 3,
+ "dim_mults": [1, 2, 4, 8],
+ "attn_dim_head": 32,
+ "attn_heads": 16
+ }
+ ],
+ "decoder": {
+ "image_sizes": [64],
+ "image_size": [64],
+ "channels": 3,
+ "timesteps": 1000,
+ "loss_type": "l2",
+ "beta_schedule": "cosine",
+ "learned_variance": true
+ },
+ "data": {
+ "webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
+ "embeddings_url": "s3://bucket/embeddings/path/",
+ "num_workers": 4,
+ "batch_size": 64,
+ "start_shard": 0,
+ "end_shard": 9999999,
+ "shard_width": 6,
+ "index_width": 4,
+ "splits": {
+ "train": 0.75,
+ "val": 0.15,
+ "test": 0.1
+ },
+ "shuffle_train": true,
+ "resample_train": false,
+ "preprocessing": {
+ "RandomResizedCrop": {
+ "size": [128, 128],
+ "scale": [0.75, 1.0],
+ "ratio": [1.0, 1.0]
+ },
+ "ToTensor": true
+ }
+ },
+ "train": {
+ "epochs": 20,
+ "lr": 1e-4,
+ "wd": 0.01,
+ "max_grad_norm": 0.5,
+ "save_every_n_samples": 100000,
+ "n_sample_images": 6,
+ "device": "cuda:0",
+ "epoch_samples": null,
+ "validation_samples": null,
+ "use_ema": true,
+ "ema_beta": 0.99,
+ "amp": false,
+ "save_all": false,
+ "save_latest": true,
+ "save_best": true,
+ "unet_training_mask": [true]
+ },
+ "evaluate": {
+ "n_evalation_samples": 1000,
+ "FID": {
+ "feature": 64
+ },
+ "IS": {
+ "feature": 64,
+ "splits": 10
+ },
+ "KID": {
+ "feature": 64,
+ "subset_size": 10
+ },
+ "LPIPS": {
+ "net_type": "vgg",
+ "reduction": "mean"
+ }
+ },
+ "tracker": {
+ "tracker_type": "console",
+ "data_path": "./models",
+
+ "wandb_entity": "",
+ "wandb_project": "",
+
+ "verbose": false
+ },
+ "load": {
+ "source": null,
+
+ "run_path": "",
+ "file_path": "",
+
+ "resume": false
+ }
+}
diff --git a/dalle2_pytorch/optimizer.py b/dalle2_pytorch/optimizer.py
index ad1431d..ee366d8 100644
--- a/dalle2_pytorch/optimizer.py
+++ b/dalle2_pytorch/optimizer.py
@@ -11,7 +11,8 @@ def get_optimizer(
wd = 1e-2,
betas = (0.9, 0.999),
eps = 1e-8,
- filter_by_requires_grad = False
+ filter_by_requires_grad = False,
+ **kwargs
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
diff --git a/setup.py b/setup.py
index 4a191d7..aa02390 100644
--- a/setup.py
+++ b/setup.py
@@ -41,7 +41,8 @@ setup(
'x-clip>=0.4.4',
'youtokentome',
'webdataset>=0.2.5',
- 'fsspec>=2022.1.0'
+ 'fsspec>=2022.1.0',
+ 'torchmetrics[image]>=0.8.0'
],
classifiers=[
'Development Status :: 4 - Beta',
diff --git a/train_decoder.py b/train_decoder.py
new file mode 100644
index 0000000..3c91c36
--- /dev/null
+++ b/train_decoder.py
@@ -0,0 +1,500 @@
+from dalle2_pytorch import Unet, Decoder
+from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
+from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
+from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
+from configs.decoder_defaults import default_config, ConfigField
+import time
+import json
+import torchvision
+from torchvision import transforms as T
+import torch
+from torchmetrics.image.fid import FrechetInceptionDistance
+from torchmetrics.image.inception import InceptionScore
+from torchmetrics.image.kid import KernelInceptionDistance
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
+import webdataset as wds
+import click
+
+
+def create_dataloaders(
+ available_shards,
+ webdataset_base_url,
+ embeddings_url,
+ shard_width=6,
+ num_workers=4,
+ batch_size=32,
+ n_sample_images=6,
+ shuffle_train=True,
+ resample_train=False,
+ img_preproc = None,
+ index_width=4,
+ train_prop = 0.75,
+ val_prop = 0.15,
+ test_prop = 0.10,
+ **kwargs
+):
+ """
+ Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
+ """
+ assert train_prop + test_prop + val_prop == 1
+ num_train = round(train_prop*len(available_shards))
+ num_test = round(test_prop*len(available_shards))
+ num_val = len(available_shards) - num_train - num_test
+ assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
+ train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
+
+ # The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
+ train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
+ test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
+ val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
+
+ create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
+ tar_url=tar_urls,
+ num_workers=num_workers,
+ batch_size=batch_size if not for_sampling else n_sample_images,
+ embeddings_url=embeddings_url,
+ index_width=index_width,
+ shuffle_num = None,
+ extra_keys= ["txt"] if with_text else [],
+ shuffle_shards = shuffle,
+ resample_shards = resample,
+ img_preproc=img_preproc,
+ handler=wds.handlers.warn_and_continue
+ )
+
+ train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
+ train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
+ val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
+ test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
+ test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
+ return {
+ "train": train_dataloader,
+ "train_sampling": train_sampling_dataloader,
+ "val": val_dataloader,
+ "test": test_dataloader,
+ "test_sampling": test_sampling_dataloader
+ }
+
+
+def create_decoder(device, decoder_config, unets_config):
+ """Creates a sample decoder"""
+ unets = []
+ for i in range(0, len(unets_config)):
+ unets.append(Unet(
+ **unets_config[i]
+ ))
+
+ decoder = Decoder(
+ unet=tuple(unets), # Must be tuple because of cast_tuple
+ **decoder_config
+ )
+ decoder.to(device=device)
+
+ return decoder
+
+def get_dataset_keys(dataloader):
+ """
+ It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
+ """
+ # If the dataloader is actually a WebLoader, we need to extract the real dataloader
+ if isinstance(dataloader, wds.WebLoader):
+ dataloader = dataloader.pipeline[0]
+ return dataloader.dataset.key_map
+
+def get_example_data(dataloader, device, n=5):
+ """
+ Samples the dataloader and returns a zipped list of examples
+ """
+ images = []
+ embeddings = []
+ captions = []
+ dataset_keys = get_dataset_keys(dataloader)
+ has_caption = "txt" in dataset_keys
+ for data in dataloader:
+ if has_caption:
+ img, emb, txt = data
+ else:
+ img, emb = data
+ txt = [""] * emb.shape[0]
+ img = img.to(device=device, dtype=torch.float)
+ emb = emb.to(device=device, dtype=torch.float)
+ images.extend(list(img))
+ embeddings.extend(list(emb))
+ captions.extend(list(txt))
+ if len(images) >= n:
+ break
+ print("Generated {} examples".format(len(images)))
+ return list(zip(images[:n], embeddings[:n], captions[:n]))
+
+def generate_samples(trainer, example_data, text_prepend=""):
+ """
+ Takes example data and generates images from the embeddings
+ Returns three lists: real images, generated images, and captions
+ """
+ real_images, embeddings, txts = zip(*example_data)
+ embeddings_tensor = torch.stack(embeddings)
+ samples = trainer.sample(embeddings_tensor)
+ generated_images = list(samples)
+ captions = [text_prepend + txt for txt in txts]
+ return real_images, generated_images, captions
+
+def generate_grid_samples(trainer, examples, text_prepend=""):
+ """
+ Generates samples and uses torchvision to put them in a side by side grid for easy viewing
+ """
+ real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
+ grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
+ return grid_images, captions
+
+def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
+ """
+ Computes evaluation metrics for the decoder
+ """
+ metrics = {}
+ # Prepare the data
+ examples = get_example_data(dataloader, device, n_evalation_samples)
+ real_images, generated_images, captions = generate_samples(trainer, examples)
+ real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
+ generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
+ # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
+ int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
+ int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
+ if FID is not None:
+ fid = FrechetInceptionDistance(**FID)
+ fid.to(device=device)
+ fid.update(int_real_images, real=True)
+ fid.update(int_generated_images, real=False)
+ metrics["FID"] = fid.compute().item()
+ if IS is not None:
+ inception = InceptionScore(**IS)
+ inception.to(device=device)
+ inception.update(int_real_images)
+ is_mean, is_std = inception.compute()
+ metrics["IS_mean"] = is_mean.item()
+ metrics["IS_std"] = is_std.item()
+ if KID is not None:
+ kernel_inception = KernelInceptionDistance(**KID)
+ kernel_inception.to(device=device)
+ kernel_inception.update(int_real_images, real=True)
+ kernel_inception.update(int_generated_images, real=False)
+ kid_mean, kid_std = kernel_inception.compute()
+ metrics["KID_mean"] = kid_mean.item()
+ metrics["KID_std"] = kid_std.item()
+ if LPIPS is not None:
+ # Convert from [0, 1] to [-1, 1]
+ renorm_real_images = real_images.mul(2).sub(1)
+ renorm_generated_images = generated_images.mul(2).sub(1)
+ lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
+ lpips.to(device=device)
+ lpips.update(renorm_real_images, renorm_generated_images)
+ metrics["LPIPS"] = lpips.compute().item()
+ return metrics
+
+def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
+ """
+ Logs the model with an appropriate method depending on the tracker
+ """
+ if isinstance(relative_paths, str):
+ relative_paths = [relative_paths]
+ trainer_state_dict = {}
+ trainer_state_dict["trainer"] = trainer.state_dict()
+ trainer_state_dict['epoch'] = epoch
+ trainer_state_dict['step'] = step
+ trainer_state_dict['validation_losses'] = validation_losses
+ for relative_path in relative_paths:
+ tracker.save_state_dict(trainer_state_dict, relative_path)
+
+def recall_trainer(tracker, trainer, recall_source=None, **load_config):
+ """
+ Loads the model with an appropriate method depending on the tracker
+ """
+ print(print_ribbon(f"Loading model from {recall_source}"))
+ state_dict = tracker.recall_state_dict(recall_source, **load_config)
+ trainer.load_state_dict(state_dict["trainer"])
+ print("Model loaded")
+ return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
+
+def train(
+ dataloaders,
+ decoder,
+ tracker,
+ inference_device,
+ load_config=None,
+ evaluate_config=None,
+ epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
+ validation_samples = None,
+ epochs = 20,
+ n_sample_images = 5,
+ save_every_n_samples = 100000,
+ save_all=False,
+ save_latest=True,
+ save_best=True,
+ unet_training_mask=None,
+ **kwargs
+):
+ """
+ Trains a decoder on a dataset.
+ """
+ trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
+ decoder,
+ **kwargs
+ )
+ # Set up starting model and parameters based on a recalled state dict
+ start_step = 0
+ start_epoch = 0
+ validation_losses = []
+
+ if load_config is not None and load_config["source"] is not None:
+ start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
+ trainer.to(device=inference_device)
+
+ if unet_training_mask is None:
+ # Then the unet mask should be true for all unets in the decoder
+ unet_training_mask = [True] * trainer.num_unets
+ assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
+
+ print(print_ribbon("Generating Example Data", repeat=40))
+ print("This can take a while to load the shard lists...")
+ train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
+ test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
+
+ send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
+ step = start_step
+ for epoch in range(start_epoch, epochs):
+ print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
+ trainer.train()
+
+ sample = 0
+ last_sample = 0
+ last_snapshot = 0
+ last_time = time.time()
+ losses = []
+ for i, (img, emb) in enumerate(dataloaders["train"]):
+ step += 1
+ sample += img.shape[0]
+ img, emb = send_to_device((img, emb))
+
+ for unet in range(1, trainer.num_unets+1):
+ # Check if this is a unet we are training
+ if unet_training_mask[unet-1]: # Unet index is the unet number - 1
+ loss = trainer.forward(img, image_embed=emb, unet_number=unet)
+ trainer.update(unet_number=unet)
+ losses.append(loss)
+
+ samples_per_sec = (sample - last_sample) / (time.time() - last_time)
+ last_time = time.time()
+ last_sample = sample
+
+ if i % 10 == 0:
+ average_loss = sum(losses) / len(losses)
+ log_data = {
+ "Training loss": average_loss,
+ "Epoch": epoch,
+ "Sample": sample,
+ "Step": i,
+ "Samples per second": samples_per_sec
+ }
+ tracker.log(log_data, step=step, verbose=True)
+ losses = []
+
+ if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
+ last_snapshot = sample
+ # We need to know where the model should be saved
+ save_paths = []
+ if save_latest:
+ save_paths.append("latest.pth")
+ if save_all:
+ save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
+ save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
+ if n_sample_images is not None and n_sample_images > 0:
+ trainer.eval()
+ train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
+ trainer.train()
+ tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
+
+ if epoch_samples is not None and sample >= epoch_samples:
+ break
+
+ trainer.eval()
+ print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
+ with torch.no_grad():
+ sample = 0
+ average_loss = 0
+ start_time = time.time()
+ for i, (img, emb, txt) in enumerate(dataloaders["val"]):
+ sample += img.shape[0]
+ img, emb = send_to_device((img, emb))
+
+ for unet in range(1, len(decoder.unets)+1):
+ loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
+ average_loss += loss
+
+ if i % 10 == 0:
+ print(f"Epoch {epoch}/{epochs} - {sample / (time.time() - start_time):.2f} samples/sec")
+ print(f"Loss: {average_loss / (i+1)}")
+ print("")
+
+ if validation_samples is not None and sample >= validation_samples:
+ break
+ average_loss /= i+1
+ log_data = {
+ "Validation loss": average_loss
+ }
+ tracker.log(log_data, step=step, verbose=True)
+
+ # Compute evaluation metrics
+ trainer.eval()
+ if evaluate_config is not None:
+ print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
+ evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
+ tracker.log(evaluation, step=step, verbose=True)
+
+ # Generate sample images
+ print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
+ test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
+ train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
+ tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
+ tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
+
+ print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
+ # Get the same paths
+ save_paths = []
+ if save_latest:
+ save_paths.append("latest.pth")
+ if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
+ save_paths.append("best.pth")
+ validation_losses.append(average_loss)
+ save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
+
+def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
+ """
+ Creates a tracker of the specified type and initializes special features based on the full config
+ """
+ tracker_config = config["tracker"]
+ init_config = {}
+ init_config["config"] = config.config
+ if tracker_type == "console":
+ tracker = ConsoleTracker(**init_config)
+ elif tracker_type == "wandb":
+ # We need to initialize the resume state here
+ load_config = config["load"]
+ if load_config["source"] == "wandb" and load_config["resume"]:
+ # Then we are resuming the run load_config["run_path"]
+ run_id = config["resume"]["wandb_run_path"].split("/")[-1]
+ init_config["id"] = run_id
+ init_config["resume"] = "must"
+ init_config["entity"] = tracker_config["wandb_entity"]
+ init_config["project"] = tracker_config["wandb_project"]
+ tracker = WandbTracker(data_path)
+ tracker.init(**init_config)
+ else:
+ raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
+ return tracker
+
+def initialize_training(config):
+ # Create the save path
+ if "cuda" in config["train"]["device"]:
+ assert torch.cuda.is_available(), "CUDA is not available"
+ device = torch.device(config["train"]["device"])
+ torch.cuda.set_device(device)
+ all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
+
+ dataloaders = create_dataloaders (
+ available_shards=all_shards,
+ img_preproc = config.get_preprocessing(),
+ train_prop = config["data"]["splits"]["train"],
+ val_prop = config["data"]["splits"]["val"],
+ test_prop = config["data"]["splits"]["test"],
+ n_sample_images=config["train"]["n_sample_images"],
+ **config["data"]
+ )
+
+ decoder = create_decoder(device, config["decoder"], config["unets"])
+ num_parameters = sum(p.numel() for p in decoder.parameters())
+ print(print_ribbon("Loaded Config", repeat=40))
+ print(f"Number of parameters: {num_parameters}")
+
+ tracker = create_tracker(config, **config["tracker"])
+
+ train(dataloaders, decoder,
+ tracker=tracker,
+ inference_device=device,
+ load_config=config["load"],
+ evaluate_config=config["evaluate"],
+ **config["train"],
+ )
+
+
+class TrainDecoderConfig:
+ def __init__(self, config):
+ self.config = self.map_config(config, default_config)
+
+ def map_config(self, config, defaults):
+ """
+ Returns a dictionary containing all config options in the union of config and defaults.
+ If the config value is an array, apply the default value to each element.
+ If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
+ """
+ def _check_option(option, option_config, option_defaults):
+ for key, value in option_defaults.items():
+ if key not in option_config:
+ if value == ConfigField.REQUIRED:
+ raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
+ option_config[key] = value
+
+ for key, value in defaults.items():
+ if key not in config:
+ # Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
+ if value == ConfigField.REQUIRED:
+ raise RuntimeError("Required config value '{}' not supplied".format(key))
+ elif isinstance(value, dict):
+ config[key] = {}
+ elif isinstance(value, list):
+ config[key] = [{}]
+ # Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
+ # If it is a list, then we need to check each element
+ if isinstance(value, list):
+ assert isinstance(config[key], list)
+ for element in config[key]:
+ _check_option(key, element, value[0])
+ elif isinstance(value, dict):
+ _check_option(key, config[key], value)
+ # This object does not support checking
+ return config
+
+ def get_preprocessing(self):
+ """
+ Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
+ """
+ def _get_transformation(transformation_name, **kwargs):
+ if transformation_name == "RandomResizedCrop":
+ return T.RandomResizedCrop(**kwargs)
+ elif transformation_name == "RandomHorizontalFlip":
+ return T.RandomHorizontalFlip()
+ elif transformation_name == "ToTensor":
+ return T.ToTensor()
+
+ transformations = []
+ for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
+ if isinstance(transformation_kwargs, dict):
+ transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
+ else:
+ transformations.append(_get_transformation(transformation_name))
+ return T.Compose(transformations)
+
+ def __getitem__(self, key):
+ return self.config[key]
+
+# Create a simple click command line interface to load the config and start the training
+@click.command()
+@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
+def main(config_file):
+ print("Recalling config from {}".format(config_file))
+ with open(config_file) as f:
+ config = json.load(f)
+ config = TrainDecoderConfig(config)
+ initialize_training(config)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file