Compare commits

..

1 Commits

12 changed files with 186 additions and 456 deletions

View File

@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Status
@@ -24,11 +24,6 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
*ongoing at 21k steps*
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- DALL-E 2 🚧
## Install
```bash
@@ -1081,10 +1076,6 @@ This library would not have gotten to this working state without the help of
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [x] use pydantic for config drive training
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
@@ -1094,9 +1085,12 @@ This library would not have gotten to this working state without the help of
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations
@@ -1195,12 +1189,4 @@ This library would not have gotten to this working state without the help of
}
```
```bibtex
@misc{Saharia2022,
title = {Imagen: unprecedented photorealism × deep level of language understanding},
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
year = {2022}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -4,12 +4,11 @@ For more complex configuration, we provide the option of using a configuration f
### Decoder Trainer
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example).
**<ins>Unet</ins>:**
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
**<ins>Unets</ins>:**
Each member of this array defines a single unet that will be added to the decoder.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `dim` | Yes | N/A | The starting channels of the unet. |
@@ -23,7 +22,6 @@ Any parameter from the `Unet` constructor can also be given here.
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `unets` | Yes | N/A | A list of unets, using the configuration above |
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
| `image_size` | Yes | N/A | Not used. Can be any number. |
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |

View File

@@ -0,0 +1,82 @@
"""
Defines the default values for the decoder config
"""
from enum import Enum
class ConfigField(Enum):
REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it.
default_config = {
"unets": ConfigField.REQUIRED,
"decoder": {
"image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet
"image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": True
},
"data": {
"webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images
"embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": True,
"resample_train": False,
"preprocessing": {
"ToTensor": True
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset
"device": "cuda:0",
"epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
"validation_samples": None, # Same as above but for validation.
"use_ema": True,
"ema_beta": 0.99,
"amp": False,
"save_all": False, # Whether to preserve all checkpoints
"save_latest": True, # Whether to always save the latest checkpoint
"save_best": True, # Whether to save the best checkpoint
"unet_training_mask": None # If None, use all unets
},
"evaluate": {
"n_evalation_samples": 1000,
"FID": None,
"IS": None,
"KID": None,
"LPIPS": None
},
"tracker": {
"tracker_type": "console", # Decoder currently supports console and wandb
"data_path": "./models", # The path where files will be saved locally
"wandb_entity": "", # Only needs to be set if tracker_type is wandb
"wandb_project": "",
"verbose": False # Whether to print console logging for non-console trackers
},
"load": {
"source": None, # Supports file and wandb
"run_path": "", # Used only if source is wandb
"file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
"resume": False # If using wandb, whether to resume the run
}
}

View File

@@ -1,17 +1,18 @@
{
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": {
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"image_sizes": [64],
"image_size": [64],
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
@@ -62,7 +63,7 @@
"unet_training_mask": [true]
},
"evaluate": {
"n_evaluation_samples": 1000,
"n_evalation_samples": 1000,
"FID": {
"feature": 64
},

View File

@@ -890,8 +890,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
)
if exists(clip):
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
@@ -1107,20 +1105,13 @@ class Block(nn.Module):
groups = 8
):
super().__init__()
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.project(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
self.block = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
nn.GroupNorm(groups, dim_out),
nn.SiLU()
)
def forward(self, x):
return self.block(x)
class ResnetBlock(nn.Module):
def __init__(
@@ -1139,7 +1130,7 @@ class ResnetBlock(nn.Module):
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out * 2)
nn.Linear(time_cond_dim, dim_out)
)
self.cross_attn = None
@@ -1159,14 +1150,11 @@ class ResnetBlock(nn.Module):
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None):
h = self.block1(x)
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
h = rearrange(time_emb, 'b c -> b c 1 1') + h
if exists(self.cross_attn):
assert exists(cond)
@@ -1714,8 +1702,6 @@ class Decoder(BaseGaussianDiffusion):
vb_loss_weight = 0.001,
unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9
):
super().__init__(
beta_schedule = beta_schedule,
@@ -1724,19 +1710,12 @@ class Decoder(BaseGaussianDiffusion):
)
self.unconditional = unconditional
# text conditioning
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
self.condition_on_text_encodings = condition_on_text_encodings
# clip
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None
if exists(clip):
assert not unconditional, 'clip must not be given if doing unconditional image training'
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
@@ -1746,20 +1725,13 @@ class Decoder(BaseGaussianDiffusion):
assert isinstance(clip, BaseClipAdapter)
self.clip = clip
# determine image size, with image_size and image_sizes taking precedence
if exists(image_size) or exists(image_sizes):
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
image_size = default(image_size, lambda: image_sizes[-1])
elif exists(clip):
image_size = clip.image_size
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
else:
raise Error('either image_size, image_sizes, or clip must be given to decoder')
self.clip_image_size = image_size
self.channels = channels
# channels
self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
@@ -1801,7 +1773,7 @@ class Decoder(BaseGaussianDiffusion):
# unet image sizes
image_sizes = default(image_sizes, (image_size,))
image_sizes = default(image_sizes, (self.clip_image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
@@ -1838,13 +1810,7 @@ class Decoder(BaseGaussianDiffusion):
self.clip_denoised = clip_denoised
self.clip_x_start = clip_x_start
# dynamic thresholding settings, if clipping denoised during sampling
self.use_dynamic_thres = use_dynamic_thres
self.dynamic_thres_percentile = dynamic_thres_percentile
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
@@ -1885,21 +1851,7 @@ class Decoder(BaseGaussianDiffusion):
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
# s is the threshold amount
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x_recon, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)

View File

@@ -12,7 +12,6 @@ def get_optimizer(
betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,
**kwargs
):
if filter_by_requires_grad:
@@ -22,13 +21,11 @@ def get_optimizer(
return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
if group_wd_params:
wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
params = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -1,189 +0,0 @@
import json
from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
# pydantic classes
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
num_text_embeds: int = 1
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.
final_proj: bool = True
normformer: bool = False
rotary_emb: bool = True
class DiffusionPriorConfig(BaseModel):
# only clip-less diffusion prior config for now
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
image_channels: int = 3
timesteps: int = 1000
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
beta_schedule: str = 'cosine'
def create(self):
kwargs = self.dict()
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
class Config:
extra = "allow"
class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple(int)
image_embed_dim: int = None
cond_dim: int = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
class Config:
extra = "allow"
class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig)
image_size: int = None
image_sizes: ListOrTuple(int) = None
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
learned_variance: bool = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
def create(self):
decoder_kwargs = self.dict()
unet_configs = decoder_kwargs.pop('unets')
unets = [Unet(**config) for config in unet_configs]
return Decoder(unets, **decoder_kwargs)
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
if exists(values.get('image_size')) ^ exists(image_sizes):
return image_sizes
raise ValueError('either image_size or image_sizes is required, but not both')
class Config:
extra = "allow"
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
if sum([*fields.values()]) != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0')
return fields
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
end_shard: int = 9999999
shard_width: int = 6
index_width: int = 4
splits: TrainSplitConfig
shuffle_train: bool = True
resample_train: bool = False
preprocessing: Dict[str, Any] = {'ToTensor': True}
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transforms = []
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: float = 1e-4
wd: float = 0.01
max_grad_norm: float = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
FID: Dict[str, Any] = None
IS: Dict[str, Any] = None
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig
data: DecoderDataConfig
train: DecoderTrainConfig
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)

View File

@@ -1,6 +1,5 @@
import time
import copy
from pathlib import Path
from math import ceil
from functools import partial, wraps
from collections.abc import Iterable
@@ -56,10 +55,6 @@ def num_to_groups(num, divisor):
arr.append(remainder)
return arr
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
# decorators
def cast_torch_tensor(fn):
@@ -133,6 +128,12 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# saving and loading functions
# for diffusion prior
@@ -190,7 +191,7 @@ class EMA(nn.Module):
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0]))
self.register_buffer('step', torch.tensor([0.]))
def restore_ema_model_device(self):
device = self.initted.device
@@ -254,7 +255,6 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6,
max_grad_norm = None,
amp = False,
group_wd_params = True,
**kwargs
):
super().__init__()
@@ -280,7 +280,6 @@ class DiffusionPriorTrainer(nn.Module):
lr = lr,
wd = wd,
eps = eps,
group_wd_params = group_wd_params,
**kwargs
)
@@ -288,50 +287,7 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0]))
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version(),
step = self.step.item(),
**kwargs
)
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
torch.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
self.register_buffer('step', torch.tensor([0.]))
def update(self):
if exists(self.max_grad_norm):
@@ -412,7 +368,6 @@ class DecoderTrainer(nn.Module):
eps = 1e-8,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
**kwargs
):
super().__init__()
@@ -438,7 +393,6 @@ class DecoderTrainer(nn.Module):
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)
@@ -456,60 +410,6 @@ class DecoderTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
model = self.decoder.state_dict(),
version = get_pkg_version(),
step = self.step.item(),
**kwargs
)
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
torch.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
scaler.load_state_dict(loaded_obj[scaler_key])
optimizer.load_state_dict(loaded_obj[optimizer_key])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

View File

@@ -1,7 +1,5 @@
import time
# time helpers
class Timer:
def __init__(self):
self.reset()
@@ -11,9 +9,3 @@ class Timer:
def elapsed(self):
return time.time() - self.last_time
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.5.2',
version = '0.3.8',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -32,7 +32,6 @@ setup(
'kornia>=0.5.4',
'numpy',
'pillow',
'pydantic',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',

View File

@@ -1,10 +1,11 @@
from dalle2_pytorch import Unet, Decoder
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.utils import Timer
import json
import torchvision
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
@@ -85,6 +86,20 @@ def create_dataloaders(
"test_sampling": test_sampling_dataloader
}
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = [Unet(**config) for config in unets_config]
decoder = Decoder(
unet=unets,
**decoder_config
)
decoder.to(device=device)
return decoder
def get_dataset_keys(dataloader):
"""
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
@@ -139,13 +154,13 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
metrics = {}
# Prepare the data
examples = get_example_data(dataloader, device, n_evaluation_samples)
examples = get_example_data(dataloader, device, n_evalation_samples)
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
@@ -237,8 +252,8 @@ def train(
start_epoch = 0
validation_losses = []
if exists(load_config) and exists(load_config.source):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
if exists(load_config) and exists(load_config["source"]):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
trainer.to(device=inference_device)
if not exists(unet_training_mask):
@@ -256,6 +271,7 @@ def train(
for epoch in range(start_epoch, epochs):
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
trainer.train()
timer = Timer()
@@ -264,13 +280,11 @@ def train(
last_snapshot = 0
losses = []
for i, (img, emb) in enumerate(dataloaders["train"]):
step += 1
sample += img.shape[0]
img, emb = send_to_device((img, emb))
trainer.train()
for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
@@ -285,7 +299,7 @@ def train(
timer.reset()
last_sample = sample
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
if i % CALC_LOSS_EVERY_ITERS == 0:
average_loss = sum(losses) / len(losses)
log_data = {
"Training loss": average_loss,
@@ -305,12 +319,11 @@ def train(
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
trainer.train()
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
if exists(epoch_samples) and sample >= epoch_samples:
@@ -345,6 +358,7 @@ def train(
tracker.log(log_data, step=step, verbose=True)
# Compute evaluation metrics
trainer.eval()
if exists(evaluate_config):
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
@@ -371,25 +385,21 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
tracker_config = config.tracker
tracker_config = config["tracker"]
init_config = {}
if exists(tracker_config.init_config):
init_config["config"] = tracker_config.init_config
init_config["config"] = config.config
if tracker_type == "console":
tracker = ConsoleTracker(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config.load
if load_config.source == "wandb" and load_config.resume:
load_config = config["load"]
if load_config["source"] == "wandb" and load_config["resume"]:
# Then we are resuming the run load_config["run_path"]
run_id = load_config.run_path.split("/")[-1]
run_id = config["resume"]["wandb_run_path"].split("/")[-1]
init_config["id"] = run_id
init_config["resume"] = "must"
init_config["entity"] = tracker_config.wandb_entity
init_config["project"] = tracker_config.wandb_project
init_config["entity"] = tracker_config["wandb_entity"]
init_config["project"] = tracker_config["wandb_project"]
tracker = WandbTracker(data_path)
tracker.init(**init_config)
else:
@@ -398,35 +408,35 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
def initialize_training(config):
# Create the save path
if "cuda" in config.train.device:
if "cuda" in config["train"]["device"]:
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device(config.train.device)
device = torch.device(config["train"]["device"])
torch.cuda.set_device(device)
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
dataloaders = create_dataloaders (
available_shards=all_shards,
img_preproc = config.data.img_preproc,
train_prop = config.data.splits.train,
val_prop = config.data.splits.val,
test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images,
**config.data.dict()
img_preproc = config.get_preprocessing(),
train_prop = config["data"]["splits"]["train"],
val_prop = config["data"]["splits"]["val"],
test_prop = config["data"]["splits"]["test"],
n_sample_images=config["train"]["n_sample_images"],
**config["data"]
)
decoder = config.decoder.create().to(device = device)
decoder = create_decoder(device, config["decoder"], config["unets"])
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")
tracker = create_tracker(config, **config.tracker.dict())
tracker = create_tracker(config, **config["tracker"])
train(dataloaders, decoder,
tracker=tracker,
inference_device=device,
load_config=config.load,
evaluate_config=config.evaluate,
**config.train.dict(),
load_config=config["load"],
evaluate_config=config["evaluate"],
**config["train"],
)
# Create a simple click command line interface to load the config and start the training
@@ -434,7 +444,9 @@ def initialize_training(config):
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file):
print("Recalling config from {}".format(config_file))
config = TrainDecoderConfig.from_json_path(config_file)
with open(config_file) as f:
config = json.load(f)
config = TrainDecoderConfig(config)
initialize_training(config)

View File

@@ -9,10 +9,10 @@ from torch import nn
from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.utils import Timer
from embedding_reader import EmbeddingReader