Compare commits

...

19 Commits

Author SHA1 Message Date
Phil Wang
46a2558d53 bug in pydantic decoder config class 2022-06-29 07:17:35 -07:00
yytdfc
86109646e3 fix a bug of name error (#179) 2022-06-29 07:16:44 -07:00
Phil Wang
6a11b9678b bring in the skip connection scaling factor, used by imagen in their unets, cite original paper using it 2022-06-26 21:59:55 -07:00
Phil Wang
b90364695d fix remaining issues with deriving cond_on_text_encodings from child unet settings 2022-06-26 21:07:42 -07:00
zion
868c001199 bug fixes for text conditioning update (#175) 2022-06-26 16:12:32 -07:00
Phil Wang
032e83b0e0 nevermind, do not enforce text encodings on first unet 2022-06-26 12:45:05 -07:00
Phil Wang
2e85e736f3 remove unnecessary decoder setting, and if not unconditional, always make sure the first unet is condition-able on text 2022-06-26 12:32:17 -07:00
Aidan Dempster
f5760bdb92 Add data flexibility to decoder trainer (#165)
* Added the ability to train decoder with text embeddings

* Added the ability to train using on the fly generated embeddings with clip

* Clip now generates embeddings for whatever is not precomputed
2022-06-25 19:05:20 -07:00
zion
c453f468b1 autoswitch tqdm for notebooks (#171)
avoids printing the `tqdm` progress bar to a newline in notebooks when detected
2022-06-25 16:37:06 -07:00
zion
98f0c17759 add sampels-seen and ema decay (#166) 2022-06-24 15:12:09 -07:00
Phil Wang
a5b9fd6ca8 product management 2022-06-24 08:15:05 -07:00
Phil Wang
4b994601ae just make sure decoder learning rate is reasonable and help out budding researchers 2022-06-23 11:29:28 -07:00
zion
fddf66e91e fix params in decoder (#162) 2022-06-22 14:45:01 -07:00
Phil Wang
c8422ffd5d fix EMA updating buffers with non-float tensors 2022-06-22 07:16:39 -07:00
Conight
2aadc23c7c Fix train decoder config example (#160) 2022-06-21 22:17:06 -07:00
Phil Wang
c098f57e09 EMA for vqgan vae comes from ema_pytorch now 2022-06-20 15:29:08 -07:00
Phil Wang
0021535c26 move ema to external repo 2022-06-20 11:48:32 -07:00
Phil Wang
56883910fb cleanup 2022-06-20 11:14:55 -07:00
Phil Wang
893f270012 project management 2022-06-20 10:00:22 -07:00
10 changed files with 304 additions and 283 deletions

View File

@@ -368,7 +368,8 @@ unet1 = Unet(
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings
).cuda()
unet2 = Unet(
@@ -385,8 +386,7 @@ decoder = Decoder(
clip = clip,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
text_cond_drop_prob = 0.5
).cuda()
for unet_number in (1, 2):
@@ -1017,33 +1017,6 @@ The most significant parameters for the script are as follows:
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
#### Loading and Saving the DiffusionPrior model
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
```python
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
```
##### Loading
load_diffusion_model(dprior_path, device)
dprior_path : path to saved model(.pth)
device : the cuda device you're running on
##### Saving
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
save_path : path to save at
model : object of Diffusion_Prior
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
scaler : a GradScaler object.
e.g: scaler = GradScaler(enabled=amp)
config : config object created in train_diffusion_prior.py - see file for example.
image_embed_dim - the dimension of the image_embedding
e.g: 768
## CLI (wip)
```bash
@@ -1092,19 +1065,14 @@ Once built, images will be saved to the same directory the command is invoked
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
- [x] allow for unet to be able to condition non-cross attention style as well
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
## Citations
@@ -1221,4 +1189,14 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@article{Saharia2021PaletteID,
title = {Palette: Image-to-Image Diffusion Models},
author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
journal = {ArXiv},
year = {2021},
volume = {abs/2111.05826}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -15,7 +15,7 @@
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"beta_schedule": ["cosine"],
"learned_variance": true
},
"data": {

View File

@@ -1,6 +1,6 @@
import math
import random
from tqdm import tqdm
from tqdm.auto import tqdm
from functools import partial, wraps
from contextlib import contextmanager
from collections import namedtuple
@@ -1359,6 +1359,7 @@ class Unet(nn.Module):
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False,
scale_skip_connection = False,
**kwargs
):
super().__init__()
@@ -1440,6 +1441,10 @@ class Unet(nn.Module):
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
# whether to scale skip connection, adopted in Imagen
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
@@ -1687,7 +1692,9 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim = 1)
skip_connect = hiddens.pop() * self.skip_connect_scale
x = torch.cat((x, skip_connect), dim = 1)
x = init_block(x, c, t)
x = sparse_attn(x)
@@ -1766,14 +1773,13 @@ class Decoder(nn.Module):
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True,
clip_x_start = True,
clip_adapter_overrides = dict(),
learned_variance = True,
learned_variance_constrain_frac = False,
vb_loss_weight = 0.001,
unconditional = False,
unconditional = False, # set to True for generating images without conditioning
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9,
@@ -1782,13 +1788,6 @@ class Decoder(nn.Module):
):
super().__init__()
self.unconditional = unconditional
# text conditioning
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
self.condition_on_text_encodings = condition_on_text_encodings
# clip
self.clip = None
@@ -1820,12 +1819,16 @@ class Decoder(nn.Module):
self.channels = channels
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
# verify conditioning method
unets = cast_tuple(unet)
num_unets = len(unets)
self.unconditional = unconditional
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
@@ -1852,8 +1855,8 @@ class Decoder(nn.Module):
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
cond_on_image_embeds = is_first and not unconditional,
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
cond_on_image_embeds = not unconditional and is_first,
cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings,
channels = unet_channels,
channels_out = unet_channels_out
)
@@ -1861,6 +1864,10 @@ class Decoder(nn.Module):
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# determine from unets whether conditioning on text encoding is needed
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
# create noise schedulers per unet
if not exists(beta_schedule):

View File

@@ -21,7 +21,7 @@ def get_example_file(fs, path, file_format):
"""
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
previous_tar_url = None
current_embeddings = None
@@ -56,7 +56,7 @@ def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handler
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
if torch.count_nonzero(embedding) == 0:
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
sample["npy"] = embedding
sample[sample_key] = embedding
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
@@ -84,18 +84,20 @@ def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.re
continue
else:
break
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
def verify_keys(samples, handler=wds.handlers.reraise_exception):
def join_embeddings(samples, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
"""
for sample in samples:
try:
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
sample['emb'] = {}
if 'text_emb' in sample:
sample['emb']['text'] = sample['text_emb']
if 'img_emb' in sample:
sample['emb']['img'] = sample['img_emb']
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
@@ -103,6 +105,23 @@ def verify_keys(samples, handler=wds.handlers.reraise_exception):
else:
break
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
"""
for sample in samples:
try:
for key in required_keys:
assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
key_verifier = wds.filters.pipelinefilter(verify_keys)
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
"""
A fluid interface wrapper for DataPipline that returns image embedding pairs
@@ -112,7 +131,8 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
def __init__(
self,
urls,
embedding_folder_url=None,
img_embedding_folder_url=None,
text_embedding_folder_url=None,
index_width=None,
img_preproc=None,
extra_keys=[],
@@ -136,7 +156,12 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
"""
super().__init__()
keys = ["jpg", "npy"] + extra_keys
keys = ["jpg", "emb"] + extra_keys
# if img_embedding_folder_url is not None:
# keys.append("img_emb")
# if text_embedding_folder_url is not None:
# keys.append("text_emb")
# keys.extend(extra_keys)
self.key_map = {key: i for i, key in enumerate(keys)}
self.resampling = resample
self.img_preproc = img_preproc
@@ -145,7 +170,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
# Then this has an s3 link for the webdataset and we need extra packages
if shutil.which("s3cmd") is None:
raise RuntimeError("s3cmd is required for s3 webdataset")
if "s3:" in embedding_folder_url:
if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
try:
import s3fs
@@ -160,17 +185,24 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
if shuffle_shards:
self.append(wds.filters.shuffle(1000))
if embedding_folder_url is not None:
if img_embedding_folder_url is not None:
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
if text_embedding_folder_url is not None:
self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("pilrgb", handler=handler))
if embedding_folder_url is not None:
# Then we are loading embeddings for a remote source
if img_embedding_folder_url is not None:
# Then we are loading image embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
self.append(verify_keys)
self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
if text_embedding_folder_url is not None:
# Then we are loading image embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
self.append(join_embeddings)
self.append(key_verifier(required_keys=keys, handler=handler))
# Apply preprocessing
self.append(wds.map(self.preproc))
self.append(wds.to_tuple(*keys))
@@ -185,7 +217,8 @@ def create_image_embedding_dataloader(
tar_url,
num_workers,
batch_size,
embeddings_url=None,
img_embeddings_url=None,
text_embeddings_url=None,
index_width=None,
shuffle_num = None,
shuffle_shards = True,
@@ -211,7 +244,8 @@ def create_image_embedding_dataloader(
"""
ds = ImageEmbeddingDataset(
tar_url,
embeddings_url,
img_embedding_folder_url=img_embeddings_url,
text_embedding_folder_url=text_embeddings_url,
index_width=index_width,
shuffle_shards=shuffle_shards,
resample=resample_shards,
@@ -228,4 +262,4 @@ def create_image_embedding_dataloader(
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
pin_memory=True,
shuffle=False
)
)

View File

@@ -13,7 +13,7 @@ from dalle2_pytorch.dalle2_pytorch import (
Decoder,
DiffusionPrior,
DiffusionPriorNetwork,
XClipAdapter,
XClipAdapter
)
# helper functions
@@ -158,6 +158,8 @@ class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple(int)
image_embed_dim: int = None
text_embed_dim: int = None
cond_on_text_encodings: bool = None
cond_dim: int = None
channels: int = 3
attn_dim_head: int = 32
@@ -170,6 +172,7 @@ class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig)
image_size: int = None
image_sizes: ListOrTuple(int) = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
@@ -180,9 +183,16 @@ class DecoderConfig(BaseModel):
def create(self):
decoder_kwargs = self.dict()
unet_configs = decoder_kwargs.pop('unets')
unets = [Unet(**config) for config in unet_configs]
return Decoder(unets, **decoder_kwargs)
has_clip = exists(decoder_kwargs.pop('clip'))
clip = None
if has_clip:
clip = self.clip.create()
return Decoder(unets, clip=clip, **decoder_kwargs)
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
@@ -194,8 +204,9 @@ class DecoderConfig(BaseModel):
extra = "allow"
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings
webdataset_base_url: str # path to a webdataset with jpg images
img_embeddings_url: Optional[str] # path to .npy files with embeddings
text_embeddings_url: Optional[str] # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
@@ -268,3 +279,32 @@ class TrainDecoderConfig(BaseModel):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
@root_validator
def check_has_embeddings(cls, values):
# Makes sure that enough information is provided to get the embeddings specified for training
data_config, decoder_config = values.get('data'), values.get('decoder')
if not exists(data_config) or not exists(decoder_config):
# Then something else errored and we should just pass through
return values
using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
using_clip = exists(decoder_config.clip)
img_emb_url = data_config.img_embeddings_url
text_emb_url = data_config.text_embeddings_url
if using_text_embeddings:
# Then we need some way to get the embeddings
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
if using_clip:
if using_text_embeddings:
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
else:
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
if text_emb_url:
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return values

View File

@@ -14,6 +14,8 @@ from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
from ema_pytorch import EMA
from accelerate import Accelerator
import numpy as np
@@ -62,16 +64,6 @@ def num_to_groups(num, divisor):
arr.append(remainder)
return arr
def clamp(value, min_value = None, max_value = None):
assert exists(min_value) or exists(max_value)
if exists(min_value):
value = max(value, min_value)
if exists(max_value):
value = min(value, max_value)
return value
# decorators
def cast_torch_tensor(fn):
@@ -145,146 +137,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# saving and loading functions
# for diffusion prior
def load_diffusion_model(dprior_path, device):
dprior_path = Path(dprior_path)
assert dprior_path.exists(), 'Dprior model file does not exist'
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
# Get hyperparameters of loaded model
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
dp_config = loaded_obj['hparams']['diffusion_prior']
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior, loaded_obj
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict
print_ribbon('Saving checkpoint')
state_dict = dict(model=model.state_dict(),
optimizer=optimizer.state_dict(),
scaler=scaler.state_dict(),
hparams = config,
image_embed_dim = {"image_embed_dim":image_embed_dim})
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
# exponential moving average wrapper
class EMA(nn.Module):
"""
Implements exponential moving average shadowing for your model.
Utilizes an inverse decay schedule to manage longer term training runs.
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
def __init__(
self,
model,
beta = 0.9999,
update_after_step = 100,
update_every = 10,
inv_gamma = 1.0,
power = 2/3,
min_value = 0.0,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.update_every = update_every
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0]))
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def copy_params_from_model_to_ema(self):
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
ma_param.data.copy_(current_param.data)
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
ma_buffer.data.copy_(current_buffer.data)
def get_current_decay(self):
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
if epoch <= 0:
return 0.
return clamp(value, min_value = self.min_value, max_value = self.beta)
def update(self):
step = self.step.item()
self.step += 1
if (step % self.update_every) != 0:
return
if step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return
if not self.initted.item():
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
@torch.no_grad()
def update_moving_average(self, ma_model, current_model):
current_decay = self.get_current_decay()
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
difference = ma_params.data - current_params.data
difference.mul_(1.0 - current_decay)
ma_params.sub_(difference)
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
difference = ma_buffer - current_buffer
difference.mul_(1.0 - current_decay)
ma_buffer.sub_(difference)
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# diffusion prior trainer
def prior_sample_in_chunks(fn):
@@ -505,26 +357,20 @@ class DiffusionPriorTrainer(nn.Module):
@cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs):
if self.use_ema:
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
else:
return self.diffusion_prior.p_sample_loop(*args, **kwargs)
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.p_sample_loop(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs):
if self.use_ema:
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
else:
return self.diffusion_prior.sample(*args, **kwargs)
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.sample(*args, **kwargs)
@torch.no_grad()
def sample_batch_size(self, *args, **kwargs):
if self.use_ema:
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
else:
return self.diffusion_prior.sample_batch_size(*args, **kwargs)
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.sample_batch_size(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@@ -605,6 +451,8 @@ class DecoderTrainer(nn.Module):
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
optimizers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
@@ -730,6 +578,18 @@ class DecoderTrainer(nn.Module):
return output
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_image(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
@cast_torch_tensor
def forward(
self,

View File

@@ -1 +1 @@
__version__ = '0.11.2'
__version__ = '0.12.4'

View File

@@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image
from einops import rearrange
from dalle2_pytorch.train import EMA
from dalle2_pytorch.vqgan_vae import VQGanVAE
from dalle2_pytorch.optimizer import get_optimizer
from ema_pytorch import EMA
# helpers
def exists(val):
@@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module):
valid_frac = 0.05,
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 2000,
ema_update_after_step = 500,
ema_update_every = 10,
apply_grad_penalty_every = 4,
amp = False

View File

@@ -28,6 +28,7 @@ setup(
'click',
'clip-anytorch',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',

View File

@@ -6,6 +6,7 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
from clip import tokenize
import torchvision
import torch
@@ -33,7 +34,8 @@ def exists(val):
def create_dataloaders(
available_shards,
webdataset_base_url,
embeddings_url,
img_embeddings_url=None,
text_embeddings_url=None,
shard_width=6,
num_workers=4,
batch_size=32,
@@ -63,14 +65,15 @@ def create_dataloaders(
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
tar_url=tar_urls,
num_workers=num_workers,
batch_size=batch_size if not for_sampling else n_sample_images,
embeddings_url=embeddings_url,
img_embeddings_url=img_embeddings_url,
text_embeddings_url=text_embeddings_url,
index_width=index_width,
shuffle_num = None,
extra_keys= ["txt"] if with_text else [],
extra_keys= ["txt"],
shuffle_shards = shuffle,
resample_shards = resample,
img_preproc=img_preproc,
@@ -79,8 +82,8 @@ def create_dataloaders(
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
val_dataloader = create_dataloader(val_urls, shuffle=False)
test_dataloader = create_dataloader(test_urls, shuffle=False)
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
return {
"train": train_dataloader,
@@ -104,42 +107,65 @@ def get_example_data(dataloader, device, n=5):
Samples the dataloader and returns a zipped list of examples
"""
images = []
embeddings = []
img_embeddings = []
text_embeddings = []
captions = []
dataset_keys = get_dataset_keys(dataloader)
has_caption = "txt" in dataset_keys
for data in dataloader:
if has_caption:
img, emb, txt = data
for img, emb, txt in dataloader:
img_emb, text_emb = emb.get('img'), emb.get('text')
if img_emb is not None:
img_emb = img_emb.to(device=device, dtype=torch.float)
img_embeddings.extend(list(img_emb))
else:
img, emb = data
txt = [""] * emb.shape[0]
# Then we add None img.shape[0] times
img_embeddings.extend([None]*img.shape[0])
if text_emb is not None:
text_emb = text_emb.to(device=device, dtype=torch.float)
text_embeddings.extend(list(text_emb))
else:
# Then we add None img.shape[0] times
text_embeddings.extend([None]*img.shape[0])
img = img.to(device=device, dtype=torch.float)
emb = emb.to(device=device, dtype=torch.float)
images.extend(list(img))
embeddings.extend(list(emb))
captions.extend(list(txt))
if len(images) >= n:
break
return list(zip(images[:n], embeddings[:n], captions[:n]))
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, text_prepend=""):
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
"""
real_images, embeddings, txts = zip(*example_data)
embeddings_tensor = torch.stack(embeddings)
samples = trainer.sample(embeddings_tensor)
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
sample_params = {}
if img_embeddings[0] is None:
# Generate image embeddings from clip
imgs_tensor = torch.stack(real_images)
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
sample_params["image_embed"] = img_embeddings
else:
# Then we are using precomputed image embeddings
img_embeddings = torch.stack(img_embeddings)
sample_params["image_embed"] = img_embeddings
if condition_on_text_encodings:
if text_embeddings[0] is None:
# Generate text embeddings from text
tokenized_texts = tokenize(txts, truncate=True)
sample_params["text"] = tokenized_texts
else:
# Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings)
sample_params["text_encodings"] = text_embeddings
samples = trainer.sample(**sample_params)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, text_prepend=""):
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
"""
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
@@ -151,7 +177,7 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
@@ -161,7 +187,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.")
return metrics
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -250,6 +276,7 @@ def train(
save_latest=True,
save_best=True,
unet_training_mask=None,
condition_on_text_encodings=False,
**kwargs
):
"""
@@ -258,8 +285,8 @@ def train(
is_master = accelerator.process_index == 0
trainer = DecoderTrainer(
accelerator,
decoder,
decoder=decoder,
accelerator=accelerator,
**kwargs
)
@@ -268,6 +295,7 @@ def train(
validation_losses = []
next_task = 'train'
sample = 0
samples_seen = 0
val_sample = 0
step = lambda: int(trainer.step.item())
@@ -306,13 +334,22 @@ def train(
last_snapshot = sample
if next_task == 'train':
for i, (img, emb) in enumerate(dataloaders["train"]):
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
total_samples = all_samples.sum().item()
sample += total_samples
img, emb = send_to_device((img, emb))
samples_seen += total_samples
img_emb = emb.get('img')
has_img_embedding = img_emb is not None
if has_img_embedding:
img_emb, = send_to_device((img_emb,))
text_emb = emb.get('text')
has_text_embedding = text_emb is not None
if has_text_embedding:
text_emb, = send_to_device((text_emb,))
img, = send_to_device((img,))
trainer.train()
for unet in range(1, trainer.num_unets+1):
@@ -320,7 +357,20 @@ def train(
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
forward_params = {}
if has_img_embedding:
forward_params['image_embed'] = img_emb
else:
# Forward pass automatically generates embedding
pass
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -334,14 +384,20 @@ def train(
mask = unet_all_losses != 0
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
# gather decay rate on each UNet
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
log_data = {
"Epoch": epoch,
"Sample": sample,
"Step": i,
"Samples per second": samples_per_sec,
"Samples Seen": samples_seen,
**ema_decay_list,
**loss_map
}
# print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}")
if is_master:
tracker.log(log_data, step=step(), verbose=True)
@@ -358,7 +414,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples:
@@ -381,14 +437,35 @@ def train(
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
val_sample += total_samples
img, emb = send_to_device((img, emb))
img_emb = emb.get('img')
has_img_embedding = img_emb is not None
if has_img_embedding:
img_emb, = send_to_device((img_emb,))
text_emb = emb.get('text')
has_text_embedding = text_emb is not None
if has_text_embedding:
text_emb, = send_to_device((text_emb,))
img, = send_to_device((img,))
for unet in range(1, len(decoder.unets)+1):
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
# No need to evaluate an unchanging unet
continue
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
forward_params = {}
if has_img_embedding:
forward_params['image_embed'] = img_emb.float()
else:
# Forward pass automatically generates embedding
pass
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb.float()
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
forward_params['text'] = tokenized_texts
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
average_val_loss_tensor[0, unet-1] += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
@@ -415,7 +492,7 @@ def train(
if next_task == 'eval':
if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
if is_master:
tracker.log(evaluation, step=step(), verbose=True)
next_task = 'sample'
@@ -426,8 +503,8 @@ def train(
# Generate examples and save the model if we are the master
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
@@ -517,14 +594,37 @@ def initialize_training(config, config_path):
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
has_img_embeddings = config.data.img_embeddings_url is not None
has_text_embeddings = config.data.text_embeddings_url is not None
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
has_clip_model = config.decoder.clip is not None
data_source_string = ""
if has_img_embeddings:
data_source_string += "precomputed image embeddings"
elif has_clip_model:
data_source_string += "clip image embeddings generation"
else:
raise ValueError("No image embeddings source specified")
if conditioning_on_text:
if has_text_embeddings:
data_source_string += " and precomputed text embeddings"
elif has_clip_model:
data_source_string += " and clip text encoding generation"
else:
raise ValueError("No text embeddings source specified")
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {num_parameters}")
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,
load_config=config.load,
evaluate_config=config.evaluate,
condition_on_text_encodings=conditioning_on_text,
**config.train.dict(),
)