diff --git a/configs/README.md b/configs/README.md index 1586469..36c783f 100644 --- a/configs/README.md +++ b/configs/README.md @@ -6,9 +6,10 @@ For more complex configuration, we provide the option of using a configuration f The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json). -**Unets:** +**Unet:** + +This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets` -Each member of this array defines a single unet that will be added to the decoder. | Option | Required | Default | Description | | ------ | -------- | ------- | ----------- | | `dim` | Yes | N/A | The starting channels of the unet. | @@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here. Defines the configuration options for the decoder model. The unets defined above will automatically be inserted. | Option | Required | Default | Description | | ------ | -------- | ------- | ----------- | +| `unets` | Yes | N/A | A list of unets, using the configuration above | | `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. | | `image_size` | Yes | N/A | Not used. Can be any number. | | `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. | diff --git a/configs/train_decoder_config.example.json b/configs/train_decoder_config.example.json index dd387ed..f08758e 100644 --- a/configs/train_decoder_config.example.json +++ b/configs/train_decoder_config.example.json @@ -1,16 +1,16 @@ { - "unets": [ - { - "dim": 128, - "image_embed_dim": 768, - "cond_dim": 64, - "channels": 3, - "dim_mults": [1, 2, 4, 8], - "attn_dim_head": 32, - "attn_heads": 16 - } - ], "decoder": { + "unets": [ + { + "dim": 128, + "image_embed_dim": 768, + "cond_dim": 64, + "channels": 3, + "dim_mults": [1, 2, 4, 8], + "attn_dim_head": 32, + "attn_heads": 16 + } + ], "image_sizes": [64], "channels": 3, "timesteps": 1000, diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index f932eab..33d37c0 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1712,7 +1712,7 @@ class Decoder(BaseGaussianDiffusion): self.unconditional = unconditional assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' - assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' + assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' self.clip = None if exists(clip): @@ -1728,7 +1728,7 @@ class Decoder(BaseGaussianDiffusion): self.clip_image_size = clip.image_size self.channels = clip.image_channels else: - self.clip_image_size = image_size + self.clip_image_size = default(image_size, lambda: image_sizes[-1]) self.channels = channels self.condition_on_text_encodings = condition_on_text_encodings diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index e2a6161..2c8765a 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -3,15 +3,24 @@ from torchvision import transforms as T from pydantic import BaseModel, validator, root_validator from typing import List, Iterable, Optional, Union, Tuple, Dict, Any +from dalle2_pytorch.dalle2_pytorch import Unet, Decoder + +# helper functions + def exists(val): return val is not None def default(val, d): return val if exists(val) else d +def ListOrTuple(inner_type): + return Union[List[inner_type], Tuple[inner_type]] + +# pydantic classes + class UnetConfig(BaseModel): dim: int - dim_mults: List[int] + dim_mults: ListOrTuple(int) image_embed_dim: int = None cond_dim: int = None channels: int = 3 @@ -22,14 +31,21 @@ class UnetConfig(BaseModel): extra = "allow" class DecoderConfig(BaseModel): + unets: Union[List[UnetConfig], Tuple[UnetConfig]] image_size: int = None - image_sizes: Union[List[int], Tuple[int]] = None + image_sizes: ListOrTuple(int) = None channels: int = 3 timesteps: int = 1000 loss_type: str = 'l2' beta_schedule: str = 'cosine' learned_variance: bool = True + def create(self): + decoder_kwargs = self.dict() + unet_configs = decoder_kwargs.pop('unets') + unets = [Unet(**config) for config in unet_configs] + return Decoder(unets, **decoder_kwargs) + @validator('image_sizes') def check_image_sizes(cls, image_sizes, values): if exists(values.get('image_size')) ^ exists(image_sizes): @@ -86,17 +102,17 @@ class DecoderTrainConfig(BaseModel): wd: float = 0.01 max_grad_norm: float = 0.5 save_every_n_samples: int = 100000 - n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset + n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset device: str = 'cuda:0' - epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. - validation_samples: int = None # Same as above but for validation. + epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. + validation_samples: int = None # Same as above but for validation. use_ema: bool = True ema_beta: float = 0.99 amp: bool = False - save_all: bool = False # Whether to preserve all checkpoints - save_latest: bool = True # Whether to always save the latest checkpoint - save_best: bool = True # Whether to save the best checkpoint - unet_training_mask: List[bool] = None # If None, use all unets + save_all: bool = False # Whether to preserve all checkpoints + save_latest: bool = True # Whether to always save the latest checkpoint + save_best: bool = True # Whether to save the best checkpoint + unet_training_mask: ListOrTuple(bool) = None # If None, use all unets class DecoderEvaluateConfig(BaseModel): n_evaluation_samples: int = 1000 @@ -120,7 +136,6 @@ class DecoderLoadConfig(BaseModel): resume: bool = False # If using wandb, whether to resume the run class TrainDecoderConfig(BaseModel): - unets: List[UnetConfig] decoder: DecoderConfig data: DecoderDataConfig train: DecoderTrainConfig diff --git a/setup.py b/setup.py index 4c41836..0b14740 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.4.7', + version = '0.4.8', license='MIT', description = 'DALL-E 2', author = 'Phil Wang', diff --git a/train_decoder.py b/train_decoder.py index 0949b9e..af72bea 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -85,20 +85,6 @@ def create_dataloaders( "test_sampling": test_sampling_dataloader } - -def create_decoder(device, decoder_config, unets_config): - """Creates a sample decoder""" - - unets = [Unet(**config.dict()) for config in unets_config] - - decoder = Decoder( - unet=unets, - **decoder_config.dict() - ) - - decoder.to(device=device) - return decoder - def get_dataset_keys(dataloader): """ It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it. @@ -428,7 +414,7 @@ def initialize_training(config): **config.data.dict() ) - decoder = create_decoder(device, config.decoder, config.unets) + decoder = config.decoder.create().to(device = device) num_parameters = sum(p.numel() for p in decoder.parameters()) print(print_ribbon("Loaded Config", repeat=40)) print(f"Number of parameters: {num_parameters}")