use pydantic to manage decoder training configs + defaults and refactor training script

This commit is contained in:
Phil Wang
2022-05-22 14:27:40 -07:00
parent d49eca62fa
commit a1ef023193
7 changed files with 145 additions and 264 deletions

View File

@@ -1076,6 +1076,7 @@ This library would not have gotten to this working state without the help of
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [x] use pydantic for config drive training
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
@@ -1092,7 +1093,6 @@ This library would not have gotten to this working state without the help of
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
- [ ] use json schemas to manage config fields, start with decoder and move into diffusion prior - think about whether json schema can allow for both config-driven as well as CLI driven training (by constructing the click decorators from the schema)
## Citations

View File

@@ -4,7 +4,7 @@ For more complex configuration, we provide the option of using a configuration f
### Decoder Trainer
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example).
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
**<ins>Unets</ins>:**

View File

@@ -1,82 +0,0 @@
"""
Defines the default values for the decoder config
"""
from enum import Enum
class ConfigField(Enum):
REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it.
default_config = {
"unets": ConfigField.REQUIRED,
"decoder": {
"image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet
"image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": True
},
"data": {
"webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images
"embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": True,
"resample_train": False,
"preprocessing": {
"ToTensor": True
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset
"device": "cuda:0",
"epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
"validation_samples": None, # Same as above but for validation.
"use_ema": True,
"ema_beta": 0.99,
"amp": False,
"save_all": False, # Whether to preserve all checkpoints
"save_latest": True, # Whether to always save the latest checkpoint
"save_best": True, # Whether to save the best checkpoint
"unet_training_mask": None # If None, use all unets
},
"evaluate": {
"n_evalation_samples": 1000,
"FID": None,
"IS": None,
"KID": None,
"LPIPS": None
},
"tracker": {
"tracker_type": "console", # Decoder currently supports console and wandb
"data_path": "./models", # The path where files will be saved locally
"wandb_entity": "", # Only needs to be set if tracker_type is wandb
"wandb_project": "",
"verbose": False # Whether to print console logging for non-console trackers
},
"load": {
"source": None, # Supports file and wandb
"run_path": "", # Used only if source is wandb
"file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
"resume": False # If using wandb, whether to resume the run
}
}

View File

@@ -1,100 +0,0 @@
{
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": {
"image_sizes": [64],
"image_size": [64],
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": true
},
"data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"embeddings_url": "s3://bucket/embeddings/path/",
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": true,
"resample_train": false,
"preprocessing": {
"RandomResizedCrop": {
"size": [128, 128],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6,
"device": "cuda:0",
"epoch_samples": null,
"validation_samples": null,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
"n_evalation_samples": 1000,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 10
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"tracker_type": "console",
"data_path": "./models",
"wandb_entity": "",
"wandb_project": "",
"verbose": false
},
"load": {
"source": null,
"run_path": "",
"file_path": "",
"resume": false
}
}

View File

@@ -1,47 +1,111 @@
from torchvision import transforms as T
from configs.decoder_defaults import default_config, ConfigField
from pydantic import BaseModel, validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
class TrainDecoderConfig:
def __init__(self, config):
self.config = self.map_config(config, default_config)
def exists(val):
return val is not None
def map_config(self, config, defaults):
"""
Returns a dictionary containing all config options in the union of config and defaults.
If the config value is an array, apply the default value to each element.
If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
"""
def _check_option(option, option_config, option_defaults):
for key, value in option_defaults.items():
if key not in option_config:
if value == ConfigField.REQUIRED:
raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
option_config[key] = value
for key, value in defaults.items():
if key not in config:
# Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
if value == ConfigField.REQUIRED:
raise RuntimeError("Required config value '{}' not supplied".format(key))
elif isinstance(value, dict):
config[key] = {}
elif isinstance(value, list):
config[key] = [{}]
# Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
# If it is a list, then we need to check each element
if isinstance(value, list):
assert isinstance(config[key], list)
for element in config[key]:
_check_option(key, element, value[0])
elif isinstance(value, dict):
_check_option(key, config[key], value)
# This object does not support checking
return config
def default(val, d):
return val if exists(val) else d
def get_preprocessing(self):
"""
Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
"""
class UnetConfig(BaseModel):
dim: int
dim_mults: List[int]
image_embed_dim: int = None
cond_dim: int = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
class Config:
extra = "allow"
class DecoderConfig(BaseModel):
image_size: int = None
image_sizes: Union[List[int], Tuple[int]] = None
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
learned_variance: bool = True
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
if exists(values.get('image_size')) ^ exists(image_sizes):
return image_sizes
raise ValueError('either image_size or image_sizes is required, but not both')
class Config:
extra = "allow"
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
end_shard: int = 9999999
shard_width: int = 6
index_width: int = 4
splits: Dict[str, float] = {
'train': 0.75,
'val': 0.15,
'test': 0.1
}
shuffle_train: bool = True
resample_train: bool = False
preprocessing: Dict[str, Any] = {'ToTensor': True}
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: float = 1e-4
wd: float = 0.01
max_grad_norm: float = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: List[bool] = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
FID: Dict[str, Any] = None
IS: Dict[str, Any] = None
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
unets: List[UnetConfig]
decoder: DecoderConfig
data: DecoderDataConfig
train: DecoderTrainConfig
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
@@ -50,13 +114,8 @@ class TrainDecoderConfig:
elif transformation_name == "ToTensor":
return T.ToTensor()
transformations = []
for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
if isinstance(transformation_kwargs, dict):
transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
else:
transformations.append(_get_transformation(transformation_name))
return T.Compose(transformations)
def __getitem__(self, key):
return self.config[key]
transforms = []
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.3.9',
version = '0.4.0',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -29,10 +29,10 @@ setup(
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
'jsonschema>=4.5.1',
'kornia>=0.5.4',
'numpy',
'pillow',
'pydantic',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',

View File

@@ -90,11 +90,11 @@ def create_dataloaders(
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = [Unet(**config) for config in unets_config]
unets = [Unet(**config.dict()) for config in unets_config]
decoder = Decoder(
unet=unets,
**decoder_config
**decoder_config.dict()
)
decoder.to(device=device)
@@ -154,13 +154,13 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
metrics = {}
# Prepare the data
examples = get_example_data(dataloader, device, n_evalation_samples)
examples = get_example_data(dataloader, device, n_evaluation_samples)
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
@@ -252,8 +252,8 @@ def train(
start_epoch = 0
validation_losses = []
if exists(load_config) and exists(load_config["source"]):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
if exists(load_config) and exists(load_config.source):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
trainer.to(device=inference_device)
if not exists(unet_training_mask):
@@ -386,21 +386,25 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
tracker_config = config["tracker"]
tracker_config = config.tracker
init_config = {}
init_config["config"] = config.config
if exists(tracker_config.init_config):
init_config["config"] = tracker_config.init_config
if tracker_type == "console":
tracker = ConsoleTracker(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config["load"]
if load_config["source"] == "wandb" and load_config["resume"]:
load_config = config.load
if load_config.source == "wandb" and load_config.resume:
# Then we are resuming the run load_config["run_path"]
run_id = config["resume"]["wandb_run_path"].split("/")[-1]
run_id = load_config.run_path.split("/")[-1]
init_config["id"] = run_id
init_config["resume"] = "must"
init_config["entity"] = tracker_config["wandb_entity"]
init_config["project"] = tracker_config["wandb_project"]
init_config["entity"] = tracker_config.wandb_entity
init_config["project"] = tracker_config.wandb_project
tracker = WandbTracker(data_path)
tracker.init(**init_config)
else:
@@ -409,35 +413,35 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
def initialize_training(config):
# Create the save path
if "cuda" in config["train"]["device"]:
if "cuda" in config.train.device:
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device(config["train"]["device"])
device = torch.device(config.train.device)
torch.cuda.set_device(device)
all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
dataloaders = create_dataloaders (
available_shards=all_shards,
img_preproc = config.get_preprocessing(),
train_prop = config["data"]["splits"]["train"],
val_prop = config["data"]["splits"]["val"],
test_prop = config["data"]["splits"]["test"],
n_sample_images=config["train"]["n_sample_images"],
**config["data"]
img_preproc = config.img_preproc,
train_prop = config.data["splits"]["train"],
val_prop = config.data["splits"]["val"],
test_prop = config.data["splits"]["test"],
n_sample_images=config.train.n_sample_images,
**config.data.dict()
)
decoder = create_decoder(device, config["decoder"], config["unets"])
decoder = create_decoder(device, config.decoder, config.unets)
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")
tracker = create_tracker(config, **config["tracker"])
tracker = create_tracker(config, **config.tracker.dict())
train(dataloaders, decoder,
tracker=tracker,
inference_device=device,
load_config=config["load"],
evaluate_config=config["evaluate"],
**config["train"],
load_config=config.load,
evaluate_config=config.evaluate,
**config.train.dict(),
)
# Create a simple click command line interface to load the config and start the training
@@ -447,7 +451,7 @@ def main(config_file):
print("Recalling config from {}".format(config_file))
with open(config_file) as f:
config = json.load(f)
config = TrainDecoderConfig(config)
config = TrainDecoderConfig(**config)
initialize_training(config)