diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 36be714..28b8f89 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -27,6 +27,9 @@ def default(val, d): def ListOrTuple(inner_type): return Union[List[inner_type], Tuple[inner_type]] +def SingularOrIterable(inner_type): + return Union[inner_type, ListOrTuple(inner_type)] + # general pydantic classes class TrainSplitConfig(BaseModel): @@ -88,7 +91,7 @@ class DiffusionPriorNetworkConfig(BaseModel): return DiffusionPriorNetwork(**kwargs) class DiffusionPriorConfig(BaseModel): - clip: AdapterConfig + clip: AdapterConfig = None net: DiffusionPriorNetworkConfig image_embed_dim: int image_size: int @@ -105,9 +108,16 @@ class DiffusionPriorConfig(BaseModel): def create(self): kwargs = self.dict() - clip = AdapterConfig(**kwargs.pop('clip')).create() - diffusion_prior_network = DiffusionPriorNetworkConfig(**kwargs.pop('net')).create() - return DiffusionPrior(net = diffusion_prior_network, clip=clip, **kwargs) + + has_clip = exists(kwargs.pop('clip')) + kwargs.pop('net') + + clip = None + if has_clip: + clip = self.clip.create() + + diffusion_prior_network = self.net.create() + return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs) class DiffusionPriorTrainConfig(BaseModel): epochs: int = 1 @@ -215,16 +225,16 @@ class DecoderDataConfig(BaseModel): class DecoderTrainConfig(BaseModel): epochs: int = 20 - lr: float = 1e-4 - wd: float = 0.01 - max_grad_norm: float = 0.5 + lr: SingularOrIterable(float) = 1e-4 + wd: SingularOrIterable(float) = 0.01 + max_grad_norm: SingularOrIterable(float) = 0.5 save_every_n_samples: int = 100000 n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset device: str = 'cuda:0' epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. validation_samples: int = None # Same as above but for validation. use_ema: bool = True - ema_beta: float = 0.99 + ema_beta: float = 0.999 amp: bool = False save_all: bool = False # Whether to preserve all checkpoints save_latest: bool = True # Whether to always save the latest checkpoint diff --git a/setup.py b/setup.py index bfed8a4..2156d2e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.5.5', + version = '0.5.6', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',