make sure diffusion prior can be instantiated from pydantic class without clip

This commit is contained in:
Phil Wang
2022-05-26 08:47:30 -07:00
parent f4fe6c570d
commit b8af2210df
2 changed files with 19 additions and 9 deletions

View File

@@ -27,6 +27,9 @@ def default(val, d):
def ListOrTuple(inner_type): def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]] return Union[List[inner_type], Tuple[inner_type]]
def SingularOrIterable(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
# general pydantic classes # general pydantic classes
class TrainSplitConfig(BaseModel): class TrainSplitConfig(BaseModel):
@@ -88,7 +91,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
return DiffusionPriorNetwork(**kwargs) return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel): class DiffusionPriorConfig(BaseModel):
clip: AdapterConfig clip: AdapterConfig = None
net: DiffusionPriorNetworkConfig net: DiffusionPriorNetworkConfig
image_embed_dim: int image_embed_dim: int
image_size: int image_size: int
@@ -105,8 +108,15 @@ class DiffusionPriorConfig(BaseModel):
def create(self): def create(self):
kwargs = self.dict() kwargs = self.dict()
clip = AdapterConfig(**kwargs.pop('clip')).create()
diffusion_prior_network = DiffusionPriorNetworkConfig(**kwargs.pop('net')).create() has_clip = exists(kwargs.pop('clip'))
kwargs.pop('net')
clip = None
if has_clip:
clip = self.clip.create()
diffusion_prior_network = self.net.create()
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs) return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
class DiffusionPriorTrainConfig(BaseModel): class DiffusionPriorTrainConfig(BaseModel):
@@ -215,16 +225,16 @@ class DecoderDataConfig(BaseModel):
class DecoderTrainConfig(BaseModel): class DecoderTrainConfig(BaseModel):
epochs: int = 20 epochs: int = 20
lr: float = 1e-4 lr: SingularOrIterable(float) = 1e-4
wd: float = 0.01 wd: SingularOrIterable(float) = 0.01
max_grad_norm: float = 0.5 max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0' device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation. validation_samples: int = None # Same as above but for validation.
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.99 ema_beta: float = 0.999
amp: bool = False amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint save_latest: bool = True # Whether to always save the latest checkpoint

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.5.5', version = '0.5.6',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',