mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
make sure diffusion prior can be instantiated from pydantic class without clip
This commit is contained in:
@@ -27,6 +27,9 @@ def default(val, d):
|
|||||||
def ListOrTuple(inner_type):
|
def ListOrTuple(inner_type):
|
||||||
return Union[List[inner_type], Tuple[inner_type]]
|
return Union[List[inner_type], Tuple[inner_type]]
|
||||||
|
|
||||||
|
def SingularOrIterable(inner_type):
|
||||||
|
return Union[inner_type, ListOrTuple(inner_type)]
|
||||||
|
|
||||||
# general pydantic classes
|
# general pydantic classes
|
||||||
|
|
||||||
class TrainSplitConfig(BaseModel):
|
class TrainSplitConfig(BaseModel):
|
||||||
@@ -88,7 +91,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
|||||||
return DiffusionPriorNetwork(**kwargs)
|
return DiffusionPriorNetwork(**kwargs)
|
||||||
|
|
||||||
class DiffusionPriorConfig(BaseModel):
|
class DiffusionPriorConfig(BaseModel):
|
||||||
clip: AdapterConfig
|
clip: AdapterConfig = None
|
||||||
net: DiffusionPriorNetworkConfig
|
net: DiffusionPriorNetworkConfig
|
||||||
image_embed_dim: int
|
image_embed_dim: int
|
||||||
image_size: int
|
image_size: int
|
||||||
@@ -105,8 +108,15 @@ class DiffusionPriorConfig(BaseModel):
|
|||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
kwargs = self.dict()
|
kwargs = self.dict()
|
||||||
clip = AdapterConfig(**kwargs.pop('clip')).create()
|
|
||||||
diffusion_prior_network = DiffusionPriorNetworkConfig(**kwargs.pop('net')).create()
|
has_clip = exists(kwargs.pop('clip'))
|
||||||
|
kwargs.pop('net')
|
||||||
|
|
||||||
|
clip = None
|
||||||
|
if has_clip:
|
||||||
|
clip = self.clip.create()
|
||||||
|
|
||||||
|
diffusion_prior_network = self.net.create()
|
||||||
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
|
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
|
||||||
|
|
||||||
class DiffusionPriorTrainConfig(BaseModel):
|
class DiffusionPriorTrainConfig(BaseModel):
|
||||||
@@ -215,16 +225,16 @@ class DecoderDataConfig(BaseModel):
|
|||||||
|
|
||||||
class DecoderTrainConfig(BaseModel):
|
class DecoderTrainConfig(BaseModel):
|
||||||
epochs: int = 20
|
epochs: int = 20
|
||||||
lr: float = 1e-4
|
lr: SingularOrIterable(float) = 1e-4
|
||||||
wd: float = 0.01
|
wd: SingularOrIterable(float) = 0.01
|
||||||
max_grad_norm: float = 0.5
|
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||||
save_every_n_samples: int = 100000
|
save_every_n_samples: int = 100000
|
||||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||||
device: str = 'cuda:0'
|
device: str = 'cuda:0'
|
||||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||||
validation_samples: int = None # Same as above but for validation.
|
validation_samples: int = None # Same as above but for validation.
|
||||||
use_ema: bool = True
|
use_ema: bool = True
|
||||||
ema_beta: float = 0.99
|
ema_beta: float = 0.999
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
save_all: bool = False # Whether to preserve all checkpoints
|
save_all: bool = False # Whether to preserve all checkpoints
|
||||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||||
|
|||||||
Reference in New Issue
Block a user