diff --git a/.gitignore b/.gitignore index 55301b1..41f11cf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,12 @@ # default experiment tracker data .tracker-data/ +# Configuration Files +configs/* +!configs/*.example +!configs/*_defaults.py +!configs/README.md + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/configs/README.md b/configs/README.md new file mode 100644 index 0000000..80a942e --- /dev/null +++ b/configs/README.md @@ -0,0 +1,109 @@ +## DALLE2 Training Configurations + +For more complex configuration, we provide the option of using a configuration file instead of command line arguments. + +### Decoder Trainer + +The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example). + +**Unets:** + +Each member of this array defines a single unet that will be added to the decoder. +| Option | Required | Default | Description | +| ------ | -------- | ------- | ----------- | +| `dim` | Yes | N/A | The starting channels of the unet. | +| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. | +| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. | + +Any parameter from the `Unet` constructor can also be given here. + +**Decoder:** + +Defines the configuration options for the decoder model. The unets defined above will automatically be inserted. +| Option | Required | Default | Description | +| ------ | -------- | ------- | ----------- | +| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. | +| `image_size` | Yes | N/A | Not used. Can be any number. | +| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. | +| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. | +| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. | +| `learned_variance` | No | `True` | Whether to learn the variance. | + +Any parameter from the `Decoder` constructor can also be given here. + +**Data:** + +Settings for creation of the dataloaders. +| Option | Required | Default | Description | +| ------ | -------- | ------- | ----------- | +| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. | +| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. | +| `num_workers` | No | `4` | The number of workers used in the dataloader. | +| `batch_size` | No | `64` | The batch size. | +| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. | +| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. | +| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. | +| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. | +| `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. | +| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. | +| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. | +| `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. | + +[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`. + +[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5. + +[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4. + +**Train:** + +Settings for controlling the training hyperparameters. +| Option | Required | Default | Description | +| ------ | -------- | ------- | ----------- | +| `epochs` | No | `20` | The number of epochs in the training run. | +| `lr` | No | `1e-4` | The learning rate. | +| `wd` | No | `0.01` | The weight decay. | +| `max_grad_norm`| No | `0.5` | The grad norm clipping. | +| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. | +| `device` | No | `cuda:0` | The device to train on. | +| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. | +| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. | +| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. | +| `ema_beta` | No | `0.99` | The ema coefficient. | +| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. | +| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. | +| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. | +| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. | + +**Evaluate:** + +Defines which evaluation metrics will be used to test the model. +Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked. +| Option | Required | Default | Description | +| ------ | -------- | ------- | ----------- | +| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. | +| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric. +| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric. +| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. | +| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. | + +**Tracker:** + +Selects which tracker to use and configures it. +| Option | Required | Default | Description | +| ------ | -------- | ------- | ----------- | +| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. | +| `data_path` | No | `./models` | Where the tracker will store local data. | +| `verbose` | No | `False` | Enables console logging for non-console trackers. | + +Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py). + +**Load:** + +Selects where to load a pretrained model from. +| Option | Required | Default | Description | +| ------ | -------- | ------- | ----------- | +| `source` | No | `None` | Supports `file` or `wandb`. | +| `resume` | No | `False` | If the tracker support resuming the run, resume it. | + +Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py). diff --git a/configs/decoder_defaults.py b/configs/decoder_defaults.py new file mode 100644 index 0000000..e36cd41 --- /dev/null +++ b/configs/decoder_defaults.py @@ -0,0 +1,82 @@ +""" +Defines the default values for the decoder config +""" + +from enum import Enum +class ConfigField(Enum): + REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it. + +default_config = { + "unets": ConfigField.REQUIRED, + "decoder": { + "image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet + "image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think + "channels": 3, + "timesteps": 1000, + "loss_type": "l2", + "beta_schedule": "cosine", + "learned_variance": True + }, + "data": { + "webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images + "embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings + "num_workers": 4, + "batch_size": 64, + "start_shard": 0, + "end_shard": 9999999, + "shard_width": 6, + "index_width": 4, + "splits": { + "train": 0.75, + "val": 0.15, + "test": 0.1 + }, + "shuffle_train": True, + "resample_train": False, + "preprocessing": { + "ToTensor": True + } + }, + "train": { + "epochs": 20, + "lr": 1e-4, + "wd": 0.01, + "max_grad_norm": 0.5, + "save_every_n_samples": 100000, + "n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset + "device": "cuda:0", + "epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. + "validation_samples": None, # Same as above but for validation. + "use_ema": True, + "ema_beta": 0.99, + "amp": False, + "save_all": False, # Whether to preserve all checkpoints + "save_latest": True, # Whether to always save the latest checkpoint + "save_best": True, # Whether to save the best checkpoint + "unet_training_mask": None # If None, use all unets + }, + "evaluate": { + "n_evalation_samples": 1000, + "FID": None, + "IS": None, + "KID": None, + "LPIPS": None + }, + "tracker": { + "tracker_type": "console", # Decoder currently supports console and wandb + "data_path": "./models", # The path where files will be saved locally + + "wandb_entity": "", # Only needs to be set if tracker_type is wandb + "wandb_project": "", + + "verbose": False # Whether to print console logging for non-console trackers + }, + "load": { + "source": None, # Supports file and wandb + + "run_path": "", # Used only if source is wandb + "file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb. + + "resume": False # If using wandb, whether to resume the run + } +} diff --git a/configs/train_decoder_config.json.example b/configs/train_decoder_config.json.example new file mode 100644 index 0000000..d2645c1 --- /dev/null +++ b/configs/train_decoder_config.json.example @@ -0,0 +1,100 @@ +{ + "unets": [ + { + "dim": 128, + "image_embed_dim": 768, + "cond_dim": 64, + "channels": 3, + "dim_mults": [1, 2, 4, 8], + "attn_dim_head": 32, + "attn_heads": 16 + } + ], + "decoder": { + "image_sizes": [64], + "image_size": [64], + "channels": 3, + "timesteps": 1000, + "loss_type": "l2", + "beta_schedule": "cosine", + "learned_variance": true + }, + "data": { + "webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -", + "embeddings_url": "s3://bucket/embeddings/path/", + "num_workers": 4, + "batch_size": 64, + "start_shard": 0, + "end_shard": 9999999, + "shard_width": 6, + "index_width": 4, + "splits": { + "train": 0.75, + "val": 0.15, + "test": 0.1 + }, + "shuffle_train": true, + "resample_train": false, + "preprocessing": { + "RandomResizedCrop": { + "size": [128, 128], + "scale": [0.75, 1.0], + "ratio": [1.0, 1.0] + }, + "ToTensor": true + } + }, + "train": { + "epochs": 20, + "lr": 1e-4, + "wd": 0.01, + "max_grad_norm": 0.5, + "save_every_n_samples": 100000, + "n_sample_images": 6, + "device": "cuda:0", + "epoch_samples": null, + "validation_samples": null, + "use_ema": true, + "ema_beta": 0.99, + "amp": false, + "save_all": false, + "save_latest": true, + "save_best": true, + "unet_training_mask": [true] + }, + "evaluate": { + "n_evalation_samples": 1000, + "FID": { + "feature": 64 + }, + "IS": { + "feature": 64, + "splits": 10 + }, + "KID": { + "feature": 64, + "subset_size": 10 + }, + "LPIPS": { + "net_type": "vgg", + "reduction": "mean" + } + }, + "tracker": { + "tracker_type": "console", + "data_path": "./models", + + "wandb_entity": "", + "wandb_project": "", + + "verbose": false + }, + "load": { + "source": null, + + "run_path": "", + "file_path": "", + + "resume": false + } +} diff --git a/dalle2_pytorch/optimizer.py b/dalle2_pytorch/optimizer.py index ad1431d..ee366d8 100644 --- a/dalle2_pytorch/optimizer.py +++ b/dalle2_pytorch/optimizer.py @@ -11,7 +11,8 @@ def get_optimizer( wd = 1e-2, betas = (0.9, 0.999), eps = 1e-8, - filter_by_requires_grad = False + filter_by_requires_grad = False, + **kwargs ): if filter_by_requires_grad: params = list(filter(lambda t: t.requires_grad, params)) diff --git a/setup.py b/setup.py index 4a191d7..aa02390 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,8 @@ setup( 'x-clip>=0.4.4', 'youtokentome', 'webdataset>=0.2.5', - 'fsspec>=2022.1.0' + 'fsspec>=2022.1.0', + 'torchmetrics[image]>=0.8.0' ], classifiers=[ 'Development Status :: 4 - Beta', diff --git a/train_decoder.py b/train_decoder.py new file mode 100644 index 0000000..3c91c36 --- /dev/null +++ b/train_decoder.py @@ -0,0 +1,500 @@ +from dalle2_pytorch import Unet, Decoder +from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon +from dalle2_pytorch.dataloaders import create_image_embedding_dataloader +from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker +from configs.decoder_defaults import default_config, ConfigField +import time +import json +import torchvision +from torchvision import transforms as T +import torch +from torchmetrics.image.fid import FrechetInceptionDistance +from torchmetrics.image.inception import InceptionScore +from torchmetrics.image.kid import KernelInceptionDistance +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +import webdataset as wds +import click + + +def create_dataloaders( + available_shards, + webdataset_base_url, + embeddings_url, + shard_width=6, + num_workers=4, + batch_size=32, + n_sample_images=6, + shuffle_train=True, + resample_train=False, + img_preproc = None, + index_width=4, + train_prop = 0.75, + val_prop = 0.15, + test_prop = 0.10, + **kwargs +): + """ + Randomly splits the available shards into train, val, and test sets and returns a dataloader for each + """ + assert train_prop + test_prop + val_prop == 1 + num_train = round(train_prop*len(available_shards)) + num_test = round(test_prop*len(available_shards)) + num_val = len(available_shards) - num_train - num_test + assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}" + train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0)) + + # The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename. + train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split] + test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split] + val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split] + + create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader( + tar_url=tar_urls, + num_workers=num_workers, + batch_size=batch_size if not for_sampling else n_sample_images, + embeddings_url=embeddings_url, + index_width=index_width, + shuffle_num = None, + extra_keys= ["txt"] if with_text else [], + shuffle_shards = shuffle, + resample_shards = resample, + img_preproc=img_preproc, + handler=wds.handlers.warn_and_continue + ) + + train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train) + train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True) + val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True) + test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True) + test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True) + return { + "train": train_dataloader, + "train_sampling": train_sampling_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "test_sampling": test_sampling_dataloader + } + + +def create_decoder(device, decoder_config, unets_config): + """Creates a sample decoder""" + unets = [] + for i in range(0, len(unets_config)): + unets.append(Unet( + **unets_config[i] + )) + + decoder = Decoder( + unet=tuple(unets), # Must be tuple because of cast_tuple + **decoder_config + ) + decoder.to(device=device) + + return decoder + +def get_dataset_keys(dataloader): + """ + It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it. + """ + # If the dataloader is actually a WebLoader, we need to extract the real dataloader + if isinstance(dataloader, wds.WebLoader): + dataloader = dataloader.pipeline[0] + return dataloader.dataset.key_map + +def get_example_data(dataloader, device, n=5): + """ + Samples the dataloader and returns a zipped list of examples + """ + images = [] + embeddings = [] + captions = [] + dataset_keys = get_dataset_keys(dataloader) + has_caption = "txt" in dataset_keys + for data in dataloader: + if has_caption: + img, emb, txt = data + else: + img, emb = data + txt = [""] * emb.shape[0] + img = img.to(device=device, dtype=torch.float) + emb = emb.to(device=device, dtype=torch.float) + images.extend(list(img)) + embeddings.extend(list(emb)) + captions.extend(list(txt)) + if len(images) >= n: + break + print("Generated {} examples".format(len(images))) + return list(zip(images[:n], embeddings[:n], captions[:n])) + +def generate_samples(trainer, example_data, text_prepend=""): + """ + Takes example data and generates images from the embeddings + Returns three lists: real images, generated images, and captions + """ + real_images, embeddings, txts = zip(*example_data) + embeddings_tensor = torch.stack(embeddings) + samples = trainer.sample(embeddings_tensor) + generated_images = list(samples) + captions = [text_prepend + txt for txt in txts] + return real_images, generated_images, captions + +def generate_grid_samples(trainer, examples, text_prepend=""): + """ + Generates samples and uses torchvision to put them in a side by side grid for easy viewing + """ + real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend) + grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] + return grid_images, captions + +def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): + """ + Computes evaluation metrics for the decoder + """ + metrics = {} + # Prepare the data + examples = get_example_data(dataloader, device, n_evalation_samples) + real_images, generated_images, captions = generate_samples(trainer, examples) + real_images = torch.stack(real_images).to(device=device, dtype=torch.float) + generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) + # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 + int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8) + int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8) + if FID is not None: + fid = FrechetInceptionDistance(**FID) + fid.to(device=device) + fid.update(int_real_images, real=True) + fid.update(int_generated_images, real=False) + metrics["FID"] = fid.compute().item() + if IS is not None: + inception = InceptionScore(**IS) + inception.to(device=device) + inception.update(int_real_images) + is_mean, is_std = inception.compute() + metrics["IS_mean"] = is_mean.item() + metrics["IS_std"] = is_std.item() + if KID is not None: + kernel_inception = KernelInceptionDistance(**KID) + kernel_inception.to(device=device) + kernel_inception.update(int_real_images, real=True) + kernel_inception.update(int_generated_images, real=False) + kid_mean, kid_std = kernel_inception.compute() + metrics["KID_mean"] = kid_mean.item() + metrics["KID_std"] = kid_std.item() + if LPIPS is not None: + # Convert from [0, 1] to [-1, 1] + renorm_real_images = real_images.mul(2).sub(1) + renorm_generated_images = generated_images.mul(2).sub(1) + lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS) + lpips.to(device=device) + lpips.update(renorm_real_images, renorm_generated_images) + metrics["LPIPS"] = lpips.compute().item() + return metrics + +def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths): + """ + Logs the model with an appropriate method depending on the tracker + """ + if isinstance(relative_paths, str): + relative_paths = [relative_paths] + trainer_state_dict = {} + trainer_state_dict["trainer"] = trainer.state_dict() + trainer_state_dict['epoch'] = epoch + trainer_state_dict['step'] = step + trainer_state_dict['validation_losses'] = validation_losses + for relative_path in relative_paths: + tracker.save_state_dict(trainer_state_dict, relative_path) + +def recall_trainer(tracker, trainer, recall_source=None, **load_config): + """ + Loads the model with an appropriate method depending on the tracker + """ + print(print_ribbon(f"Loading model from {recall_source}")) + state_dict = tracker.recall_state_dict(recall_source, **load_config) + trainer.load_state_dict(state_dict["trainer"]) + print("Model loaded") + return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"] + +def train( + dataloaders, + decoder, + tracker, + inference_device, + load_config=None, + evaluate_config=None, + epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch + validation_samples = None, + epochs = 20, + n_sample_images = 5, + save_every_n_samples = 100000, + save_all=False, + save_latest=True, + save_best=True, + unet_training_mask=None, + **kwargs +): + """ + Trains a decoder on a dataset. + """ + trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here + decoder, + **kwargs + ) + # Set up starting model and parameters based on a recalled state dict + start_step = 0 + start_epoch = 0 + validation_losses = [] + + if load_config is not None and load_config["source"] is not None: + start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config) + trainer.to(device=inference_device) + + if unet_training_mask is None: + # Then the unet mask should be true for all unets in the decoder + unet_training_mask = [True] * trainer.num_unets + assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}" + + print(print_ribbon("Generating Example Data", repeat=40)) + print("This can take a while to load the shard lists...") + train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images) + test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images) + + send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr] + step = start_step + for epoch in range(start_epoch, epochs): + print(print_ribbon(f"Starting epoch {epoch}", repeat=40)) + trainer.train() + + sample = 0 + last_sample = 0 + last_snapshot = 0 + last_time = time.time() + losses = [] + for i, (img, emb) in enumerate(dataloaders["train"]): + step += 1 + sample += img.shape[0] + img, emb = send_to_device((img, emb)) + + for unet in range(1, trainer.num_unets+1): + # Check if this is a unet we are training + if unet_training_mask[unet-1]: # Unet index is the unet number - 1 + loss = trainer.forward(img, image_embed=emb, unet_number=unet) + trainer.update(unet_number=unet) + losses.append(loss) + + samples_per_sec = (sample - last_sample) / (time.time() - last_time) + last_time = time.time() + last_sample = sample + + if i % 10 == 0: + average_loss = sum(losses) / len(losses) + log_data = { + "Training loss": average_loss, + "Epoch": epoch, + "Sample": sample, + "Step": i, + "Samples per second": samples_per_sec + } + tracker.log(log_data, step=step, verbose=True) + losses = [] + + if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope + last_snapshot = sample + # We need to know where the model should be saved + save_paths = [] + if save_latest: + save_paths.append("latest.pth") + if save_all: + save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth") + save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths) + if n_sample_images is not None and n_sample_images > 0: + trainer.eval() + train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") + trainer.train() + tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step) + + if epoch_samples is not None and sample >= epoch_samples: + break + + trainer.eval() + print(print_ribbon(f"Starting Validation {epoch}", repeat=40)) + with torch.no_grad(): + sample = 0 + average_loss = 0 + start_time = time.time() + for i, (img, emb, txt) in enumerate(dataloaders["val"]): + sample += img.shape[0] + img, emb = send_to_device((img, emb)) + + for unet in range(1, len(decoder.unets)+1): + loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet) + average_loss += loss + + if i % 10 == 0: + print(f"Epoch {epoch}/{epochs} - {sample / (time.time() - start_time):.2f} samples/sec") + print(f"Loss: {average_loss / (i+1)}") + print("") + + if validation_samples is not None and sample >= validation_samples: + break + average_loss /= i+1 + log_data = { + "Validation loss": average_loss + } + tracker.log(log_data, step=step, verbose=True) + + # Compute evaluation metrics + trainer.eval() + if evaluate_config is not None: + print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) + evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config) + tracker.log(evaluation, step=step, verbose=True) + + # Generate sample images + print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) + test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ") + train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") + tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step) + tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step) + + print(print_ribbon(f"Starting Saving {epoch}", repeat=40)) + # Get the same paths + save_paths = [] + if save_latest: + save_paths.append("latest.pth") + if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)): + save_paths.append("best.pth") + validation_losses.append(average_loss) + save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths) + +def create_tracker(config, tracker_type=None, data_path=None, **kwargs): + """ + Creates a tracker of the specified type and initializes special features based on the full config + """ + tracker_config = config["tracker"] + init_config = {} + init_config["config"] = config.config + if tracker_type == "console": + tracker = ConsoleTracker(**init_config) + elif tracker_type == "wandb": + # We need to initialize the resume state here + load_config = config["load"] + if load_config["source"] == "wandb" and load_config["resume"]: + # Then we are resuming the run load_config["run_path"] + run_id = config["resume"]["wandb_run_path"].split("/")[-1] + init_config["id"] = run_id + init_config["resume"] = "must" + init_config["entity"] = tracker_config["wandb_entity"] + init_config["project"] = tracker_config["wandb_project"] + tracker = WandbTracker(data_path) + tracker.init(**init_config) + else: + raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer") + return tracker + +def initialize_training(config): + # Create the save path + if "cuda" in config["train"]["device"]: + assert torch.cuda.is_available(), "CUDA is not available" + device = torch.device(config["train"]["device"]) + torch.cuda.set_device(device) + all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1)) + + dataloaders = create_dataloaders ( + available_shards=all_shards, + img_preproc = config.get_preprocessing(), + train_prop = config["data"]["splits"]["train"], + val_prop = config["data"]["splits"]["val"], + test_prop = config["data"]["splits"]["test"], + n_sample_images=config["train"]["n_sample_images"], + **config["data"] + ) + + decoder = create_decoder(device, config["decoder"], config["unets"]) + num_parameters = sum(p.numel() for p in decoder.parameters()) + print(print_ribbon("Loaded Config", repeat=40)) + print(f"Number of parameters: {num_parameters}") + + tracker = create_tracker(config, **config["tracker"]) + + train(dataloaders, decoder, + tracker=tracker, + inference_device=device, + load_config=config["load"], + evaluate_config=config["evaluate"], + **config["train"], + ) + + +class TrainDecoderConfig: + def __init__(self, config): + self.config = self.map_config(config, default_config) + + def map_config(self, config, defaults): + """ + Returns a dictionary containing all config options in the union of config and defaults. + If the config value is an array, apply the default value to each element. + If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config + """ + def _check_option(option, option_config, option_defaults): + for key, value in option_defaults.items(): + if key not in option_config: + if value == ConfigField.REQUIRED: + raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option)) + option_config[key] = value + + for key, value in defaults.items(): + if key not in config: + # Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error + if value == ConfigField.REQUIRED: + raise RuntimeError("Required config value '{}' not supplied".format(key)) + elif isinstance(value, dict): + config[key] = {} + elif isinstance(value, list): + config[key] = [{}] + # Config[key] is now either a dict, list of dicts, or an object that cannot be checked. + # If it is a list, then we need to check each element + if isinstance(value, list): + assert isinstance(config[key], list) + for element in config[key]: + _check_option(key, element, value[0]) + elif isinstance(value, dict): + _check_option(key, config[key], value) + # This object does not support checking + return config + + def get_preprocessing(self): + """ + Takes the preprocessing dictionary and converts it to a composition of torchvision transforms + """ + def _get_transformation(transformation_name, **kwargs): + if transformation_name == "RandomResizedCrop": + return T.RandomResizedCrop(**kwargs) + elif transformation_name == "RandomHorizontalFlip": + return T.RandomHorizontalFlip() + elif transformation_name == "ToTensor": + return T.ToTensor() + + transformations = [] + for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items(): + if isinstance(transformation_kwargs, dict): + transformations.append(_get_transformation(transformation_name, **transformation_kwargs)) + else: + transformations.append(_get_transformation(transformation_name)) + return T.Compose(transformations) + + def __getitem__(self, key): + return self.config[key] + +# Create a simple click command line interface to load the config and start the training +@click.command() +@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file") +def main(config_file): + print("Recalling config from {}".format(config_file)) + with open(config_file) as f: + config = json.load(f) + config = TrainDecoderConfig(config) + initialize_training(config) + + +if __name__ == "__main__": + main() \ No newline at end of file