move neural network creations off the configuration file into the pydantic classes

This commit is contained in:
Phil Wang
2022-05-22 19:18:18 -07:00
parent 0f4edff214
commit 5c397c9d66
6 changed files with 44 additions and 41 deletions

View File

@@ -6,9 +6,10 @@ For more complex configuration, we provide the option of using a configuration f
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json). The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
**<ins>Unets</ins>:** **<ins>Unet</ins>:**
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
Each member of this array defines a single unet that will be added to the decoder.
| Option | Required | Default | Description | | Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- | | ------ | -------- | ------- | ----------- |
| `dim` | Yes | N/A | The starting channels of the unet. | | `dim` | Yes | N/A | The starting channels of the unet. |
@@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here.
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted. Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description | | Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- | | ------ | -------- | ------- | ----------- |
| `unets` | Yes | N/A | A list of unets, using the configuration above |
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. | | `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
| `image_size` | Yes | N/A | Not used. Can be any number. | | `image_size` | Yes | N/A | Not used. Can be any number. |
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. | | `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |

View File

@@ -1,16 +1,16 @@
{ {
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": { "decoder": {
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"image_sizes": [64], "image_sizes": [64],
"channels": 3, "channels": 3,
"timesteps": 1000, "timesteps": 1000,

View File

@@ -1712,7 +1712,7 @@ class Decoder(BaseGaussianDiffusion):
self.unconditional = unconditional self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None self.clip = None
if exists(clip): if exists(clip):
@@ -1728,7 +1728,7 @@ class Decoder(BaseGaussianDiffusion):
self.clip_image_size = clip.image_size self.clip_image_size = clip.image_size
self.channels = clip.image_channels self.channels = clip.image_channels
else: else:
self.clip_image_size = image_size self.clip_image_size = default(image_size, lambda: image_sizes[-1])
self.channels = channels self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings self.condition_on_text_encodings = condition_on_text_encodings

View File

@@ -3,15 +3,24 @@ from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder
# helper functions
def exists(val): def exists(val):
return val is not None return val is not None
def default(val, d): def default(val, d):
return val if exists(val) else d return val if exists(val) else d
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
# pydantic classes
class UnetConfig(BaseModel): class UnetConfig(BaseModel):
dim: int dim: int
dim_mults: List[int] dim_mults: ListOrTuple(int)
image_embed_dim: int = None image_embed_dim: int = None
cond_dim: int = None cond_dim: int = None
channels: int = 3 channels: int = 3
@@ -22,14 +31,21 @@ class UnetConfig(BaseModel):
extra = "allow" extra = "allow"
class DecoderConfig(BaseModel): class DecoderConfig(BaseModel):
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
image_size: int = None image_size: int = None
image_sizes: Union[List[int], Tuple[int]] = None image_sizes: ListOrTuple(int) = None
channels: int = 3 channels: int = 3
timesteps: int = 1000 timesteps: int = 1000
loss_type: str = 'l2' loss_type: str = 'l2'
beta_schedule: str = 'cosine' beta_schedule: str = 'cosine'
learned_variance: bool = True learned_variance: bool = True
def create(self):
decoder_kwargs = self.dict()
unet_configs = decoder_kwargs.pop('unets')
unets = [Unet(**config) for config in unet_configs]
return Decoder(unets, **decoder_kwargs)
@validator('image_sizes') @validator('image_sizes')
def check_image_sizes(cls, image_sizes, values): def check_image_sizes(cls, image_sizes, values):
if exists(values.get('image_size')) ^ exists(image_sizes): if exists(values.get('image_size')) ^ exists(image_sizes):
@@ -86,17 +102,17 @@ class DecoderTrainConfig(BaseModel):
wd: float = 0.01 wd: float = 0.01
max_grad_norm: float = 0.5 max_grad_norm: float = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0' device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation. validation_samples: int = None # Same as above but for validation.
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.99 ema_beta: float = 0.99
amp: bool = False amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: List[bool] = None # If None, use all unets unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel): class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000 n_evaluation_samples: int = 1000
@@ -120,7 +136,6 @@ class DecoderLoadConfig(BaseModel):
resume: bool = False # If using wandb, whether to resume the run resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel): class TrainDecoderConfig(BaseModel):
unets: List[UnetConfig]
decoder: DecoderConfig decoder: DecoderConfig
data: DecoderDataConfig data: DecoderDataConfig
train: DecoderTrainConfig train: DecoderTrainConfig

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.4.7', version = '0.4.8',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -85,20 +85,6 @@ def create_dataloaders(
"test_sampling": test_sampling_dataloader "test_sampling": test_sampling_dataloader
} }
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = [Unet(**config.dict()) for config in unets_config]
decoder = Decoder(
unet=unets,
**decoder_config.dict()
)
decoder.to(device=device)
return decoder
def get_dataset_keys(dataloader): def get_dataset_keys(dataloader):
""" """
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it. It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
@@ -428,7 +414,7 @@ def initialize_training(config):
**config.data.dict() **config.data.dict()
) )
decoder = create_decoder(device, config.decoder, config.unets) decoder = config.decoder.create().to(device = device)
num_parameters = sum(p.numel() for p in decoder.parameters()) num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40)) print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}") print(f"Number of parameters: {num_parameters}")