mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
move neural network creations off the configuration file into the pydantic classes
This commit is contained in:
@@ -6,9 +6,10 @@ For more complex configuration, we provide the option of using a configuration f
|
|||||||
|
|
||||||
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
||||||
|
|
||||||
**<ins>Unets</ins>:**
|
**<ins>Unet</ins>:**
|
||||||
|
|
||||||
|
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
|
||||||
|
|
||||||
Each member of this array defines a single unet that will be added to the decoder.
|
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `dim` | Yes | N/A | The starting channels of the unet. |
|
| `dim` | Yes | N/A | The starting channels of the unet. |
|
||||||
@@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here.
|
|||||||
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `unets` | Yes | N/A | A list of unets, using the configuration above |
|
||||||
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
||||||
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
||||||
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
{
|
{
|
||||||
"unets": [
|
|
||||||
{
|
|
||||||
"dim": 128,
|
|
||||||
"image_embed_dim": 768,
|
|
||||||
"cond_dim": 64,
|
|
||||||
"channels": 3,
|
|
||||||
"dim_mults": [1, 2, 4, 8],
|
|
||||||
"attn_dim_head": 32,
|
|
||||||
"attn_heads": 16
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"decoder": {
|
"decoder": {
|
||||||
|
"unets": [
|
||||||
|
{
|
||||||
|
"dim": 128,
|
||||||
|
"image_embed_dim": 768,
|
||||||
|
"cond_dim": 64,
|
||||||
|
"channels": 3,
|
||||||
|
"dim_mults": [1, 2, 4, 8],
|
||||||
|
"attn_dim_head": 32,
|
||||||
|
"attn_heads": 16
|
||||||
|
}
|
||||||
|
],
|
||||||
"image_sizes": [64],
|
"image_sizes": [64],
|
||||||
"channels": 3,
|
"channels": 3,
|
||||||
"timesteps": 1000,
|
"timesteps": 1000,
|
||||||
|
|||||||
@@ -1712,7 +1712,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.unconditional = unconditional
|
self.unconditional = unconditional
|
||||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||||
|
|
||||||
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||||
|
|
||||||
self.clip = None
|
self.clip = None
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
@@ -1728,7 +1728,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.clip_image_size = clip.image_size
|
self.clip_image_size = clip.image_size
|
||||||
self.channels = clip.image_channels
|
self.channels = clip.image_channels
|
||||||
else:
|
else:
|
||||||
self.clip_image_size = image_size
|
self.clip_image_size = default(image_size, lambda: image_sizes[-1])
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
||||||
self.condition_on_text_encodings = condition_on_text_encodings
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
|
|||||||
@@ -3,15 +3,24 @@ from torchvision import transforms as T
|
|||||||
from pydantic import BaseModel, validator, root_validator
|
from pydantic import BaseModel, validator, root_validator
|
||||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||||
|
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
return val if exists(val) else d
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
def ListOrTuple(inner_type):
|
||||||
|
return Union[List[inner_type], Tuple[inner_type]]
|
||||||
|
|
||||||
|
# pydantic classes
|
||||||
|
|
||||||
class UnetConfig(BaseModel):
|
class UnetConfig(BaseModel):
|
||||||
dim: int
|
dim: int
|
||||||
dim_mults: List[int]
|
dim_mults: ListOrTuple(int)
|
||||||
image_embed_dim: int = None
|
image_embed_dim: int = None
|
||||||
cond_dim: int = None
|
cond_dim: int = None
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
@@ -22,14 +31,21 @@ class UnetConfig(BaseModel):
|
|||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
class DecoderConfig(BaseModel):
|
class DecoderConfig(BaseModel):
|
||||||
|
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
|
||||||
image_size: int = None
|
image_size: int = None
|
||||||
image_sizes: Union[List[int], Tuple[int]] = None
|
image_sizes: ListOrTuple(int) = None
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
timesteps: int = 1000
|
timesteps: int = 1000
|
||||||
loss_type: str = 'l2'
|
loss_type: str = 'l2'
|
||||||
beta_schedule: str = 'cosine'
|
beta_schedule: str = 'cosine'
|
||||||
learned_variance: bool = True
|
learned_variance: bool = True
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
decoder_kwargs = self.dict()
|
||||||
|
unet_configs = decoder_kwargs.pop('unets')
|
||||||
|
unets = [Unet(**config) for config in unet_configs]
|
||||||
|
return Decoder(unets, **decoder_kwargs)
|
||||||
|
|
||||||
@validator('image_sizes')
|
@validator('image_sizes')
|
||||||
def check_image_sizes(cls, image_sizes, values):
|
def check_image_sizes(cls, image_sizes, values):
|
||||||
if exists(values.get('image_size')) ^ exists(image_sizes):
|
if exists(values.get('image_size')) ^ exists(image_sizes):
|
||||||
@@ -86,17 +102,17 @@ class DecoderTrainConfig(BaseModel):
|
|||||||
wd: float = 0.01
|
wd: float = 0.01
|
||||||
max_grad_norm: float = 0.5
|
max_grad_norm: float = 0.5
|
||||||
save_every_n_samples: int = 100000
|
save_every_n_samples: int = 100000
|
||||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||||
device: str = 'cuda:0'
|
device: str = 'cuda:0'
|
||||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||||
validation_samples: int = None # Same as above but for validation.
|
validation_samples: int = None # Same as above but for validation.
|
||||||
use_ema: bool = True
|
use_ema: bool = True
|
||||||
ema_beta: float = 0.99
|
ema_beta: float = 0.99
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
save_all: bool = False # Whether to preserve all checkpoints
|
save_all: bool = False # Whether to preserve all checkpoints
|
||||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||||
save_best: bool = True # Whether to save the best checkpoint
|
save_best: bool = True # Whether to save the best checkpoint
|
||||||
unet_training_mask: List[bool] = None # If None, use all unets
|
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||||
|
|
||||||
class DecoderEvaluateConfig(BaseModel):
|
class DecoderEvaluateConfig(BaseModel):
|
||||||
n_evaluation_samples: int = 1000
|
n_evaluation_samples: int = 1000
|
||||||
@@ -120,7 +136,6 @@ class DecoderLoadConfig(BaseModel):
|
|||||||
resume: bool = False # If using wandb, whether to resume the run
|
resume: bool = False # If using wandb, whether to resume the run
|
||||||
|
|
||||||
class TrainDecoderConfig(BaseModel):
|
class TrainDecoderConfig(BaseModel):
|
||||||
unets: List[UnetConfig]
|
|
||||||
decoder: DecoderConfig
|
decoder: DecoderConfig
|
||||||
data: DecoderDataConfig
|
data: DecoderDataConfig
|
||||||
train: DecoderTrainConfig
|
train: DecoderTrainConfig
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.4.7',
|
version = '0.4.8',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -85,20 +85,6 @@ def create_dataloaders(
|
|||||||
"test_sampling": test_sampling_dataloader
|
"test_sampling": test_sampling_dataloader
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_decoder(device, decoder_config, unets_config):
|
|
||||||
"""Creates a sample decoder"""
|
|
||||||
|
|
||||||
unets = [Unet(**config.dict()) for config in unets_config]
|
|
||||||
|
|
||||||
decoder = Decoder(
|
|
||||||
unet=unets,
|
|
||||||
**decoder_config.dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder.to(device=device)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
def get_dataset_keys(dataloader):
|
def get_dataset_keys(dataloader):
|
||||||
"""
|
"""
|
||||||
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
||||||
@@ -428,7 +414,7 @@ def initialize_training(config):
|
|||||||
**config.data.dict()
|
**config.data.dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = create_decoder(device, config.decoder, config.unets)
|
decoder = config.decoder.create().to(device = device)
|
||||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||||
print(print_ribbon("Loaded Config", repeat=40))
|
print(print_ribbon("Loaded Config", repeat=40))
|
||||||
print(f"Number of parameters: {num_parameters}")
|
print(f"Number of parameters: {num_parameters}")
|
||||||
|
|||||||
Reference in New Issue
Block a user