diff --git a/configs/README.md b/configs/README.md
index 1586469..36c783f 100644
--- a/configs/README.md
+++ b/configs/README.md
@@ -6,9 +6,10 @@ For more complex configuration, we provide the option of using a configuration f
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
-**Unets:**
+**Unet:**
+
+This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
-Each member of this array defines a single unet that will be added to the decoder.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `dim` | Yes | N/A | The starting channels of the unet. |
@@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here.
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
+| `unets` | Yes | N/A | A list of unets, using the configuration above |
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
| `image_size` | Yes | N/A | Not used. Can be any number. |
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
diff --git a/configs/train_decoder_config.example.json b/configs/train_decoder_config.example.json
index dd387ed..f08758e 100644
--- a/configs/train_decoder_config.example.json
+++ b/configs/train_decoder_config.example.json
@@ -1,16 +1,16 @@
{
- "unets": [
- {
- "dim": 128,
- "image_embed_dim": 768,
- "cond_dim": 64,
- "channels": 3,
- "dim_mults": [1, 2, 4, 8],
- "attn_dim_head": 32,
- "attn_heads": 16
- }
- ],
"decoder": {
+ "unets": [
+ {
+ "dim": 128,
+ "image_embed_dim": 768,
+ "cond_dim": 64,
+ "channels": 3,
+ "dim_mults": [1, 2, 4, 8],
+ "attn_dim_head": 32,
+ "attn_heads": 16
+ }
+ ],
"image_sizes": [64],
"channels": 3,
"timesteps": 1000,
diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py
index f932eab..33d37c0 100644
--- a/dalle2_pytorch/dalle2_pytorch.py
+++ b/dalle2_pytorch/dalle2_pytorch.py
@@ -1712,7 +1712,7 @@ class Decoder(BaseGaussianDiffusion):
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
- assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
+ assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None
if exists(clip):
@@ -1728,7 +1728,7 @@ class Decoder(BaseGaussianDiffusion):
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
else:
- self.clip_image_size = image_size
+ self.clip_image_size = default(image_size, lambda: image_sizes[-1])
self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings
diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py
index e2a6161..2c8765a 100644
--- a/dalle2_pytorch/train_configs.py
+++ b/dalle2_pytorch/train_configs.py
@@ -3,15 +3,24 @@ from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
+from dalle2_pytorch.dalle2_pytorch import Unet, Decoder
+
+# helper functions
+
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
+def ListOrTuple(inner_type):
+ return Union[List[inner_type], Tuple[inner_type]]
+
+# pydantic classes
+
class UnetConfig(BaseModel):
dim: int
- dim_mults: List[int]
+ dim_mults: ListOrTuple(int)
image_embed_dim: int = None
cond_dim: int = None
channels: int = 3
@@ -22,14 +31,21 @@ class UnetConfig(BaseModel):
extra = "allow"
class DecoderConfig(BaseModel):
+ unets: Union[List[UnetConfig], Tuple[UnetConfig]]
image_size: int = None
- image_sizes: Union[List[int], Tuple[int]] = None
+ image_sizes: ListOrTuple(int) = None
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
learned_variance: bool = True
+ def create(self):
+ decoder_kwargs = self.dict()
+ unet_configs = decoder_kwargs.pop('unets')
+ unets = [Unet(**config) for config in unet_configs]
+ return Decoder(unets, **decoder_kwargs)
+
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
if exists(values.get('image_size')) ^ exists(image_sizes):
@@ -86,17 +102,17 @@ class DecoderTrainConfig(BaseModel):
wd: float = 0.01
max_grad_norm: float = 0.5
save_every_n_samples: int = 100000
- n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
+ n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
- epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
- validation_samples: int = None # Same as above but for validation.
+ epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
+ validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
- save_all: bool = False # Whether to preserve all checkpoints
- save_latest: bool = True # Whether to always save the latest checkpoint
- save_best: bool = True # Whether to save the best checkpoint
- unet_training_mask: List[bool] = None # If None, use all unets
+ save_all: bool = False # Whether to preserve all checkpoints
+ save_latest: bool = True # Whether to always save the latest checkpoint
+ save_best: bool = True # Whether to save the best checkpoint
+ unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
@@ -120,7 +136,6 @@ class DecoderLoadConfig(BaseModel):
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
- unets: List[UnetConfig]
decoder: DecoderConfig
data: DecoderDataConfig
train: DecoderTrainConfig
diff --git a/setup.py b/setup.py
index 4c41836..0b14740 100644
--- a/setup.py
+++ b/setup.py
@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
- version = '0.4.7',
+ version = '0.4.8',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
diff --git a/train_decoder.py b/train_decoder.py
index 0949b9e..af72bea 100644
--- a/train_decoder.py
+++ b/train_decoder.py
@@ -85,20 +85,6 @@ def create_dataloaders(
"test_sampling": test_sampling_dataloader
}
-
-def create_decoder(device, decoder_config, unets_config):
- """Creates a sample decoder"""
-
- unets = [Unet(**config.dict()) for config in unets_config]
-
- decoder = Decoder(
- unet=unets,
- **decoder_config.dict()
- )
-
- decoder.to(device=device)
- return decoder
-
def get_dataset_keys(dataloader):
"""
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
@@ -428,7 +414,7 @@ def initialize_training(config):
**config.data.dict()
)
- decoder = create_decoder(device, config.decoder, config.unets)
+ decoder = config.decoder.create().to(device = device)
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")