use pydantic to manage decoder training configs + defaults and refactor training script

This commit is contained in:
Phil Wang
2022-05-22 14:27:40 -07:00
parent d49eca62fa
commit a1ef023193
7 changed files with 145 additions and 264 deletions

View File

@@ -1,47 +1,111 @@
from torchvision import transforms as T
from configs.decoder_defaults import default_config, ConfigField
from pydantic import BaseModel, validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
class TrainDecoderConfig:
def __init__(self, config):
self.config = self.map_config(config, default_config)
def exists(val):
return val is not None
def map_config(self, config, defaults):
"""
Returns a dictionary containing all config options in the union of config and defaults.
If the config value is an array, apply the default value to each element.
If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
"""
def _check_option(option, option_config, option_defaults):
for key, value in option_defaults.items():
if key not in option_config:
if value == ConfigField.REQUIRED:
raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
option_config[key] = value
for key, value in defaults.items():
if key not in config:
# Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
if value == ConfigField.REQUIRED:
raise RuntimeError("Required config value '{}' not supplied".format(key))
elif isinstance(value, dict):
config[key] = {}
elif isinstance(value, list):
config[key] = [{}]
# Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
# If it is a list, then we need to check each element
if isinstance(value, list):
assert isinstance(config[key], list)
for element in config[key]:
_check_option(key, element, value[0])
elif isinstance(value, dict):
_check_option(key, config[key], value)
# This object does not support checking
return config
def default(val, d):
return val if exists(val) else d
def get_preprocessing(self):
"""
Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
"""
class UnetConfig(BaseModel):
dim: int
dim_mults: List[int]
image_embed_dim: int = None
cond_dim: int = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
class Config:
extra = "allow"
class DecoderConfig(BaseModel):
image_size: int = None
image_sizes: Union[List[int], Tuple[int]] = None
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
learned_variance: bool = True
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
if exists(values.get('image_size')) ^ exists(image_sizes):
return image_sizes
raise ValueError('either image_size or image_sizes is required, but not both')
class Config:
extra = "allow"
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
end_shard: int = 9999999
shard_width: int = 6
index_width: int = 4
splits: Dict[str, float] = {
'train': 0.75,
'val': 0.15,
'test': 0.1
}
shuffle_train: bool = True
resample_train: bool = False
preprocessing: Dict[str, Any] = {'ToTensor': True}
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: float = 1e-4
wd: float = 0.01
max_grad_norm: float = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: List[bool] = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
FID: Dict[str, Any] = None
IS: Dict[str, Any] = None
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
unets: List[UnetConfig]
decoder: DecoderConfig
data: DecoderDataConfig
train: DecoderTrainConfig
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
@@ -50,13 +114,8 @@ class TrainDecoderConfig:
elif transformation_name == "ToTensor":
return T.ToTensor()
transformations = []
for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
if isinstance(transformation_kwargs, dict):
transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
else:
transformations.append(_get_transformation(transformation_name))
return T.Compose(transformations)
def __getitem__(self, key):
return self.config[key]
transforms = []
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)