make sure diffusion prior can be instantiated from pydantic class without clip

This commit is contained in:
Phil Wang
2022-05-26 08:47:30 -07:00
parent f4fe6c570d
commit b8af2210df
2 changed files with 19 additions and 9 deletions

View File

@@ -27,6 +27,9 @@ def default(val, d):
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
def SingularOrIterable(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
# general pydantic classes
class TrainSplitConfig(BaseModel):
@@ -88,7 +91,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel):
clip: AdapterConfig
clip: AdapterConfig = None
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
@@ -105,8 +108,15 @@ class DiffusionPriorConfig(BaseModel):
def create(self):
kwargs = self.dict()
clip = AdapterConfig(**kwargs.pop('clip')).create()
diffusion_prior_network = DiffusionPriorNetworkConfig(**kwargs.pop('net')).create()
has_clip = exists(kwargs.pop('clip'))
kwargs.pop('net')
clip = None
if has_clip:
clip = self.clip.create()
diffusion_prior_network = self.net.create()
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
class DiffusionPriorTrainConfig(BaseModel):
@@ -215,16 +225,16 @@ class DecoderDataConfig(BaseModel):
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: float = 1e-4
wd: float = 0.01
max_grad_norm: float = 0.5
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.99
ema_beta: float = 0.999
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.5.5',
version = '0.5.6',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',