diff --git a/README.md b/README.md
index 265a7bd..bff3f21 100644
--- a/README.md
+++ b/README.md
@@ -1076,6 +1076,7 @@ This library would not have gotten to this working state without the help of
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done here
+- [x] use pydantic for config drive training
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
@@ -1092,7 +1093,6 @@ This library would not have gotten to this working state without the help of
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
-- [ ] use json schemas to manage config fields, start with decoder and move into diffusion prior - think about whether json schema can allow for both config-driven as well as CLI driven training (by constructing the click decorators from the schema)
## Citations
diff --git a/configs/README.md b/configs/README.md
index 80a942e..1586469 100644
--- a/configs/README.md
+++ b/configs/README.md
@@ -4,7 +4,7 @@ For more complex configuration, we provide the option of using a configuration f
### Decoder Trainer
-The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example).
+The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
**Unets:**
diff --git a/configs/decoder_defaults.py b/configs/decoder_defaults.py
deleted file mode 100644
index e36cd41..0000000
--- a/configs/decoder_defaults.py
+++ /dev/null
@@ -1,82 +0,0 @@
-"""
-Defines the default values for the decoder config
-"""
-
-from enum import Enum
-class ConfigField(Enum):
- REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it.
-
-default_config = {
- "unets": ConfigField.REQUIRED,
- "decoder": {
- "image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet
- "image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think
- "channels": 3,
- "timesteps": 1000,
- "loss_type": "l2",
- "beta_schedule": "cosine",
- "learned_variance": True
- },
- "data": {
- "webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images
- "embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings
- "num_workers": 4,
- "batch_size": 64,
- "start_shard": 0,
- "end_shard": 9999999,
- "shard_width": 6,
- "index_width": 4,
- "splits": {
- "train": 0.75,
- "val": 0.15,
- "test": 0.1
- },
- "shuffle_train": True,
- "resample_train": False,
- "preprocessing": {
- "ToTensor": True
- }
- },
- "train": {
- "epochs": 20,
- "lr": 1e-4,
- "wd": 0.01,
- "max_grad_norm": 0.5,
- "save_every_n_samples": 100000,
- "n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset
- "device": "cuda:0",
- "epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
- "validation_samples": None, # Same as above but for validation.
- "use_ema": True,
- "ema_beta": 0.99,
- "amp": False,
- "save_all": False, # Whether to preserve all checkpoints
- "save_latest": True, # Whether to always save the latest checkpoint
- "save_best": True, # Whether to save the best checkpoint
- "unet_training_mask": None # If None, use all unets
- },
- "evaluate": {
- "n_evalation_samples": 1000,
- "FID": None,
- "IS": None,
- "KID": None,
- "LPIPS": None
- },
- "tracker": {
- "tracker_type": "console", # Decoder currently supports console and wandb
- "data_path": "./models", # The path where files will be saved locally
-
- "wandb_entity": "", # Only needs to be set if tracker_type is wandb
- "wandb_project": "",
-
- "verbose": False # Whether to print console logging for non-console trackers
- },
- "load": {
- "source": None, # Supports file and wandb
-
- "run_path": "", # Used only if source is wandb
- "file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
-
- "resume": False # If using wandb, whether to resume the run
- }
-}
diff --git a/configs/train_decoder_config.json.example b/configs/train_decoder_config.json.example
deleted file mode 100644
index d2645c1..0000000
--- a/configs/train_decoder_config.json.example
+++ /dev/null
@@ -1,100 +0,0 @@
-{
- "unets": [
- {
- "dim": 128,
- "image_embed_dim": 768,
- "cond_dim": 64,
- "channels": 3,
- "dim_mults": [1, 2, 4, 8],
- "attn_dim_head": 32,
- "attn_heads": 16
- }
- ],
- "decoder": {
- "image_sizes": [64],
- "image_size": [64],
- "channels": 3,
- "timesteps": 1000,
- "loss_type": "l2",
- "beta_schedule": "cosine",
- "learned_variance": true
- },
- "data": {
- "webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
- "embeddings_url": "s3://bucket/embeddings/path/",
- "num_workers": 4,
- "batch_size": 64,
- "start_shard": 0,
- "end_shard": 9999999,
- "shard_width": 6,
- "index_width": 4,
- "splits": {
- "train": 0.75,
- "val": 0.15,
- "test": 0.1
- },
- "shuffle_train": true,
- "resample_train": false,
- "preprocessing": {
- "RandomResizedCrop": {
- "size": [128, 128],
- "scale": [0.75, 1.0],
- "ratio": [1.0, 1.0]
- },
- "ToTensor": true
- }
- },
- "train": {
- "epochs": 20,
- "lr": 1e-4,
- "wd": 0.01,
- "max_grad_norm": 0.5,
- "save_every_n_samples": 100000,
- "n_sample_images": 6,
- "device": "cuda:0",
- "epoch_samples": null,
- "validation_samples": null,
- "use_ema": true,
- "ema_beta": 0.99,
- "amp": false,
- "save_all": false,
- "save_latest": true,
- "save_best": true,
- "unet_training_mask": [true]
- },
- "evaluate": {
- "n_evalation_samples": 1000,
- "FID": {
- "feature": 64
- },
- "IS": {
- "feature": 64,
- "splits": 10
- },
- "KID": {
- "feature": 64,
- "subset_size": 10
- },
- "LPIPS": {
- "net_type": "vgg",
- "reduction": "mean"
- }
- },
- "tracker": {
- "tracker_type": "console",
- "data_path": "./models",
-
- "wandb_entity": "",
- "wandb_project": "",
-
- "verbose": false
- },
- "load": {
- "source": null,
-
- "run_path": "",
- "file_path": "",
-
- "resume": false
- }
-}
diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py
index c2757eb..4d86342 100644
--- a/dalle2_pytorch/train_configs.py
+++ b/dalle2_pytorch/train_configs.py
@@ -1,47 +1,111 @@
from torchvision import transforms as T
-from configs.decoder_defaults import default_config, ConfigField
+from pydantic import BaseModel, validator
+from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
-class TrainDecoderConfig:
- def __init__(self, config):
- self.config = self.map_config(config, default_config)
+def exists(val):
+ return val is not None
- def map_config(self, config, defaults):
- """
- Returns a dictionary containing all config options in the union of config and defaults.
- If the config value is an array, apply the default value to each element.
- If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
- """
- def _check_option(option, option_config, option_defaults):
- for key, value in option_defaults.items():
- if key not in option_config:
- if value == ConfigField.REQUIRED:
- raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
- option_config[key] = value
-
- for key, value in defaults.items():
- if key not in config:
- # Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
- if value == ConfigField.REQUIRED:
- raise RuntimeError("Required config value '{}' not supplied".format(key))
- elif isinstance(value, dict):
- config[key] = {}
- elif isinstance(value, list):
- config[key] = [{}]
- # Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
- # If it is a list, then we need to check each element
- if isinstance(value, list):
- assert isinstance(config[key], list)
- for element in config[key]:
- _check_option(key, element, value[0])
- elif isinstance(value, dict):
- _check_option(key, config[key], value)
- # This object does not support checking
- return config
+def default(val, d):
+ return val if exists(val) else d
- def get_preprocessing(self):
- """
- Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
- """
+class UnetConfig(BaseModel):
+ dim: int
+ dim_mults: List[int]
+ image_embed_dim: int = None
+ cond_dim: int = None
+ channels: int = 3
+ attn_dim_head: int = 32
+ attn_heads: int = 16
+
+ class Config:
+ extra = "allow"
+
+class DecoderConfig(BaseModel):
+ image_size: int = None
+ image_sizes: Union[List[int], Tuple[int]] = None
+ channels: int = 3
+ timesteps: int = 1000
+ loss_type: str = 'l2'
+ beta_schedule: str = 'cosine'
+ learned_variance: bool = True
+
+ @validator('image_sizes')
+ def check_image_sizes(cls, image_sizes, values):
+ if exists(values.get('image_size')) ^ exists(image_sizes):
+ return image_sizes
+ raise ValueError('either image_size or image_sizes is required, but not both')
+
+ class Config:
+ extra = "allow"
+
+class DecoderDataConfig(BaseModel):
+ webdataset_base_url: str # path to a webdataset with jpg images
+ embeddings_url: str # path to .npy files with embeddings
+ num_workers: int = 4
+ batch_size: int = 64
+ start_shard: int = 0
+ end_shard: int = 9999999
+ shard_width: int = 6
+ index_width: int = 4
+ splits: Dict[str, float] = {
+ 'train': 0.75,
+ 'val': 0.15,
+ 'test': 0.1
+ }
+ shuffle_train: bool = True
+ resample_train: bool = False
+ preprocessing: Dict[str, Any] = {'ToTensor': True}
+
+class DecoderTrainConfig(BaseModel):
+ epochs: int = 20
+ lr: float = 1e-4
+ wd: float = 0.01
+ max_grad_norm: float = 0.5
+ save_every_n_samples: int = 100000
+ n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
+ device: str = 'cuda:0'
+ epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
+ validation_samples: int = None # Same as above but for validation.
+ use_ema: bool = True
+ ema_beta: float = 0.99
+ amp: bool = False
+ save_all: bool = False # Whether to preserve all checkpoints
+ save_latest: bool = True # Whether to always save the latest checkpoint
+ save_best: bool = True # Whether to save the best checkpoint
+ unet_training_mask: List[bool] = None # If None, use all unets
+
+class DecoderEvaluateConfig(BaseModel):
+ n_evaluation_samples: int = 1000
+ FID: Dict[str, Any] = None
+ IS: Dict[str, Any] = None
+ KID: Dict[str, Any] = None
+ LPIPS: Dict[str, Any] = None
+
+class TrackerConfig(BaseModel):
+ tracker_type: str = 'console' # Decoder currently supports console and wandb
+ data_path: str = './models' # The path where files will be saved locally
+ init_config: Dict[str, Any] = None
+ wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
+ wandb_project: str = ''
+ verbose: bool = False # Whether to print console logging for non-console trackers
+
+class DecoderLoadConfig(BaseModel):
+ source: str = None # Supports file and wandb
+ run_path: str = '' # Used only if source is wandb
+ file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
+ resume: bool = False # If using wandb, whether to resume the run
+
+class TrainDecoderConfig(BaseModel):
+ unets: List[UnetConfig]
+ decoder: DecoderConfig
+ data: DecoderDataConfig
+ train: DecoderTrainConfig
+ evaluate: DecoderEvaluateConfig
+ tracker: TrackerConfig
+ load: DecoderLoadConfig
+
+ @property
+ def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
@@ -50,13 +114,8 @@ class TrainDecoderConfig:
elif transformation_name == "ToTensor":
return T.ToTensor()
- transformations = []
- for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
- if isinstance(transformation_kwargs, dict):
- transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
- else:
- transformations.append(_get_transformation(transformation_name))
- return T.Compose(transformations)
-
- def __getitem__(self, key):
- return self.config[key]
+ transforms = []
+ for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
+ transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
+ transforms.append(_get_transformation(transform_name, **transform_kwargs))
+ return T.Compose(transforms)
diff --git a/setup.py b/setup.py
index 1376534..9d27e9c 100644
--- a/setup.py
+++ b/setup.py
@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
- version = '0.3.9',
+ version = '0.4.0',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -29,10 +29,10 @@ setup(
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
- 'jsonschema>=4.5.1',
'kornia>=0.5.4',
'numpy',
'pillow',
+ 'pydantic',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',
diff --git a/train_decoder.py b/train_decoder.py
index 1648e9f..3e4f7d4 100644
--- a/train_decoder.py
+++ b/train_decoder.py
@@ -90,11 +90,11 @@ def create_dataloaders(
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
- unets = [Unet(**config) for config in unets_config]
+ unets = [Unet(**config.dict()) for config in unets_config]
decoder = Decoder(
unet=unets,
- **decoder_config
+ **decoder_config.dict()
)
decoder.to(device=device)
@@ -154,13 +154,13 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
-def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
+def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
metrics = {}
# Prepare the data
- examples = get_example_data(dataloader, device, n_evalation_samples)
+ examples = get_example_data(dataloader, device, n_evaluation_samples)
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
@@ -252,8 +252,8 @@ def train(
start_epoch = 0
validation_losses = []
- if exists(load_config) and exists(load_config["source"]):
- start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
+ if exists(load_config) and exists(load_config.source):
+ start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
trainer.to(device=inference_device)
if not exists(unet_training_mask):
@@ -386,21 +386,25 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
- tracker_config = config["tracker"]
+ tracker_config = config.tracker
init_config = {}
- init_config["config"] = config.config
+
+ if exists(tracker_config.init_config):
+ init_config["config"] = tracker_config.init_config
+
if tracker_type == "console":
tracker = ConsoleTracker(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
- load_config = config["load"]
- if load_config["source"] == "wandb" and load_config["resume"]:
+ load_config = config.load
+ if load_config.source == "wandb" and load_config.resume:
# Then we are resuming the run load_config["run_path"]
- run_id = config["resume"]["wandb_run_path"].split("/")[-1]
+ run_id = load_config.run_path.split("/")[-1]
init_config["id"] = run_id
init_config["resume"] = "must"
- init_config["entity"] = tracker_config["wandb_entity"]
- init_config["project"] = tracker_config["wandb_project"]
+
+ init_config["entity"] = tracker_config.wandb_entity
+ init_config["project"] = tracker_config.wandb_project
tracker = WandbTracker(data_path)
tracker.init(**init_config)
else:
@@ -409,35 +413,35 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
def initialize_training(config):
# Create the save path
- if "cuda" in config["train"]["device"]:
+ if "cuda" in config.train.device:
assert torch.cuda.is_available(), "CUDA is not available"
- device = torch.device(config["train"]["device"])
+ device = torch.device(config.train.device)
torch.cuda.set_device(device)
- all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
+ all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
dataloaders = create_dataloaders (
available_shards=all_shards,
- img_preproc = config.get_preprocessing(),
- train_prop = config["data"]["splits"]["train"],
- val_prop = config["data"]["splits"]["val"],
- test_prop = config["data"]["splits"]["test"],
- n_sample_images=config["train"]["n_sample_images"],
- **config["data"]
+ img_preproc = config.img_preproc,
+ train_prop = config.data["splits"]["train"],
+ val_prop = config.data["splits"]["val"],
+ test_prop = config.data["splits"]["test"],
+ n_sample_images=config.train.n_sample_images,
+ **config.data.dict()
)
- decoder = create_decoder(device, config["decoder"], config["unets"])
+ decoder = create_decoder(device, config.decoder, config.unets)
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")
- tracker = create_tracker(config, **config["tracker"])
+ tracker = create_tracker(config, **config.tracker.dict())
train(dataloaders, decoder,
tracker=tracker,
inference_device=device,
- load_config=config["load"],
- evaluate_config=config["evaluate"],
- **config["train"],
+ load_config=config.load,
+ evaluate_config=config.evaluate,
+ **config.train.dict(),
)
# Create a simple click command line interface to load the config and start the training
@@ -447,7 +451,7 @@ def main(config_file):
print("Recalling config from {}".format(config_file))
with open(config_file) as f:
config = json.load(f)
- config = TrainDecoderConfig(config)
+ config = TrainDecoderConfig(**config)
initialize_training(config)