diff --git a/README.md b/README.md index 265a7bd..bff3f21 100644 --- a/README.md +++ b/README.md @@ -1076,6 +1076,7 @@ This library would not have gotten to this working state without the help of - [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14 - [x] cross embed layers for downsampling, as an option - [x] use an experimental tracker agnostic setup, as done here +- [x] use pydantic for config drive training - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] train on a toy task, offer in colab @@ -1092,7 +1093,6 @@ This library would not have gotten to this working state without the help of - [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly - [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number) - [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89 -- [ ] use json schemas to manage config fields, start with decoder and move into diffusion prior - think about whether json schema can allow for both config-driven as well as CLI driven training (by constructing the click decorators from the schema) ## Citations diff --git a/configs/README.md b/configs/README.md index 80a942e..1586469 100644 --- a/configs/README.md +++ b/configs/README.md @@ -4,7 +4,7 @@ For more complex configuration, we provide the option of using a configuration f ### Decoder Trainer -The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example). +The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json). **Unets:** diff --git a/configs/decoder_defaults.py b/configs/decoder_defaults.py deleted file mode 100644 index e36cd41..0000000 --- a/configs/decoder_defaults.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -Defines the default values for the decoder config -""" - -from enum import Enum -class ConfigField(Enum): - REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it. - -default_config = { - "unets": ConfigField.REQUIRED, - "decoder": { - "image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet - "image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think - "channels": 3, - "timesteps": 1000, - "loss_type": "l2", - "beta_schedule": "cosine", - "learned_variance": True - }, - "data": { - "webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images - "embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings - "num_workers": 4, - "batch_size": 64, - "start_shard": 0, - "end_shard": 9999999, - "shard_width": 6, - "index_width": 4, - "splits": { - "train": 0.75, - "val": 0.15, - "test": 0.1 - }, - "shuffle_train": True, - "resample_train": False, - "preprocessing": { - "ToTensor": True - } - }, - "train": { - "epochs": 20, - "lr": 1e-4, - "wd": 0.01, - "max_grad_norm": 0.5, - "save_every_n_samples": 100000, - "n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset - "device": "cuda:0", - "epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. - "validation_samples": None, # Same as above but for validation. - "use_ema": True, - "ema_beta": 0.99, - "amp": False, - "save_all": False, # Whether to preserve all checkpoints - "save_latest": True, # Whether to always save the latest checkpoint - "save_best": True, # Whether to save the best checkpoint - "unet_training_mask": None # If None, use all unets - }, - "evaluate": { - "n_evalation_samples": 1000, - "FID": None, - "IS": None, - "KID": None, - "LPIPS": None - }, - "tracker": { - "tracker_type": "console", # Decoder currently supports console and wandb - "data_path": "./models", # The path where files will be saved locally - - "wandb_entity": "", # Only needs to be set if tracker_type is wandb - "wandb_project": "", - - "verbose": False # Whether to print console logging for non-console trackers - }, - "load": { - "source": None, # Supports file and wandb - - "run_path": "", # Used only if source is wandb - "file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb. - - "resume": False # If using wandb, whether to resume the run - } -} diff --git a/configs/train_decoder_config.json.example b/configs/train_decoder_config.json.example deleted file mode 100644 index d2645c1..0000000 --- a/configs/train_decoder_config.json.example +++ /dev/null @@ -1,100 +0,0 @@ -{ - "unets": [ - { - "dim": 128, - "image_embed_dim": 768, - "cond_dim": 64, - "channels": 3, - "dim_mults": [1, 2, 4, 8], - "attn_dim_head": 32, - "attn_heads": 16 - } - ], - "decoder": { - "image_sizes": [64], - "image_size": [64], - "channels": 3, - "timesteps": 1000, - "loss_type": "l2", - "beta_schedule": "cosine", - "learned_variance": true - }, - "data": { - "webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -", - "embeddings_url": "s3://bucket/embeddings/path/", - "num_workers": 4, - "batch_size": 64, - "start_shard": 0, - "end_shard": 9999999, - "shard_width": 6, - "index_width": 4, - "splits": { - "train": 0.75, - "val": 0.15, - "test": 0.1 - }, - "shuffle_train": true, - "resample_train": false, - "preprocessing": { - "RandomResizedCrop": { - "size": [128, 128], - "scale": [0.75, 1.0], - "ratio": [1.0, 1.0] - }, - "ToTensor": true - } - }, - "train": { - "epochs": 20, - "lr": 1e-4, - "wd": 0.01, - "max_grad_norm": 0.5, - "save_every_n_samples": 100000, - "n_sample_images": 6, - "device": "cuda:0", - "epoch_samples": null, - "validation_samples": null, - "use_ema": true, - "ema_beta": 0.99, - "amp": false, - "save_all": false, - "save_latest": true, - "save_best": true, - "unet_training_mask": [true] - }, - "evaluate": { - "n_evalation_samples": 1000, - "FID": { - "feature": 64 - }, - "IS": { - "feature": 64, - "splits": 10 - }, - "KID": { - "feature": 64, - "subset_size": 10 - }, - "LPIPS": { - "net_type": "vgg", - "reduction": "mean" - } - }, - "tracker": { - "tracker_type": "console", - "data_path": "./models", - - "wandb_entity": "", - "wandb_project": "", - - "verbose": false - }, - "load": { - "source": null, - - "run_path": "", - "file_path": "", - - "resume": false - } -} diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index c2757eb..4d86342 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -1,47 +1,111 @@ from torchvision import transforms as T -from configs.decoder_defaults import default_config, ConfigField +from pydantic import BaseModel, validator +from typing import List, Iterable, Optional, Union, Tuple, Dict, Any -class TrainDecoderConfig: - def __init__(self, config): - self.config = self.map_config(config, default_config) +def exists(val): + return val is not None - def map_config(self, config, defaults): - """ - Returns a dictionary containing all config options in the union of config and defaults. - If the config value is an array, apply the default value to each element. - If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config - """ - def _check_option(option, option_config, option_defaults): - for key, value in option_defaults.items(): - if key not in option_config: - if value == ConfigField.REQUIRED: - raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option)) - option_config[key] = value - - for key, value in defaults.items(): - if key not in config: - # Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error - if value == ConfigField.REQUIRED: - raise RuntimeError("Required config value '{}' not supplied".format(key)) - elif isinstance(value, dict): - config[key] = {} - elif isinstance(value, list): - config[key] = [{}] - # Config[key] is now either a dict, list of dicts, or an object that cannot be checked. - # If it is a list, then we need to check each element - if isinstance(value, list): - assert isinstance(config[key], list) - for element in config[key]: - _check_option(key, element, value[0]) - elif isinstance(value, dict): - _check_option(key, config[key], value) - # This object does not support checking - return config +def default(val, d): + return val if exists(val) else d - def get_preprocessing(self): - """ - Takes the preprocessing dictionary and converts it to a composition of torchvision transforms - """ +class UnetConfig(BaseModel): + dim: int + dim_mults: List[int] + image_embed_dim: int = None + cond_dim: int = None + channels: int = 3 + attn_dim_head: int = 32 + attn_heads: int = 16 + + class Config: + extra = "allow" + +class DecoderConfig(BaseModel): + image_size: int = None + image_sizes: Union[List[int], Tuple[int]] = None + channels: int = 3 + timesteps: int = 1000 + loss_type: str = 'l2' + beta_schedule: str = 'cosine' + learned_variance: bool = True + + @validator('image_sizes') + def check_image_sizes(cls, image_sizes, values): + if exists(values.get('image_size')) ^ exists(image_sizes): + return image_sizes + raise ValueError('either image_size or image_sizes is required, but not both') + + class Config: + extra = "allow" + +class DecoderDataConfig(BaseModel): + webdataset_base_url: str # path to a webdataset with jpg images + embeddings_url: str # path to .npy files with embeddings + num_workers: int = 4 + batch_size: int = 64 + start_shard: int = 0 + end_shard: int = 9999999 + shard_width: int = 6 + index_width: int = 4 + splits: Dict[str, float] = { + 'train': 0.75, + 'val': 0.15, + 'test': 0.1 + } + shuffle_train: bool = True + resample_train: bool = False + preprocessing: Dict[str, Any] = {'ToTensor': True} + +class DecoderTrainConfig(BaseModel): + epochs: int = 20 + lr: float = 1e-4 + wd: float = 0.01 + max_grad_norm: float = 0.5 + save_every_n_samples: int = 100000 + n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset + device: str = 'cuda:0' + epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. + validation_samples: int = None # Same as above but for validation. + use_ema: bool = True + ema_beta: float = 0.99 + amp: bool = False + save_all: bool = False # Whether to preserve all checkpoints + save_latest: bool = True # Whether to always save the latest checkpoint + save_best: bool = True # Whether to save the best checkpoint + unet_training_mask: List[bool] = None # If None, use all unets + +class DecoderEvaluateConfig(BaseModel): + n_evaluation_samples: int = 1000 + FID: Dict[str, Any] = None + IS: Dict[str, Any] = None + KID: Dict[str, Any] = None + LPIPS: Dict[str, Any] = None + +class TrackerConfig(BaseModel): + tracker_type: str = 'console' # Decoder currently supports console and wandb + data_path: str = './models' # The path where files will be saved locally + init_config: Dict[str, Any] = None + wandb_entity: str = '' # Only needs to be set if tracker_type is wandb + wandb_project: str = '' + verbose: bool = False # Whether to print console logging for non-console trackers + +class DecoderLoadConfig(BaseModel): + source: str = None # Supports file and wandb + run_path: str = '' # Used only if source is wandb + file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb. + resume: bool = False # If using wandb, whether to resume the run + +class TrainDecoderConfig(BaseModel): + unets: List[UnetConfig] + decoder: DecoderConfig + data: DecoderDataConfig + train: DecoderTrainConfig + evaluate: DecoderEvaluateConfig + tracker: TrackerConfig + load: DecoderLoadConfig + + @property + def img_preproc(self): def _get_transformation(transformation_name, **kwargs): if transformation_name == "RandomResizedCrop": return T.RandomResizedCrop(**kwargs) @@ -50,13 +114,8 @@ class TrainDecoderConfig: elif transformation_name == "ToTensor": return T.ToTensor() - transformations = [] - for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items(): - if isinstance(transformation_kwargs, dict): - transformations.append(_get_transformation(transformation_name, **transformation_kwargs)) - else: - transformations.append(_get_transformation(transformation_name)) - return T.Compose(transformations) - - def __getitem__(self, key): - return self.config[key] + transforms = [] + for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items(): + transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool + transforms.append(_get_transformation(transform_name, **transform_kwargs)) + return T.Compose(transforms) diff --git a/setup.py b/setup.py index 1376534..9d27e9c 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.3.9', + version = '0.4.0', license='MIT', description = 'DALL-E 2', author = 'Phil Wang', @@ -29,10 +29,10 @@ setup( 'einops>=0.4', 'einops-exts>=0.0.3', 'embedding-reader', - 'jsonschema>=4.5.1', 'kornia>=0.5.4', 'numpy', 'pillow', + 'pydantic', 'resize-right>=0.0.2', 'rotary-embedding-torch', 'torch>=1.10', diff --git a/train_decoder.py b/train_decoder.py index 1648e9f..3e4f7d4 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -90,11 +90,11 @@ def create_dataloaders( def create_decoder(device, decoder_config, unets_config): """Creates a sample decoder""" - unets = [Unet(**config) for config in unets_config] + unets = [Unet(**config.dict()) for config in unets_config] decoder = Decoder( unet=unets, - **decoder_config + **decoder_config.dict() ) decoder.to(device=device) @@ -154,13 +154,13 @@ def generate_grid_samples(trainer, examples, text_prepend=""): grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] return grid_images, captions -def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): +def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): """ Computes evaluation metrics for the decoder """ metrics = {} # Prepare the data - examples = get_example_data(dataloader, device, n_evalation_samples) + examples = get_example_data(dataloader, device, n_evaluation_samples) real_images, generated_images, captions = generate_samples(trainer, examples) real_images = torch.stack(real_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) @@ -252,8 +252,8 @@ def train( start_epoch = 0 validation_losses = [] - if exists(load_config) and exists(load_config["source"]): - start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config) + if exists(load_config) and exists(load_config.source): + start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config) trainer.to(device=inference_device) if not exists(unet_training_mask): @@ -386,21 +386,25 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs): """ Creates a tracker of the specified type and initializes special features based on the full config """ - tracker_config = config["tracker"] + tracker_config = config.tracker init_config = {} - init_config["config"] = config.config + + if exists(tracker_config.init_config): + init_config["config"] = tracker_config.init_config + if tracker_type == "console": tracker = ConsoleTracker(**init_config) elif tracker_type == "wandb": # We need to initialize the resume state here - load_config = config["load"] - if load_config["source"] == "wandb" and load_config["resume"]: + load_config = config.load + if load_config.source == "wandb" and load_config.resume: # Then we are resuming the run load_config["run_path"] - run_id = config["resume"]["wandb_run_path"].split("/")[-1] + run_id = load_config.run_path.split("/")[-1] init_config["id"] = run_id init_config["resume"] = "must" - init_config["entity"] = tracker_config["wandb_entity"] - init_config["project"] = tracker_config["wandb_project"] + + init_config["entity"] = tracker_config.wandb_entity + init_config["project"] = tracker_config.wandb_project tracker = WandbTracker(data_path) tracker.init(**init_config) else: @@ -409,35 +413,35 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs): def initialize_training(config): # Create the save path - if "cuda" in config["train"]["device"]: + if "cuda" in config.train.device: assert torch.cuda.is_available(), "CUDA is not available" - device = torch.device(config["train"]["device"]) + device = torch.device(config.train.device) torch.cuda.set_device(device) - all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1)) + all_shards = list(range(config.data.start_shard, config.data.end_shard + 1)) dataloaders = create_dataloaders ( available_shards=all_shards, - img_preproc = config.get_preprocessing(), - train_prop = config["data"]["splits"]["train"], - val_prop = config["data"]["splits"]["val"], - test_prop = config["data"]["splits"]["test"], - n_sample_images=config["train"]["n_sample_images"], - **config["data"] + img_preproc = config.img_preproc, + train_prop = config.data["splits"]["train"], + val_prop = config.data["splits"]["val"], + test_prop = config.data["splits"]["test"], + n_sample_images=config.train.n_sample_images, + **config.data.dict() ) - decoder = create_decoder(device, config["decoder"], config["unets"]) + decoder = create_decoder(device, config.decoder, config.unets) num_parameters = sum(p.numel() for p in decoder.parameters()) print(print_ribbon("Loaded Config", repeat=40)) print(f"Number of parameters: {num_parameters}") - tracker = create_tracker(config, **config["tracker"]) + tracker = create_tracker(config, **config.tracker.dict()) train(dataloaders, decoder, tracker=tracker, inference_device=device, - load_config=config["load"], - evaluate_config=config["evaluate"], - **config["train"], + load_config=config.load, + evaluate_config=config.evaluate, + **config.train.dict(), ) # Create a simple click command line interface to load the config and start the training @@ -447,7 +451,7 @@ def main(config_file): print("Recalling config from {}".format(config_file)) with open(config_file) as f: config = json.load(f) - config = TrainDecoderConfig(config) + config = TrainDecoderConfig(**config) initialize_training(config)