Compare commits

..

1 Commits

Author SHA1 Message Date
Romain Beaumont
3a1dea7d97 Fix decoder test by fixing the resizing output size 2022-07-09 11:36:22 +02:00
12 changed files with 266 additions and 861 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
github: [nousr, Veldrovive, lucidrains]
github: [lucidrains]

119
README.md
View File

@@ -45,7 +45,6 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
- <a href="https://github.com/malumadev">MalumaDev</a> for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
@@ -356,8 +355,7 @@ prior_network = DiffusionPriorNetwork(
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 1000,
sample_timesteps = 64,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
@@ -421,7 +419,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
Working example below
@@ -585,7 +583,6 @@ unet1 = Unet(
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8),
text_embed_dim = 512,
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()
@@ -601,8 +598,7 @@ decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1000,
sample_timesteps = (250, 27),
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
@@ -628,82 +624,6 @@ images = dalle2(
Now you'll just have to worry about training the Prior and the Decoder!
## Inpainting
Inpainting is also built into the `Decoder`. You simply have to pass in the `inpaint_image` and `inpaint_mask` (boolean tensor where `True` indicates which regions of the inpaint image to keep)
This repository uses the formulation put forth by <a href="https://arxiv.org/abs/2201.09865">Lugmayr et al. in Repaint</a>
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# 2 unets for the decoder (a la cascading DDPM)
unet = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 1, 1, 1)
).cuda()
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
unet = (unet,), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256,), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
loss = decoder(images, unet_number = 1)
loss.backward()
# do the above for many steps for both unets
mock_image_embed = torch.randn(1, 512).cuda()
# then to do inpainting
inpaint_image = torch.randn(1, 3, 256, 256).cuda() # (batch, channels, height, width)
inpaint_mask = torch.ones(1, 256, 256).bool().cuda() # (batch, height, width)
inpainted_images = decoder.sample(
image_embed = mock_image_embed,
inpaint_image = inpaint_image, # just pass in the inpaint image
inpaint_mask = inpaint_mask # and the mask
)
inpainted_images.shape # (1, 3, 256, 256)
```
## Experimental
### DALL-E2 with Latent Diffusion
@@ -1067,12 +987,26 @@ dataset = ImageEmbeddingDataset(
)
```
### Scripts
### Scripts (wip)
#### `train_diffusion_prior.py`
For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
## CLI (wip)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/big-sleep">template</a>
## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
@@ -1110,10 +1044,11 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
- [x] allow for unet to be able to condition non-cross attention style as well
- [x] speed up inference, read up on papers (ddim)
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
## Citations
@@ -1231,14 +1166,4 @@ For detailed information on training the diffusion prior, please refer to the [d
}
```
```bibtex
@article{Lugmayr2022RePaintIU,
title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
journal = {ArXiv},
year = {2022},
volume = {abs/2201.09865}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -69,12 +69,14 @@ Settings for controlling the training hyperparameters.
| `wd` | No | `0.01` | The weight decay. |
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
| `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. |
| `device` | No | `cuda:0` | The device to train on. |
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
| `ema_beta` | No | `0.99` | The ema coefficient. |
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
**<ins>Evaluate</ins>:**
@@ -161,10 +163,9 @@ All save locations have these configuration options
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
| `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. |
| `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). |
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
If using `local`
| Option | Required | Default | Description |
@@ -176,6 +177,7 @@ If using `huggingface`
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `huggingface`. |
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
If using `wandb`

View File

@@ -56,6 +56,9 @@
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
@@ -93,15 +96,14 @@
},
"save": [{
"save_to": "wandb",
"save_latest_to": "latest.pth"
"save_to": "wandb"
}, {
"save_to": "huggingface",
"huggingface_repo": "Veldrovive/test_model",
"save_latest_to": "path/to/model_dir/latest.pth",
"save_best_to": "path/to/model_dir/best.pth",
"save_meta_to": "path/to/directory/for/assorted/files",
"save_all": true,
"save_latest": true,
"save_best": true,
"save_type": "model"
}]

View File

@@ -61,6 +61,9 @@
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
@@ -93,8 +96,7 @@
},
"save": [{
"save_to": "local",
"save_latest_to": "latest.pth"
"save_to": "local"
}]
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,15 +4,13 @@ import json
from pathlib import Path
import shutil
from itertools import zip_longest
from typing import Any, Optional, List, Union
from typing import Optional, List, Union
from pydantic import BaseModel
import torch
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.utils import import_or_print_error
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.version import __version__
from packaging import version
# constants
@@ -23,6 +21,16 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
# load file functions
def load_wandb_file(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return file_reference.name
def load_local_file(file_path, **kwargs):
return file_path
class BaseLogger:
"""
An abstract class representing an object that can log data.
@@ -226,7 +234,7 @@ class LocalLoader(BaseLoader):
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be loaded
if not self.file_path.exists() and not self.only_auto_resume:
if not self.file_path.exists():
raise FileNotFoundError(f'Model not found at {self.file_path}')
def recall(self) -> dict:
@@ -275,9 +283,9 @@ def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
class BaseSaver:
def __init__(self,
data_path: str,
save_latest_to: Optional[Union[str, bool]] = None,
save_best_to: Optional[Union[str, bool]] = None,
save_meta_to: Optional[str] = None,
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
save_best_to: Optional[Union[str, bool]] = 'best.pth',
save_meta_to: str = './',
save_type: str = 'checkpoint',
**kwargs
):
@@ -287,10 +295,10 @@ class BaseSaver:
self.save_best_to = save_best_to
self.saving_best = save_best_to is not None and save_best_to is not False
self.save_meta_to = save_meta_to
self.saving_meta = save_meta_to is not None
self.save_type = save_type
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
@@ -451,11 +459,6 @@ class Tracker:
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
print(f"New logger config: {self.logger.__dict__}")
self.save_metadata = dict(
version = version.parse(__version__)
) # Data that will be saved alongside the checkpoint or model
self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata
assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode:
# The only thing we need is a loader
@@ -504,15 +507,8 @@ class Tracker:
# Save the config under config_name in the root folder of data_path
shutil.copy(current_config_path, self.data_path / config_name)
for saver in self.savers:
if saver.saving_meta:
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
def add_save_metadata(self, state_dict_key: str, metadata: Any):
"""
Adds a new piece of metadata that will be saved along with the model or decoder.
"""
self.save_metadata[state_dict_key] = metadata
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
"""
@@ -522,38 +518,24 @@ class Tracker:
"""
assert save_type in ['checkpoint', 'model']
if save_type == 'checkpoint':
# Create a metadata dict without the blacklisted keys so we do not error when we create the state dict
metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
trainer.save(file_path, overwrite=True, **kwargs, **metadata)
trainer.save(file_path, overwrite=True, **kwargs)
elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
prior: DiffusionPrior = trainer.unwrap_model(prior)
# Remove CLIP if it is part of the model
original_clip = prior.clip
prior.clip = None
model_state_dict = prior.state_dict()
prior.clip = original_clip
state_dict = trainer.unwrap_model(prior).state_dict()
torch.save(state_dict, file_path)
elif isinstance(trainer, DecoderTrainer):
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# Remove CLIP if it is part of the model
original_clip = decoder.clip
decoder.clip = None
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
if trainer.use_ema:
trainable_unets = decoder.unets
decoder.unets = trainer.unets # Swap EMA unets in
model_state_dict = decoder.state_dict()
state_dict = decoder.state_dict()
decoder.unets = trainable_unets # Swap back
else:
model_state_dict = decoder.state_dict()
decoder.clip = original_clip
state_dict = decoder.state_dict()
torch.save(state_dict, file_path)
else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
state_dict = {
**self.save_metadata,
'model': model_state_dict
}
torch.save(state_dict, file_path)
return Path(file_path)
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):

View File

@@ -129,7 +129,6 @@ class AdapterConfig(BaseModel):
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
max_text_len: int = None
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
@@ -137,7 +136,6 @@ class DiffusionPriorNetworkConfig(BaseModel):
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_in: bool = False
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.
@@ -156,7 +154,6 @@ class DiffusionPriorConfig(BaseModel):
image_size: int
image_channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[int] = None
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
@@ -225,7 +222,6 @@ class UnetConfig(BaseModel):
self_attn: ListOrTuple(int)
attn_dim_head: int = 32
attn_heads: int = 16
init_cross_embed: bool = True
class Config:
extra = "allow"
@@ -237,7 +233,6 @@ class DecoderConfig(BaseModel):
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable(int)] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True
@@ -306,11 +301,9 @@ class DecoderTrainConfig(BaseModel):
max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
save_immediately: bool = False
use_ema: bool = True
ema_beta: float = 0.999
amp: bool = False

View File

@@ -498,27 +498,23 @@ class DecoderTrainer(nn.Module):
warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
if isinstance(unet, nn.Identity):
optimizers.append(None)
schedulers.append(None)
warmup_schedulers.append(None)
else:
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)
optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
optimizers.append(optimizer)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
schedulers.append(scheduler)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
schedulers.append(scheduler)
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
@@ -540,19 +536,11 @@ class DecoderTrainer(nn.Module):
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip
clip.to(precision_type)
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = decoder
# prepare dataloaders
train_loader = val_loader = None
if exists(dataloaders):
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers))
self.train_loader = train_loader
self.val_loader = val_loader
self.decoder = decoder
# store optimizers
@@ -594,8 +582,7 @@ class DecoderTrainer(nn.Module):
for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
state_dict = optimizer.state_dict() if optimizer is not None else None
save_obj = {**save_obj, optimizer_key: state_dict}
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -617,8 +604,8 @@ class DecoderTrainer(nn.Module):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
warmup_scheduler = self.warmup_schedulers[ind]
if optimizer is not None:
optimizer.load_state_dict(loaded_obj[optimizer_key])
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step
@@ -678,14 +665,8 @@ class DecoderTrainer(nn.Module):
def sample(self, *args, **kwargs):
distributed = self.accelerator.num_processes > 1
base_decoder = self.accelerator.unwrap_model(self.decoder)
was_training = base_decoder.training
base_decoder.eval()
if kwargs.pop('use_non_ema', False) or not self.use_ema:
out = base_decoder.sample(*args, **kwargs, distributed = distributed)
base_decoder.train(was_training)
return out
return base_decoder.sample(*args, **kwargs, distributed = distributed)
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
@@ -698,7 +679,6 @@ class DecoderTrainer(nn.Module):
for ema in self.ema_unets:
ema.restore_ema_model_device()
base_decoder.train(was_training)
return output
@torch.no_grad()
@@ -719,32 +699,23 @@ class DecoderTrainer(nn.Module):
*args,
unet_number = None,
max_batch_size = None,
return_lowres_cond_image=False,
**kwargs
):
unet_number = self.validate_and_return_unet_number(unet_number)
total_loss = 0.
cond_images = []
using_amp = self.accelerator.mixed_precision != 'no'
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast():
loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
# loss_obj may be a tuple with loss and cond_image
if return_lowres_cond_image:
loss, cond_image = loss_obj
else:
loss = loss_obj
cond_image = None
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
if cond_image is not None:
cond_images.append(cond_image)
total_loss += loss.item()
if self.training:
self.accelerator.backward(loss)
if return_lowres_cond_image:
return total_loss, torch.stack(cond_images)
else:
return total_loss
return total_loss

View File

@@ -1 +1 @@
__version__ = '1.0.0'
__version__ = '0.18.0'

View File

@@ -1,6 +1,5 @@
from pathlib import Path
from typing import List
from datetime import timedelta
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
@@ -12,12 +11,11 @@ from clip import tokenize
import torchvision
import torch
from torch import nn
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import dataclasses as accelerate_dataclasses
import webdataset as wds
import click
@@ -134,7 +132,7 @@ def get_example_data(dataloader, device, n=5):
break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
@@ -159,13 +157,6 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
# Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings)
sample_params["text_encodings"] = text_embeddings
sample_params["start_at_unet_number"] = start_unet
sample_params["stop_at_unet_number"] = end_unet
if start_unet > 1:
# If we are only training upsamplers
sample_params["image"] = torch.stack(real_images)
if device is not None:
sample_params["_device"] = device
samples = trainer.sample(**sample_params)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
@@ -174,15 +165,15 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
"""
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
@@ -192,7 +183,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.")
return metrics
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -268,13 +259,11 @@ def train(
evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None,
save_immediately=False,
epochs = 20,
n_sample_images = 5,
save_every_n_samples = 100000,
unet_training_mask=None,
condition_on_text_encodings=False,
cond_scale=1.0,
**kwargs
):
"""
@@ -282,21 +271,6 @@ def train(
"""
is_master = accelerator.process_index == 0
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * len(decoder.unets)
assert len(unet_training_mask) == len(decoder.unets), f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable]
first_trainable_unet = trainable_unet_numbers[0]
last_trainable_unet = trainable_unet_numbers[-1]
def move_unets(unet_training_mask):
for i in range(len(decoder.unets)):
if not unet_training_mask[i]:
# Replace the unet from the module list with a nn.Identity(). This training script never uses unets that aren't being trained so this is fine.
decoder.unets[i] = nn.Identity().to(inference_device)
# Remove non-trainable unets
move_unets(unet_training_mask)
trainer = DecoderTrainer(
decoder=decoder,
accelerator=accelerator,
@@ -311,7 +285,6 @@ def train(
sample = 0
samples_seen = 0
val_sample = 0
step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet))
if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
@@ -323,6 +296,13 @@ def train(
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device)
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
accelerator.print("This can take a while to load the shard lists...")
if is_master:
@@ -343,7 +323,7 @@ def train(
last_snapshot = sample
if next_task == 'train':
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
for i, (img, emb, txt) in enumerate(trainer.train_loader):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -378,9 +358,8 @@ def train(
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -393,10 +372,10 @@ def train(
unet_all_losses = accelerator.gather(unet_losses_tensor)
mask = unet_all_losses != 0
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if unet_training_mask[index] }
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
# gather decay rate on each UNet
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets) if unet_training_mask[index]}
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
log_data = {
"Epoch": epoch,
@@ -411,7 +390,7 @@ def train(
if is_master:
tracker.log(log_data, step=step())
if is_master and (last_snapshot + save_every_n_samples < sample or (save_immediately and i == 0)): # This will miss by some amount every time, but it's not a big deal... I hope
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
print("Saving snapshot")
last_snapshot = sample
@@ -419,7 +398,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples:
@@ -437,7 +416,7 @@ def train(
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
@@ -469,9 +448,8 @@ def train(
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
average_val_loss_tensor[0, unet-1] += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
@@ -498,7 +476,7 @@ def train(
if next_task == 'eval':
if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
if is_master:
tracker.log(evaluation, step=step())
next_task = 'sample'
@@ -509,15 +487,15 @@ def train(
# Generate examples and save the model if we are the master
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
is_best = False
if all_average_val_losses is not None:
average_loss = all_average_val_losses.mean(dim=0).sum() / sum(unet_training_mask)
average_loss = all_average_val_losses.mean(dim=0).item()
if len(validation_losses) == 0 or average_loss < min(validation_losses):
is_best = True
validation_losses.append(average_loss)
@@ -534,7 +512,6 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
}
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json')
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
return tracker
def initialize_training(config: TrainDecoderConfig, config_path):
@@ -543,8 +520,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
if accelerator.num_processes > 1:
# We are using distributed training and want to immediately ensure all can connect
@@ -581,7 +557,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Create the decoder model and print basic info
decoder = config.decoder.create()
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
num_parameters = sum(p.numel() for p in decoder.parameters())
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
@@ -610,10 +586,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
for i, unet in enumerate(decoder.unets):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
accelerator.print(f"Number of parameters: {num_parameters}")
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,

View File

@@ -126,9 +126,9 @@ def report_cosine_sims(
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings = trainer.embed_text(text_data)
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
)
else:
text_embedding = text_data
@@ -146,12 +146,15 @@ def report_cosine_sims(
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
text_mask_shuffled = text_mask[rolled_idx]
else:
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(
text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
)
# prepare the text embedding