diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 4d86342..d6b9edb 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -1,5 +1,5 @@ from torchvision import transforms as T -from pydantic import BaseModel, validator +from pydantic import BaseModel, validator, root_validator from typing import List, Iterable, Optional, Union, Tuple, Dict, Any def exists(val): @@ -38,6 +38,17 @@ class DecoderConfig(BaseModel): class Config: extra = "allow" +class TrainSplitConfig(BaseModel): + train: float = 0.75 + val: float = 0.15 + test: float = 0.1 + + @root_validator + def validate_all(cls, fields): + if sum([*fields.values()]) != 1.: + raise ValueError(f'{fields.keys()} must sum to 1.0') + return fields + class DecoderDataConfig(BaseModel): webdataset_base_url: str # path to a webdataset with jpg images embeddings_url: str # path to .npy files with embeddings @@ -47,11 +58,7 @@ class DecoderDataConfig(BaseModel): end_shard: int = 9999999 shard_width: int = 6 index_width: int = 4 - splits: Dict[str, float] = { - 'train': 0.75, - 'val': 0.15, - 'test': 0.1 - } + splits: TrainSplitConfig shuffle_train: bool = True resample_train: bool = False preprocessing: Dict[str, Any] = {'ToTensor': True} diff --git a/setup.py b/setup.py index 9d27e9c..3ce651a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.4.0', + version = '0.4.1', license='MIT', description = 'DALL-E 2', author = 'Phil Wang', diff --git a/train_decoder.py b/train_decoder.py index 3e4f7d4..9e872e5 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -422,9 +422,9 @@ def initialize_training(config): dataloaders = create_dataloaders ( available_shards=all_shards, img_preproc = config.img_preproc, - train_prop = config.data["splits"]["train"], - val_prop = config.data["splits"]["val"], - test_prop = config.data["splits"]["test"], + train_prop = config.data.splits.train, + val_prop = config.data.splits.val, + test_prop = config.data.splits.test, n_sample_images=config.train.n_sample_images, **config.data.dict() )