mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
move neural network creations off the configuration file into the pydantic classes
This commit is contained in:
@@ -3,15 +3,24 @@ from torchvision import transforms as T
|
||||
from pydantic import BaseModel, validator, root_validator
|
||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def ListOrTuple(inner_type):
|
||||
return Union[List[inner_type], Tuple[inner_type]]
|
||||
|
||||
# pydantic classes
|
||||
|
||||
class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: List[int]
|
||||
dim_mults: ListOrTuple(int)
|
||||
image_embed_dim: int = None
|
||||
cond_dim: int = None
|
||||
channels: int = 3
|
||||
@@ -22,14 +31,21 @@ class UnetConfig(BaseModel):
|
||||
extra = "allow"
|
||||
|
||||
class DecoderConfig(BaseModel):
|
||||
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
|
||||
image_size: int = None
|
||||
image_sizes: Union[List[int], Tuple[int]] = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: str = 'cosine'
|
||||
learned_variance: bool = True
|
||||
|
||||
def create(self):
|
||||
decoder_kwargs = self.dict()
|
||||
unet_configs = decoder_kwargs.pop('unets')
|
||||
unets = [Unet(**config) for config in unet_configs]
|
||||
return Decoder(unets, **decoder_kwargs)
|
||||
|
||||
@validator('image_sizes')
|
||||
def check_image_sizes(cls, image_sizes, values):
|
||||
if exists(values.get('image_size')) ^ exists(image_sizes):
|
||||
@@ -86,17 +102,17 @@ class DecoderTrainConfig(BaseModel):
|
||||
wd: float = 0.01
|
||||
max_grad_norm: float = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
device: str = 'cuda:0'
|
||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
validation_samples: int = None # Same as above but for validation.
|
||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
validation_samples: int = None # Same as above but for validation.
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
amp: bool = False
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: List[bool] = None # If None, use all unets
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||
|
||||
class DecoderEvaluateConfig(BaseModel):
|
||||
n_evaluation_samples: int = 1000
|
||||
@@ -120,7 +136,6 @@ class DecoderLoadConfig(BaseModel):
|
||||
resume: bool = False # If using wandb, whether to resume the run
|
||||
|
||||
class TrainDecoderConfig(BaseModel):
|
||||
unets: List[UnetConfig]
|
||||
decoder: DecoderConfig
|
||||
data: DecoderDataConfig
|
||||
train: DecoderTrainConfig
|
||||
|
||||
Reference in New Issue
Block a user