Compare commits

..

18 Commits
0.3.7 ... 0.4.8

Author SHA1 Message Date
Phil Wang
5c397c9d66 move neural network creations off the configuration file into the pydantic classes 2022-05-22 19:18:18 -07:00
Phil Wang
0f4edff214 derived value for image preprocessing belongs to the data config class 2022-05-22 18:42:40 -07:00
Phil Wang
501a8c7c46 small cleanup 2022-05-22 15:39:38 -07:00
Phil Wang
4e49373fc5 project management 2022-05-22 15:27:40 -07:00
Phil Wang
49de72040c fix decoder trainer optimizer loading (since there are multiple for each unet), also save and load step number correctly 2022-05-22 15:21:00 -07:00
Phil Wang
271a376eaf 0.4.3 2022-05-22 15:10:28 -07:00
Phil Wang
e527002472 take care of saving and loading functions on the diffusion prior and decoder training classes 2022-05-22 15:10:15 -07:00
Phil Wang
c12e067178 let the pydantic config base model take care of loading configuration from json path 2022-05-22 14:47:23 -07:00
Phil Wang
c6629c431a make training splits into its own pydantic base model, validate it sums to 1, make decoder script cleaner 2022-05-22 14:43:22 -07:00
Phil Wang
7ac2fc79f2 add renamed train decoder json file 2022-05-22 14:32:50 -07:00
Phil Wang
a1ef023193 use pydantic to manage decoder training configs + defaults and refactor training script 2022-05-22 14:27:40 -07:00
Phil Wang
d49eca62fa dep 2022-05-21 11:27:52 -07:00
Phil Wang
8aab69b91e final thought 2022-05-21 10:47:45 -07:00
Phil Wang
b432df2f7b final cleanup to decoder script 2022-05-21 10:42:16 -07:00
Phil Wang
ebaa0d28c2 product management 2022-05-21 10:30:52 -07:00
Phil Wang
8b0d459b25 move config parsing logic to own file, consider whether to find an off-the-shelf solution at future date 2022-05-21 10:30:10 -07:00
Phil Wang
0064661729 small cleanup of decoder train script 2022-05-21 10:17:13 -07:00
Phil Wang
b895f52843 appreciation section 2022-05-21 08:32:12 -07:00
11 changed files with 359 additions and 243 deletions

View File

@@ -1034,6 +1034,18 @@ Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Appreciation
This library would not have gotten to this working state without the help of
- <a href="https://github.com/nousr">Zion</a> and <a href="https://github.com/krish240574">Kumar</a> for the diffusion training script
- <a href="https://github.com/Veldrovive">Aidan</a> for the decoder training script and dataloaders
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
... and many others. Thank you! 🙏
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
@@ -1064,6 +1076,9 @@ Once built, images will be saved to the same directory the command is invoked
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [x] use pydantic for config drive training
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
@@ -1073,12 +1088,9 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations

View File

@@ -4,11 +4,12 @@ For more complex configuration, we provide the option of using a configuration f
### Decoder Trainer
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example).
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
**<ins>Unets</ins>:**
**<ins>Unet</ins>:**
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
Each member of this array defines a single unet that will be added to the decoder.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `dim` | Yes | N/A | The starting channels of the unet. |
@@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here.
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `unets` | Yes | N/A | A list of unets, using the configuration above |
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
| `image_size` | Yes | N/A | Not used. Can be any number. |
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |

View File

@@ -1,82 +0,0 @@
"""
Defines the default values for the decoder config
"""
from enum import Enum
class ConfigField(Enum):
REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it.
default_config = {
"unets": ConfigField.REQUIRED,
"decoder": {
"image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet
"image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": True
},
"data": {
"webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images
"embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": True,
"resample_train": False,
"preprocessing": {
"ToTensor": True
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset
"device": "cuda:0",
"epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
"validation_samples": None, # Same as above but for validation.
"use_ema": True,
"ema_beta": 0.99,
"amp": False,
"save_all": False, # Whether to preserve all checkpoints
"save_latest": True, # Whether to always save the latest checkpoint
"save_best": True, # Whether to save the best checkpoint
"unet_training_mask": None # If None, use all unets
},
"evaluate": {
"n_evalation_samples": 1000,
"FID": None,
"IS": None,
"KID": None,
"LPIPS": None
},
"tracker": {
"tracker_type": "console", # Decoder currently supports console and wandb
"data_path": "./models", # The path where files will be saved locally
"wandb_entity": "", # Only needs to be set if tracker_type is wandb
"wandb_project": "",
"verbose": False # Whether to print console logging for non-console trackers
},
"load": {
"source": None, # Supports file and wandb
"run_path": "", # Used only if source is wandb
"file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
"resume": False # If using wandb, whether to resume the run
}
}

View File

@@ -1,18 +1,17 @@
{
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": {
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"image_sizes": [64],
"image_size": [64],
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
@@ -63,7 +62,7 @@
"unet_training_mask": [true]
},
"evaluate": {
"n_evalation_samples": 1000,
"n_evaluation_samples": 1000,
"FID": {
"feature": 64
},

View File

@@ -1712,7 +1712,7 @@ class Decoder(BaseGaussianDiffusion):
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None
if exists(clip):
@@ -1728,7 +1728,7 @@ class Decoder(BaseGaussianDiffusion):
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
else:
self.clip_image_size = image_size
self.clip_image_size = default(image_size, lambda: image_sizes[-1])
self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings

View File

@@ -0,0 +1,150 @@
import json
from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
# pydantic classes
class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple(int)
image_embed_dim: int = None
cond_dim: int = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
class Config:
extra = "allow"
class DecoderConfig(BaseModel):
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
image_size: int = None
image_sizes: ListOrTuple(int) = None
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
learned_variance: bool = True
def create(self):
decoder_kwargs = self.dict()
unet_configs = decoder_kwargs.pop('unets')
unets = [Unet(**config) for config in unet_configs]
return Decoder(unets, **decoder_kwargs)
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
if exists(values.get('image_size')) ^ exists(image_sizes):
return image_sizes
raise ValueError('either image_size or image_sizes is required, but not both')
class Config:
extra = "allow"
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
if sum([*fields.values()]) != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0')
return fields
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
end_shard: int = 9999999
shard_width: int = 6
index_width: int = 4
splits: TrainSplitConfig
shuffle_train: bool = True
resample_train: bool = False
preprocessing: Dict[str, Any] = {'ToTensor': True}
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transforms = []
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: float = 1e-4
wd: float = 0.01
max_grad_norm: float = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
FID: Dict[str, Any] = None
IS: Dict[str, Any] = None
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig
data: DecoderDataConfig
train: DecoderTrainConfig
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)

View File

@@ -1,5 +1,6 @@
import time
import copy
from pathlib import Path
from math import ceil
from functools import partial, wraps
from collections.abc import Iterable
@@ -55,6 +56,10 @@ def num_to_groups(num, divisor):
arr.append(remainder)
return arr
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
# decorators
def cast_torch_tensor(fn):
@@ -128,12 +133,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# saving and loading functions
# for diffusion prior
@@ -191,7 +190,7 @@ class EMA(nn.Module):
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
self.register_buffer('step', torch.tensor([0]))
def restore_ema_model_device(self):
device = self.initted.device
@@ -287,7 +286,47 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
self.register_buffer('step', torch.tensor([0]))
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version(),
step = self.step.item()
)
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
torch.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
def update(self):
if exists(self.max_grad_norm):
@@ -410,6 +449,57 @@ class DecoderTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
model = self.decoder.state_dict(),
version = get_pkg_version(),
step = self.step.item()
)
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
torch.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
scaler.load_state_dict(loaded_obj[scaler_key])
optimizer.load_state_dict(loaded_obj[optimizer_key])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

View File

@@ -1,5 +1,7 @@
import time
# time helpers
class Timer:
def __init__(self):
self.reset()
@@ -9,3 +11,9 @@ class Timer:
def elapsed(self):
return time.time() - self.last_time
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.3.7',
version = '0.4.8',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -32,6 +32,7 @@ setup(
'kornia>=0.5.4',
'numpy',
'pillow',
'pydantic',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',

View File

@@ -1,13 +1,11 @@
from dalle2_pytorch import Unet, Decoder
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from configs.decoder_defaults import default_config, ConfigField
import json
import torchvision
from torchvision import transforms as T
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
@@ -16,6 +14,17 @@ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import webdataset as wds
import click
# constants
TRAIN_CALC_LOSS_EVERY_ITERS = 10
VALID_CALC_LOSS_EVERY_ITERS = 10
# helpers functions
def exists(val):
return val is not None
# main functions
def create_dataloaders(
available_shards,
@@ -76,23 +85,6 @@ def create_dataloaders(
"test_sampling": test_sampling_dataloader
}
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = []
for i in range(0, len(unets_config)):
unets.append(Unet(
**unets_config[i]
))
decoder = Decoder(
unet=unets,
**decoder_config
)
decoder.to(device=device)
return decoder
def get_dataset_keys(dataloader):
"""
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
@@ -147,33 +139,33 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
metrics = {}
# Prepare the data
examples = get_example_data(dataloader, device, n_evalation_samples)
examples = get_example_data(dataloader, device, n_evaluation_samples)
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
if FID is not None:
if exists(FID):
fid = FrechetInceptionDistance(**FID)
fid.to(device=device)
fid.update(int_real_images, real=True)
fid.update(int_generated_images, real=False)
metrics["FID"] = fid.compute().item()
if IS is not None:
if exists(IS):
inception = InceptionScore(**IS)
inception.to(device=device)
inception.update(int_real_images)
is_mean, is_std = inception.compute()
metrics["IS_mean"] = is_mean.item()
metrics["IS_std"] = is_std.item()
if KID is not None:
if exists(KID):
kernel_inception = KernelInceptionDistance(**KID)
kernel_inception.to(device=device)
kernel_inception.update(int_real_images, real=True)
@@ -181,7 +173,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=
kid_mean, kid_std = kernel_inception.compute()
metrics["KID_mean"] = kid_mean.item()
metrics["KID_std"] = kid_std.item()
if LPIPS is not None:
if exists(LPIPS):
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
@@ -245,11 +237,11 @@ def train(
start_epoch = 0
validation_losses = []
if load_config is not None and load_config["source"] is not None:
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
if exists(load_config) and exists(load_config.source):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
trainer.to(device=inference_device)
if unet_training_mask is None:
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
@@ -264,7 +256,6 @@ def train(
for epoch in range(start_epoch, epochs):
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
trainer.train()
timer = Timer()
@@ -273,24 +264,28 @@ def train(
last_snapshot = 0
losses = []
for i, (img, emb) in enumerate(dataloaders["train"]):
step += 1
sample += img.shape[0]
img, emb = send_to_device((img, emb))
trainer.train()
for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training
if unet_training_mask[unet-1]: # Unet index is the unet number - 1
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
trainer.update(unet_number=unet)
losses.append(loss)
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
trainer.update(unet_number=unet)
losses.append(loss)
samples_per_sec = (sample - last_sample) / timer.elapsed()
timer.reset()
last_sample = sample
if i % 10 == 0:
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
average_loss = sum(losses) / len(losses)
log_data = {
"Training loss": average_loss,
@@ -310,14 +305,15 @@ def train(
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
if n_sample_images is not None and n_sample_images > 0:
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
trainer.train()
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
if epoch_samples is not None and sample >= epoch_samples:
if exists(epoch_samples) and sample >= epoch_samples:
break
trainer.eval()
@@ -334,12 +330,12 @@ def train(
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
average_loss += loss
if i % 10 == 0:
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
print(f"Loss: {average_loss / (i+1)}")
print("")
if validation_samples is not None and sample >= validation_samples:
if exists(validation_samples) and sample >= validation_samples:
break
average_loss /= i+1
@@ -349,8 +345,7 @@ def train(
tracker.log(log_data, step=step, verbose=True)
# Compute evaluation metrics
trainer.eval()
if evaluate_config is not None:
if exists(evaluate_config):
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
tracker.log(evaluation, step=step, verbose=True)
@@ -376,21 +371,25 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
tracker_config = config["tracker"]
tracker_config = config.tracker
init_config = {}
init_config["config"] = config.config
if exists(tracker_config.init_config):
init_config["config"] = tracker_config.init_config
if tracker_type == "console":
tracker = ConsoleTracker(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config["load"]
if load_config["source"] == "wandb" and load_config["resume"]:
load_config = config.load
if load_config.source == "wandb" and load_config.resume:
# Then we are resuming the run load_config["run_path"]
run_id = config["resume"]["wandb_run_path"].split("/")[-1]
run_id = load_config.run_path.split("/")[-1]
init_config["id"] = run_id
init_config["resume"] = "must"
init_config["entity"] = tracker_config["wandb_entity"]
init_config["project"] = tracker_config["wandb_project"]
init_config["entity"] = tracker_config.wandb_entity
init_config["project"] = tracker_config.wandb_project
tracker = WandbTracker(data_path)
tracker.init(**init_config)
else:
@@ -399,106 +398,43 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
def initialize_training(config):
# Create the save path
if "cuda" in config["train"]["device"]:
if "cuda" in config.train.device:
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device(config["train"]["device"])
device = torch.device(config.train.device)
torch.cuda.set_device(device)
all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
dataloaders = create_dataloaders (
available_shards=all_shards,
img_preproc = config.get_preprocessing(),
train_prop = config["data"]["splits"]["train"],
val_prop = config["data"]["splits"]["val"],
test_prop = config["data"]["splits"]["test"],
n_sample_images=config["train"]["n_sample_images"],
**config["data"]
img_preproc = config.data.img_preproc,
train_prop = config.data.splits.train,
val_prop = config.data.splits.val,
test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images,
**config.data.dict()
)
decoder = create_decoder(device, config["decoder"], config["unets"])
decoder = config.decoder.create().to(device = device)
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")
tracker = create_tracker(config, **config["tracker"])
tracker = create_tracker(config, **config.tracker.dict())
train(dataloaders, decoder,
tracker=tracker,
inference_device=device,
load_config=config["load"],
evaluate_config=config["evaluate"],
**config["train"],
load_config=config.load,
evaluate_config=config.evaluate,
**config.train.dict(),
)
class TrainDecoderConfig:
def __init__(self, config):
self.config = self.map_config(config, default_config)
def map_config(self, config, defaults):
"""
Returns a dictionary containing all config options in the union of config and defaults.
If the config value is an array, apply the default value to each element.
If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
"""
def _check_option(option, option_config, option_defaults):
for key, value in option_defaults.items():
if key not in option_config:
if value == ConfigField.REQUIRED:
raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
option_config[key] = value
for key, value in defaults.items():
if key not in config:
# Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
if value == ConfigField.REQUIRED:
raise RuntimeError("Required config value '{}' not supplied".format(key))
elif isinstance(value, dict):
config[key] = {}
elif isinstance(value, list):
config[key] = [{}]
# Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
# If it is a list, then we need to check each element
if isinstance(value, list):
assert isinstance(config[key], list)
for element in config[key]:
_check_option(key, element, value[0])
elif isinstance(value, dict):
_check_option(key, config[key], value)
# This object does not support checking
return config
def get_preprocessing(self):
"""
Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
"""
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transformations = []
for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
if isinstance(transformation_kwargs, dict):
transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
else:
transformations.append(_get_transformation(transformation_name))
return T.Compose(transformations)
def __getitem__(self, key):
return self.config[key]
# Create a simple click command line interface to load the config and start the training
@click.command()
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file):
print("Recalling config from {}".format(config_file))
with open(config_file) as f:
config = json.load(f)
config = TrainDecoderConfig(config)
config = TrainDecoderConfig.from_json_path(config_file)
initialize_training(config)

View File

@@ -9,10 +9,10 @@ from torch import nn
from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.utils import Timer, print_ribbon
from embedding_reader import EmbeddingReader