Compare commits

...

5 Commits
0.4.0 ... 0.4.3

6 changed files with 206 additions and 15 deletions

View File

@@ -1077,6 +1077,7 @@ This library would not have gotten to this working state without the help of
- [x] cross embed layers for downsampling, as an option - [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a> - [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [x] use pydantic for config drive training - [x] use pydantic for config drive training
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
@@ -1091,7 +1092,6 @@ This library would not have gotten to this working state without the help of
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well - [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly - [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89 - [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations ## Citations

View File

@@ -0,0 +1,99 @@
{
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": {
"image_sizes": [64],
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": true
},
"data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"embeddings_url": "s3://bucket/embeddings/path/",
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": true,
"resample_train": false,
"preprocessing": {
"RandomResizedCrop": {
"size": [128, 128],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6,
"device": "cuda:0",
"epoch_samples": null,
"validation_samples": null,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
"n_evaluation_samples": 1000,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 10
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"tracker_type": "console",
"data_path": "./models",
"wandb_entity": "",
"wandb_project": "",
"verbose": false
},
"load": {
"source": null,
"run_path": "",
"file_path": "",
"resume": false
}
}

View File

@@ -1,5 +1,6 @@
import json
from torchvision import transforms as T from torchvision import transforms as T
from pydantic import BaseModel, validator from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
def exists(val): def exists(val):
@@ -38,6 +39,17 @@ class DecoderConfig(BaseModel):
class Config: class Config:
extra = "allow" extra = "allow"
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
if sum([*fields.values()]) != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0')
return fields
class DecoderDataConfig(BaseModel): class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings embeddings_url: str # path to .npy files with embeddings
@@ -47,11 +59,7 @@ class DecoderDataConfig(BaseModel):
end_shard: int = 9999999 end_shard: int = 9999999
shard_width: int = 6 shard_width: int = 6
index_width: int = 4 index_width: int = 4
splits: Dict[str, float] = { splits: TrainSplitConfig
'train': 0.75,
'val': 0.15,
'test': 0.1
}
shuffle_train: bool = True shuffle_train: bool = True
resample_train: bool = False resample_train: bool = False
preprocessing: Dict[str, Any] = {'ToTensor': True} preprocessing: Dict[str, Any] = {'ToTensor': True}
@@ -104,6 +112,12 @@ class TrainDecoderConfig(BaseModel):
tracker: TrackerConfig tracker: TrackerConfig
load: DecoderLoadConfig load: DecoderLoadConfig
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
@property @property
def img_preproc(self): def img_preproc(self):
def _get_transformation(transformation_name, **kwargs): def _get_transformation(transformation_name, **kwargs):

View File

@@ -1,5 +1,6 @@
import time import time
import copy import copy
from pathlib import Path
from math import ceil from math import ceil
from functools import partial, wraps from functools import partial, wraps
from collections.abc import Iterable from collections.abc import Iterable
@@ -55,6 +56,10 @@ def num_to_groups(num, divisor):
arr.append(remainder) arr.append(remainder)
return arr return arr
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
# decorators # decorators
def cast_torch_tensor(fn): def cast_torch_tensor(fn):
@@ -289,6 +294,44 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version()
)
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
torch.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
if only_model:
return
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
def update(self): def update(self):
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
@@ -410,6 +453,44 @@ class DecoderTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.decoder.state_dict(),
version = get_pkg_version()
)
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
torch.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
if only_model:
return
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
@property @property
def unets(self): def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.4.0', version = '0.4.3',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -5,7 +5,6 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer from dalle2_pytorch.utils import Timer
import json
import torchvision import torchvision
import torch import torch
from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.fid import FrechetInceptionDistance
@@ -422,9 +421,9 @@ def initialize_training(config):
dataloaders = create_dataloaders ( dataloaders = create_dataloaders (
available_shards=all_shards, available_shards=all_shards,
img_preproc = config.img_preproc, img_preproc = config.img_preproc,
train_prop = config.data["splits"]["train"], train_prop = config.data.splits.train,
val_prop = config.data["splits"]["val"], val_prop = config.data.splits.val,
test_prop = config.data["splits"]["test"], test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images, n_sample_images=config.train.n_sample_images,
**config.data.dict() **config.data.dict()
) )
@@ -449,9 +448,7 @@ def initialize_training(config):
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file") @click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file): def main(config_file):
print("Recalling config from {}".format(config_file)) print("Recalling config from {}".format(config_file))
with open(config_file) as f: config = TrainDecoderConfig.from_json_path(config_file)
config = json.load(f)
config = TrainDecoderConfig(**config)
initialize_training(config) initialize_training(config)