Compare commits

...

9 Commits
0.4.0 ... 0.4.7

8 changed files with 252 additions and 45 deletions

View File

@@ -1077,6 +1077,8 @@ This library would not have gotten to this working state without the help of
- [x] cross embed layers for downsampling, as an option - [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a> - [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [x] use pydantic for config drive training - [x] use pydantic for config drive training
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
@@ -1086,12 +1088,9 @@ This library would not have gotten to this working state without the help of
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well - [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89 - [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations ## Citations

View File

@@ -0,0 +1,99 @@
{
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": {
"image_sizes": [64],
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": true
},
"data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"embeddings_url": "s3://bucket/embeddings/path/",
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": true,
"resample_train": false,
"preprocessing": {
"RandomResizedCrop": {
"size": [128, 128],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6,
"device": "cuda:0",
"epoch_samples": null,
"validation_samples": null,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
"n_evaluation_samples": 1000,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 10
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"tracker_type": "console",
"data_path": "./models",
"wandb_entity": "",
"wandb_project": "",
"verbose": false
},
"load": {
"source": null,
"run_path": "",
"file_path": "",
"resume": false
}
}

View File

@@ -1,5 +1,6 @@
import json
from torchvision import transforms as T from torchvision import transforms as T
from pydantic import BaseModel, validator from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
def exists(val): def exists(val):
@@ -38,6 +39,17 @@ class DecoderConfig(BaseModel):
class Config: class Config:
extra = "allow" extra = "allow"
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
if sum([*fields.values()]) != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0')
return fields
class DecoderDataConfig(BaseModel): class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings embeddings_url: str # path to .npy files with embeddings
@@ -47,15 +59,27 @@ class DecoderDataConfig(BaseModel):
end_shard: int = 9999999 end_shard: int = 9999999
shard_width: int = 6 shard_width: int = 6
index_width: int = 4 index_width: int = 4
splits: Dict[str, float] = { splits: TrainSplitConfig
'train': 0.75,
'val': 0.15,
'test': 0.1
}
shuffle_train: bool = True shuffle_train: bool = True
resample_train: bool = False resample_train: bool = False
preprocessing: Dict[str, Any] = {'ToTensor': True} preprocessing: Dict[str, Any] = {'ToTensor': True}
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transforms = []
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)
class DecoderTrainConfig(BaseModel): class DecoderTrainConfig(BaseModel):
epochs: int = 20 epochs: int = 20
lr: float = 1e-4 lr: float = 1e-4
@@ -104,18 +128,8 @@ class TrainDecoderConfig(BaseModel):
tracker: TrackerConfig tracker: TrackerConfig
load: DecoderLoadConfig load: DecoderLoadConfig
@property @classmethod
def img_preproc(self): def from_json_path(cls, json_path):
def _get_transformation(transformation_name, **kwargs): with open(json_path) as f:
if transformation_name == "RandomResizedCrop": config = json.load(f)
return T.RandomResizedCrop(**kwargs) return cls(**config)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transforms = []
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)

View File

@@ -1,5 +1,6 @@
import time import time
import copy import copy
from pathlib import Path
from math import ceil from math import ceil
from functools import partial, wraps from functools import partial, wraps
from collections.abc import Iterable from collections.abc import Iterable
@@ -55,6 +56,10 @@ def num_to_groups(num, divisor):
arr.append(remainder) arr.append(remainder)
return arr return arr
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
# decorators # decorators
def cast_torch_tensor(fn): def cast_torch_tensor(fn):
@@ -128,12 +133,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
chunk_size_frac = chunk_size / batch_size chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs) yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# saving and loading functions # saving and loading functions
# for diffusion prior # for diffusion prior
@@ -191,7 +190,7 @@ class EMA(nn.Module):
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0 self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0]))
def restore_ema_model_device(self): def restore_ema_model_device(self):
device = self.initted.device device = self.initted.device
@@ -287,7 +286,47 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0]))
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version(),
step = self.step.item()
)
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
torch.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
def update(self): def update(self):
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
@@ -410,6 +449,57 @@ class DecoderTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
model = self.decoder.state_dict(),
version = get_pkg_version(),
step = self.step.item()
)
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
torch.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
scaler.load_state_dict(loaded_obj[scaler_key])
optimizer.load_state_dict(loaded_obj[optimizer_key])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
@property @property
def unets(self): def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

View File

@@ -1,5 +1,7 @@
import time import time
# time helpers
class Timer: class Timer:
def __init__(self): def __init__(self):
self.reset() self.reset()
@@ -9,3 +11,9 @@ class Timer:
def elapsed(self): def elapsed(self):
return time.time() - self.last_time return time.time() - self.last_time
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.4.0', version = '0.4.7',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -1,11 +1,10 @@
from dalle2_pytorch import Unet, Decoder from dalle2_pytorch import Unet, Decoder
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer from dalle2_pytorch.utils import Timer, print_ribbon
import json
import torchvision import torchvision
import torch import torch
from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.fid import FrechetInceptionDistance
@@ -421,10 +420,10 @@ def initialize_training(config):
dataloaders = create_dataloaders ( dataloaders = create_dataloaders (
available_shards=all_shards, available_shards=all_shards,
img_preproc = config.img_preproc, img_preproc = config.data.img_preproc,
train_prop = config.data["splits"]["train"], train_prop = config.data.splits.train,
val_prop = config.data["splits"]["val"], val_prop = config.data.splits.val,
test_prop = config.data["splits"]["test"], test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images, n_sample_images=config.train.n_sample_images,
**config.data.dict() **config.data.dict()
) )
@@ -449,9 +448,7 @@ def initialize_training(config):
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file") @click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file): def main(config_file):
print("Recalling config from {}".format(config_file)) print("Recalling config from {}".format(config_file))
with open(config_file) as f: config = TrainDecoderConfig.from_json_path(config_file)
config = json.load(f)
config = TrainDecoderConfig(**config)
initialize_training(config) initialize_training(config)

View File

@@ -9,10 +9,10 @@ from torch import nn
from dalle2_pytorch.dataloaders import make_splits from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer from dalle2_pytorch.utils import Timer, print_ribbon
from embedding_reader import EmbeddingReader from embedding_reader import EmbeddingReader