Compare commits

...

7 Commits

Author SHA1 Message Date
Phil Wang
dc50c6b34e allow for config driven creation of clip-less diffusion prior 2022-05-22 20:13:20 -07:00
Phil Wang
2b1fd1ad2e product management 2022-05-22 19:23:40 -07:00
zion
82a2ef37d9 Update README.md (#109)
block in a section that links to available pre-trained models for those who are interested
2022-05-22 19:22:30 -07:00
Phil Wang
5c397c9d66 move neural network creations off the configuration file into the pydantic classes 2022-05-22 19:18:18 -07:00
Phil Wang
0f4edff214 derived value for image preprocessing belongs to the data config class 2022-05-22 18:42:40 -07:00
Phil Wang
501a8c7c46 small cleanup 2022-05-22 15:39:38 -07:00
Phil Wang
4e49373fc5 project management 2022-05-22 15:27:40 -07:00
10 changed files with 119 additions and 70 deletions

View File

@@ -24,6 +24,11 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
*ongoing at 21k steps*
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder 🚧
- DALL-E 2 🚧
## Install
```bash
@@ -1078,6 +1083,8 @@ This library would not have gotten to this working state without the help of
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [x] use pydantic for config drive training
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
@@ -1087,11 +1094,9 @@ This library would not have gotten to this working state without the help of
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations

View File

@@ -6,9 +6,10 @@ For more complex configuration, we provide the option of using a configuration f
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
**<ins>Unets</ins>:**
**<ins>Unet</ins>:**
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
Each member of this array defines a single unet that will be added to the decoder.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `dim` | Yes | N/A | The starting channels of the unet. |
@@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here.
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `unets` | Yes | N/A | A list of unets, using the configuration above |
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
| `image_size` | Yes | N/A | Not used. Can be any number. |
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |

View File

@@ -1,16 +1,16 @@
{
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": {
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"image_sizes": [64],
"channels": 3,
"timesteps": 1000,

View File

@@ -1712,7 +1712,7 @@ class Decoder(BaseGaussianDiffusion):
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None
if exists(clip):
@@ -1728,7 +1728,7 @@ class Decoder(BaseGaussianDiffusion):
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
else:
self.clip_image_size = image_size
self.clip_image_size = default(image_size, lambda: image_sizes[-1])
self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings

View File

@@ -3,15 +3,61 @@ from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
# pydantic classes
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
num_text_embeds: int = 1
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.
final_proj: bool = True
normformer: bool = False
rotary_emb: bool = True
class DiffusionPriorConfig(BaseModel):
# only clip-less diffusion prior config for now
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
image_channels: int = 3
timesteps: int = 1000
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
beta_schedule: str = 'cosine'
def create(self):
kwargs = self.dict()
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
class Config:
extra = "allow"
class UnetConfig(BaseModel):
dim: int
dim_mults: List[int]
dim_mults: ListOrTuple(int)
image_embed_dim: int = None
cond_dim: int = None
channels: int = 3
@@ -22,13 +68,22 @@ class UnetConfig(BaseModel):
extra = "allow"
class DecoderConfig(BaseModel):
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
image_size: int = None
image_sizes: Union[List[int], Tuple[int]] = None
image_sizes: ListOrTuple(int) = None
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
learned_variance: bool = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
def create(self):
decoder_kwargs = self.dict()
unet_configs = decoder_kwargs.pop('unets')
unets = [Unet(**config) for config in unet_configs]
return Decoder(unets, **decoder_kwargs)
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
@@ -64,23 +119,39 @@ class DecoderDataConfig(BaseModel):
resample_train: bool = False
preprocessing: Dict[str, Any] = {'ToTensor': True}
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transforms = []
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: float = 1e-4
wd: float = 0.01
max_grad_norm: float = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: List[bool] = None # If None, use all unets
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
@@ -104,7 +175,6 @@ class DecoderLoadConfig(BaseModel):
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
unets: List[UnetConfig]
decoder: DecoderConfig
data: DecoderDataConfig
train: DecoderTrainConfig
@@ -117,19 +187,3 @@ class TrainDecoderConfig(BaseModel):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transforms = []
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)

View File

@@ -133,12 +133,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# saving and loading functions
# for diffusion prior

View File

@@ -1,5 +1,7 @@
import time
# time helpers
class Timer:
def __init__(self):
self.reset()
@@ -9,3 +11,9 @@ class Timer:
def elapsed(self):
return time.time() - self.last_time
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.4.5',
version = '0.4.9',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',

View File

@@ -1,9 +1,9 @@
from dalle2_pytorch import Unet, Decoder
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.utils import Timer, print_ribbon
import torchvision
import torch
@@ -85,20 +85,6 @@ def create_dataloaders(
"test_sampling": test_sampling_dataloader
}
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = [Unet(**config.dict()) for config in unets_config]
decoder = Decoder(
unet=unets,
**decoder_config.dict()
)
decoder.to(device=device)
return decoder
def get_dataset_keys(dataloader):
"""
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
@@ -420,7 +406,7 @@ def initialize_training(config):
dataloaders = create_dataloaders (
available_shards=all_shards,
img_preproc = config.img_preproc,
img_preproc = config.data.img_preproc,
train_prop = config.data.splits.train,
val_prop = config.data.splits.val,
test_prop = config.data.splits.test,
@@ -428,7 +414,7 @@ def initialize_training(config):
**config.data.dict()
)
decoder = create_decoder(device, config.decoder, config.unets)
decoder = config.decoder.create().to(device = device)
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")

View File

@@ -9,10 +9,10 @@ from torch import nn
from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.utils import Timer, print_ribbon
from embedding_reader import EmbeddingReader