Compare commits

...

13 Commits

Author SHA1 Message Date
Phil Wang
46a2558d53 bug in pydantic decoder config class 2022-06-29 07:17:35 -07:00
yytdfc
86109646e3 fix a bug of name error (#179) 2022-06-29 07:16:44 -07:00
Phil Wang
6a11b9678b bring in the skip connection scaling factor, used by imagen in their unets, cite original paper using it 2022-06-26 21:59:55 -07:00
Phil Wang
b90364695d fix remaining issues with deriving cond_on_text_encodings from child unet settings 2022-06-26 21:07:42 -07:00
zion
868c001199 bug fixes for text conditioning update (#175) 2022-06-26 16:12:32 -07:00
Phil Wang
032e83b0e0 nevermind, do not enforce text encodings on first unet 2022-06-26 12:45:05 -07:00
Phil Wang
2e85e736f3 remove unnecessary decoder setting, and if not unconditional, always make sure the first unet is condition-able on text 2022-06-26 12:32:17 -07:00
Aidan Dempster
f5760bdb92 Add data flexibility to decoder trainer (#165)
* Added the ability to train decoder with text embeddings

* Added the ability to train using on the fly generated embeddings with clip

* Clip now generates embeddings for whatever is not precomputed
2022-06-25 19:05:20 -07:00
zion
c453f468b1 autoswitch tqdm for notebooks (#171)
avoids printing the `tqdm` progress bar to a newline in notebooks when detected
2022-06-25 16:37:06 -07:00
zion
98f0c17759 add sampels-seen and ema decay (#166) 2022-06-24 15:12:09 -07:00
Phil Wang
a5b9fd6ca8 product management 2022-06-24 08:15:05 -07:00
Phil Wang
4b994601ae just make sure decoder learning rate is reasonable and help out budding researchers 2022-06-23 11:29:28 -07:00
zion
fddf66e91e fix params in decoder (#162) 2022-06-22 14:45:01 -07:00
7 changed files with 287 additions and 82 deletions

View File

@@ -368,7 +368,8 @@ unet1 = Unet(
image_embed_dim = 512, image_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8) dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings
).cuda() ).cuda()
unet2 = Unet( unet2 = Unet(
@@ -385,8 +386,7 @@ decoder = Decoder(
clip = clip, clip = clip,
timesteps = 100, timesteps = 100,
image_cond_drop_prob = 0.1, image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5, text_cond_drop_prob = 0.5
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda() ).cuda()
for unet_number in (1, 2): for unet_number in (1, 2):
@@ -1072,7 +1072,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc) - [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824 - [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] build infilling - [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
## Citations ## Citations
@@ -1189,4 +1189,14 @@ Once built, images will be saved to the same directory the command is invoked
} }
``` ```
```bibtex
@article{Saharia2021PaletteID,
title = {Palette: Image-to-Image Diffusion Models},
author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
journal = {ArXiv},
year = {2021},
volume = {abs/2111.05826}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,6 @@
import math import math
import random import random
from tqdm import tqdm from tqdm.auto import tqdm
from functools import partial, wraps from functools import partial, wraps
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
@@ -1359,6 +1359,7 @@ class Unet(nn.Module):
cross_embed_downsample = False, cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4), cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False, memory_efficient = False,
scale_skip_connection = False,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -1440,6 +1441,10 @@ class Unet(nn.Module):
self.max_text_len = max_text_len self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
# whether to scale skip connection, adopted in Imagen
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
# attention related params # attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
@@ -1687,7 +1692,9 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t) x = self.mid_block2(x, mid_c, t)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups: for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim = 1) skip_connect = hiddens.pop() * self.skip_connect_scale
x = torch.cat((x, skip_connect), dim = 1)
x = init_block(x, c, t) x = init_block(x, c, t)
x = sparse_attn(x) x = sparse_attn(x)
@@ -1766,14 +1773,13 @@ class Decoder(nn.Module):
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True, clip_denoised = True,
clip_x_start = True, clip_x_start = True,
clip_adapter_overrides = dict(), clip_adapter_overrides = dict(),
learned_variance = True, learned_variance = True,
learned_variance_constrain_frac = False, learned_variance_constrain_frac = False,
vb_loss_weight = 0.001, vb_loss_weight = 0.001,
unconditional = False, unconditional = False, # set to True for generating images without conditioning
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
use_dynamic_thres = False, # from the Imagen paper use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9, dynamic_thres_percentile = 0.9,
@@ -1782,13 +1788,6 @@ class Decoder(nn.Module):
): ):
super().__init__() super().__init__()
self.unconditional = unconditional
# text conditioning
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
self.condition_on_text_encodings = condition_on_text_encodings
# clip # clip
self.clip = None self.clip = None
@@ -1820,12 +1819,16 @@ class Decoder(nn.Module):
self.channels = channels self.channels = channels
# automatically take care of ensuring that first unet is unconditional # verify conditioning method
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet) unets = cast_tuple(unet)
num_unets = len(unets) num_unets = len(unets)
self.unconditional = unconditional
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels)) vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper # whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
@@ -1852,8 +1855,8 @@ class Decoder(nn.Module):
one_unet = one_unet.cast_model_parameters( one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first, lowres_cond = not is_first,
cond_on_image_embeds = is_first and not unconditional, cond_on_image_embeds = not unconditional and is_first,
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional, cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings,
channels = unet_channels, channels = unet_channels,
channels_out = unet_channels_out channels_out = unet_channels_out
) )
@@ -1861,6 +1864,10 @@ class Decoder(nn.Module):
self.unets.append(one_unet) self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval()) self.vaes.append(one_vae.copy_for_eval())
# determine from unets whether conditioning on text encoding is needed
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
# create noise schedulers per unet # create noise schedulers per unet
if not exists(beta_schedule): if not exists(beta_schedule):

View File

@@ -21,7 +21,7 @@ def get_example_file(fs, path, file_format):
""" """
return fs.glob(os.path.join(path, f"*.{file_format}"))[0] return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception): def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields""" """Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
previous_tar_url = None previous_tar_url = None
current_embeddings = None current_embeddings = None
@@ -56,7 +56,7 @@ def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handler
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop # We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
if torch.count_nonzero(embedding) == 0: if torch.count_nonzero(embedding) == 0:
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}") raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
sample["npy"] = embedding sample[sample_key] = embedding
yield sample yield sample
except Exception as exn: # From wds implementation except Exception as exn: # From wds implementation
if handler(exn): if handler(exn):
@@ -84,18 +84,20 @@ def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.re
continue continue
else: else:
break break
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper) skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
def verify_keys(samples, handler=wds.handlers.reraise_exception): def join_embeddings(samples, handler=wds.handlers.reraise_exception):
""" """
Requires that both the image and embedding are present in the sample Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter. either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
""" """
for sample in samples: for sample in samples:
try: try:
assert "jpg" in sample, f"Sample {sample['__key__']} missing image" sample['emb'] = {}
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?" if 'text_emb' in sample:
sample['emb']['text'] = sample['text_emb']
if 'img_emb' in sample:
sample['emb']['img'] = sample['img_emb']
yield sample yield sample
except Exception as exn: # From wds implementation except Exception as exn: # From wds implementation
if handler(exn): if handler(exn):
@@ -103,6 +105,23 @@ def verify_keys(samples, handler=wds.handlers.reraise_exception):
else: else:
break break
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
"""
for sample in samples:
try:
for key in required_keys:
assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
key_verifier = wds.filters.pipelinefilter(verify_keys)
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface): class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
""" """
A fluid interface wrapper for DataPipline that returns image embedding pairs A fluid interface wrapper for DataPipline that returns image embedding pairs
@@ -112,7 +131,8 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
def __init__( def __init__(
self, self,
urls, urls,
embedding_folder_url=None, img_embedding_folder_url=None,
text_embedding_folder_url=None,
index_width=None, index_width=None,
img_preproc=None, img_preproc=None,
extra_keys=[], extra_keys=[],
@@ -136,7 +156,12 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
""" """
super().__init__() super().__init__()
keys = ["jpg", "npy"] + extra_keys keys = ["jpg", "emb"] + extra_keys
# if img_embedding_folder_url is not None:
# keys.append("img_emb")
# if text_embedding_folder_url is not None:
# keys.append("text_emb")
# keys.extend(extra_keys)
self.key_map = {key: i for i, key in enumerate(keys)} self.key_map = {key: i for i, key in enumerate(keys)}
self.resampling = resample self.resampling = resample
self.img_preproc = img_preproc self.img_preproc = img_preproc
@@ -145,7 +170,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
# Then this has an s3 link for the webdataset and we need extra packages # Then this has an s3 link for the webdataset and we need extra packages
if shutil.which("s3cmd") is None: if shutil.which("s3cmd") is None:
raise RuntimeError("s3cmd is required for s3 webdataset") raise RuntimeError("s3cmd is required for s3 webdataset")
if "s3:" in embedding_folder_url: if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
# Then the embeddings are being loaded from s3 and fsspec requires s3fs # Then the embeddings are being loaded from s3 and fsspec requires s3fs
try: try:
import s3fs import s3fs
@@ -160,17 +185,24 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
if shuffle_shards: if shuffle_shards:
self.append(wds.filters.shuffle(1000)) self.append(wds.filters.shuffle(1000))
if embedding_folder_url is not None: if img_embedding_folder_url is not None:
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues. # There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler)) self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
if text_embedding_folder_url is not None:
self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
self.append(wds.tarfile_to_samples(handler=handler)) self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("pilrgb", handler=handler)) self.append(wds.decode("pilrgb", handler=handler))
if embedding_folder_url is not None: if img_embedding_folder_url is not None:
# Then we are loading embeddings for a remote source # Then we are loading image embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given" assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler)) self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
self.append(verify_keys) if text_embedding_folder_url is not None:
# Then we are loading image embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
self.append(join_embeddings)
self.append(key_verifier(required_keys=keys, handler=handler))
# Apply preprocessing # Apply preprocessing
self.append(wds.map(self.preproc)) self.append(wds.map(self.preproc))
self.append(wds.to_tuple(*keys)) self.append(wds.to_tuple(*keys))
@@ -185,7 +217,8 @@ def create_image_embedding_dataloader(
tar_url, tar_url,
num_workers, num_workers,
batch_size, batch_size,
embeddings_url=None, img_embeddings_url=None,
text_embeddings_url=None,
index_width=None, index_width=None,
shuffle_num = None, shuffle_num = None,
shuffle_shards = True, shuffle_shards = True,
@@ -211,7 +244,8 @@ def create_image_embedding_dataloader(
""" """
ds = ImageEmbeddingDataset( ds = ImageEmbeddingDataset(
tar_url, tar_url,
embeddings_url, img_embedding_folder_url=img_embeddings_url,
text_embedding_folder_url=text_embeddings_url,
index_width=index_width, index_width=index_width,
shuffle_shards=shuffle_shards, shuffle_shards=shuffle_shards,
resample=resample_shards, resample=resample_shards,
@@ -228,4 +262,4 @@ def create_image_embedding_dataloader(
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
pin_memory=True, pin_memory=True,
shuffle=False shuffle=False
) )

View File

@@ -13,7 +13,7 @@ from dalle2_pytorch.dalle2_pytorch import (
Decoder, Decoder,
DiffusionPrior, DiffusionPrior,
DiffusionPriorNetwork, DiffusionPriorNetwork,
XClipAdapter, XClipAdapter
) )
# helper functions # helper functions
@@ -158,6 +158,8 @@ class UnetConfig(BaseModel):
dim: int dim: int
dim_mults: ListOrTuple(int) dim_mults: ListOrTuple(int)
image_embed_dim: int = None image_embed_dim: int = None
text_embed_dim: int = None
cond_on_text_encodings: bool = None
cond_dim: int = None cond_dim: int = None
channels: int = 3 channels: int = 3
attn_dim_head: int = 32 attn_dim_head: int = 32
@@ -170,6 +172,7 @@ class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig) unets: ListOrTuple(UnetConfig)
image_size: int = None image_size: int = None
image_sizes: ListOrTuple(int) = None image_sizes: ListOrTuple(int) = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3 channels: int = 3
timesteps: int = 1000 timesteps: int = 1000
loss_type: str = 'l2' loss_type: str = 'l2'
@@ -180,9 +183,16 @@ class DecoderConfig(BaseModel):
def create(self): def create(self):
decoder_kwargs = self.dict() decoder_kwargs = self.dict()
unet_configs = decoder_kwargs.pop('unets') unet_configs = decoder_kwargs.pop('unets')
unets = [Unet(**config) for config in unet_configs] unets = [Unet(**config) for config in unet_configs]
return Decoder(unets, **decoder_kwargs)
has_clip = exists(decoder_kwargs.pop('clip'))
clip = None
if has_clip:
clip = self.clip.create()
return Decoder(unets, clip=clip, **decoder_kwargs)
@validator('image_sizes') @validator('image_sizes')
def check_image_sizes(cls, image_sizes, values): def check_image_sizes(cls, image_sizes, values):
@@ -194,8 +204,9 @@ class DecoderConfig(BaseModel):
extra = "allow" extra = "allow"
class DecoderDataConfig(BaseModel): class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings img_embeddings_url: Optional[str] # path to .npy files with embeddings
text_embeddings_url: Optional[str] # path to .npy files with embeddings
num_workers: int = 4 num_workers: int = 4
batch_size: int = 64 batch_size: int = 64
start_shard: int = 0 start_shard: int = 0
@@ -268,3 +279,32 @@ class TrainDecoderConfig(BaseModel):
with open(json_path) as f: with open(json_path) as f:
config = json.load(f) config = json.load(f)
return cls(**config) return cls(**config)
@root_validator
def check_has_embeddings(cls, values):
# Makes sure that enough information is provided to get the embeddings specified for training
data_config, decoder_config = values.get('data'), values.get('decoder')
if not exists(data_config) or not exists(decoder_config):
# Then something else errored and we should just pass through
return values
using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
using_clip = exists(decoder_config.clip)
img_emb_url = data_config.img_embeddings_url
text_emb_url = data_config.text_embeddings_url
if using_text_embeddings:
# Then we need some way to get the embeddings
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
if using_clip:
if using_text_embeddings:
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
else:
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
if text_emb_url:
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return values

View File

@@ -451,6 +451,8 @@ class DecoderTrainer(nn.Module):
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps)) lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
optimizers = [] optimizers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps): for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
@@ -576,6 +578,18 @@ class DecoderTrainer(nn.Module):
return output return output
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_image(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
@cast_torch_tensor @cast_torch_tensor
def forward( def forward(
self, self,

View File

@@ -1 +1 @@
__version__ = '0.11.4' __version__ = '0.12.4'

View File

@@ -6,6 +6,7 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to from dalle2_pytorch.dalle2_pytorch import resize_image_to
from clip import tokenize
import torchvision import torchvision
import torch import torch
@@ -33,7 +34,8 @@ def exists(val):
def create_dataloaders( def create_dataloaders(
available_shards, available_shards,
webdataset_base_url, webdataset_base_url,
embeddings_url, img_embeddings_url=None,
text_embeddings_url=None,
shard_width=6, shard_width=6,
num_workers=4, num_workers=4,
batch_size=32, batch_size=32,
@@ -63,14 +65,15 @@ def create_dataloaders(
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split] test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split] val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader( create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
tar_url=tar_urls, tar_url=tar_urls,
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size if not for_sampling else n_sample_images, batch_size=batch_size if not for_sampling else n_sample_images,
embeddings_url=embeddings_url, img_embeddings_url=img_embeddings_url,
text_embeddings_url=text_embeddings_url,
index_width=index_width, index_width=index_width,
shuffle_num = None, shuffle_num = None,
extra_keys= ["txt"] if with_text else [], extra_keys= ["txt"],
shuffle_shards = shuffle, shuffle_shards = shuffle,
resample_shards = resample, resample_shards = resample,
img_preproc=img_preproc, img_preproc=img_preproc,
@@ -79,8 +82,8 @@ def create_dataloaders(
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train) train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True) train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True) val_dataloader = create_dataloader(val_urls, shuffle=False)
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True) test_dataloader = create_dataloader(test_urls, shuffle=False)
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True) test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
return { return {
"train": train_dataloader, "train": train_dataloader,
@@ -104,42 +107,65 @@ def get_example_data(dataloader, device, n=5):
Samples the dataloader and returns a zipped list of examples Samples the dataloader and returns a zipped list of examples
""" """
images = [] images = []
embeddings = [] img_embeddings = []
text_embeddings = []
captions = [] captions = []
dataset_keys = get_dataset_keys(dataloader) for img, emb, txt in dataloader:
has_caption = "txt" in dataset_keys img_emb, text_emb = emb.get('img'), emb.get('text')
for data in dataloader: if img_emb is not None:
if has_caption: img_emb = img_emb.to(device=device, dtype=torch.float)
img, emb, txt = data img_embeddings.extend(list(img_emb))
else: else:
img, emb = data # Then we add None img.shape[0] times
txt = [""] * emb.shape[0] img_embeddings.extend([None]*img.shape[0])
if text_emb is not None:
text_emb = text_emb.to(device=device, dtype=torch.float)
text_embeddings.extend(list(text_emb))
else:
# Then we add None img.shape[0] times
text_embeddings.extend([None]*img.shape[0])
img = img.to(device=device, dtype=torch.float) img = img.to(device=device, dtype=torch.float)
emb = emb.to(device=device, dtype=torch.float)
images.extend(list(img)) images.extend(list(img))
embeddings.extend(list(emb))
captions.extend(list(txt)) captions.extend(list(txt))
if len(images) >= n: if len(images) >= n:
break break
return list(zip(images[:n], embeddings[:n], captions[:n])) return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, text_prepend=""): def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
""" """
Takes example data and generates images from the embeddings Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions Returns three lists: real images, generated images, and captions
""" """
real_images, embeddings, txts = zip(*example_data) real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
embeddings_tensor = torch.stack(embeddings) sample_params = {}
samples = trainer.sample(embeddings_tensor) if img_embeddings[0] is None:
# Generate image embeddings from clip
imgs_tensor = torch.stack(real_images)
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
sample_params["image_embed"] = img_embeddings
else:
# Then we are using precomputed image embeddings
img_embeddings = torch.stack(img_embeddings)
sample_params["image_embed"] = img_embeddings
if condition_on_text_encodings:
if text_embeddings[0] is None:
# Generate text embeddings from text
tokenized_texts = tokenize(txts, truncate=True)
sample_params["text"] = tokenized_texts
else:
# Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings)
sample_params["text_encodings"] = text_embeddings
samples = trainer.sample(**sample_params)
generated_images = list(samples) generated_images = list(samples)
captions = [text_prepend + txt for txt in txts] captions = [text_prepend + txt for txt in txts]
return real_images, generated_images, captions return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, text_prepend=""): def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
""" """
Generates samples and uses torchvision to put them in a side by side grid for easy viewing Generates samples and uses torchvision to put them in a side by side grid for easy viewing
""" """
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend) real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
real_image_size = real_images[0].shape[-1] real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1] generated_image_size = generated_images[0].shape[-1]
@@ -151,7 +177,7 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
""" """
Computes evaluation metrics for the decoder Computes evaluation metrics for the decoder
""" """
@@ -161,7 +187,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
if len(examples) == 0: if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.") print("No data to evaluate. Check that your dataloader has shards.")
return metrics return metrics
real_images, generated_images, captions = generate_samples(trainer, examples) real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float) real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -250,6 +276,7 @@ def train(
save_latest=True, save_latest=True,
save_best=True, save_best=True,
unet_training_mask=None, unet_training_mask=None,
condition_on_text_encodings=False,
**kwargs **kwargs
): ):
""" """
@@ -258,8 +285,8 @@ def train(
is_master = accelerator.process_index == 0 is_master = accelerator.process_index == 0
trainer = DecoderTrainer( trainer = DecoderTrainer(
accelerator, decoder=decoder,
decoder, accelerator=accelerator,
**kwargs **kwargs
) )
@@ -268,6 +295,7 @@ def train(
validation_losses = [] validation_losses = []
next_task = 'train' next_task = 'train'
sample = 0 sample = 0
samples_seen = 0
val_sample = 0 val_sample = 0
step = lambda: int(trainer.step.item()) step = lambda: int(trainer.step.item())
@@ -306,13 +334,22 @@ def train(
last_snapshot = sample last_snapshot = sample
if next_task == 'train': if next_task == 'train':
for i, (img, emb) in enumerate(dataloaders["train"]): for i, (img, emb, txt) in enumerate(dataloaders["train"]):
# We want to count the total number of samples across all processes # We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img) sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
total_samples = all_samples.sum().item() total_samples = all_samples.sum().item()
sample += total_samples sample += total_samples
img, emb = send_to_device((img, emb)) samples_seen += total_samples
img_emb = emb.get('img')
has_img_embedding = img_emb is not None
if has_img_embedding:
img_emb, = send_to_device((img_emb,))
text_emb = emb.get('text')
has_text_embedding = text_emb is not None
if has_text_embedding:
text_emb, = send_to_device((text_emb,))
img, = send_to_device((img,))
trainer.train() trainer.train()
for unet in range(1, trainer.num_unets+1): for unet in range(1, trainer.num_unets+1):
@@ -320,7 +357,20 @@ def train(
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue continue
loss = trainer.forward(img, image_embed=emb, unet_number=unet) forward_params = {}
if has_img_embedding:
forward_params['image_embed'] = img_emb
else:
# Forward pass automatically generates embedding
pass
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet) trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -334,14 +384,20 @@ def train(
mask = unet_all_losses != 0 mask = unet_all_losses != 0
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0) unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 } loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
# gather decay rate on each UNet
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
log_data = { log_data = {
"Epoch": epoch, "Epoch": epoch,
"Sample": sample, "Sample": sample,
"Step": i, "Step": i,
"Samples per second": samples_per_sec, "Samples per second": samples_per_sec,
"Samples Seen": samples_seen,
**ema_decay_list,
**loss_map **loss_map
} }
# print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}")
if is_master: if is_master:
tracker.log(log_data, step=step(), verbose=True) tracker.log(log_data, step=step(), verbose=True)
@@ -358,7 +414,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths) save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
if exists(n_sample_images) and n_sample_images > 0: if exists(n_sample_images) and n_sample_images > 0:
trainer.eval() trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples: if epoch_samples is not None and sample >= epoch_samples:
@@ -381,14 +437,35 @@ def train(
all_samples = accelerator.gather(val_sample_length_tensor) all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item() total_samples = all_samples.sum().item()
val_sample += total_samples val_sample += total_samples
img, emb = send_to_device((img, emb)) img_emb = emb.get('img')
has_img_embedding = img_emb is not None
if has_img_embedding:
img_emb, = send_to_device((img_emb,))
text_emb = emb.get('text')
has_text_embedding = text_emb is not None
if has_text_embedding:
text_emb, = send_to_device((text_emb,))
img, = send_to_device((img,))
for unet in range(1, len(decoder.unets)+1): for unet in range(1, len(decoder.unets)+1):
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
# No need to evaluate an unchanging unet # No need to evaluate an unchanging unet
continue continue
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet) forward_params = {}
if has_img_embedding:
forward_params['image_embed'] = img_emb.float()
else:
# Forward pass automatically generates embedding
pass
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb.float()
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
forward_params['text'] = tokenized_texts
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
average_val_loss_tensor[0, unet-1] += loss average_val_loss_tensor[0, unet-1] += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0: if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
@@ -415,7 +492,7 @@ def train(
if next_task == 'eval': if next_task == 'eval':
if exists(evaluate_config): if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict()) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
if is_master: if is_master:
tracker.log(evaluation, step=step(), verbose=True) tracker.log(evaluation, step=step(), verbose=True)
next_task = 'sample' next_task = 'sample'
@@ -426,8 +503,8 @@ def train(
# Generate examples and save the model if we are the master # Generate examples and save the model if we are the master
# Generate sample images # Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ") test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step()) tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
@@ -517,14 +594,37 @@ def initialize_training(config, config_path):
# Create and initialize the tracker if we are the master # Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy") tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
has_img_embeddings = config.data.img_embeddings_url is not None
has_text_embeddings = config.data.text_embeddings_url is not None
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
has_clip_model = config.decoder.clip is not None
data_source_string = ""
if has_img_embeddings:
data_source_string += "precomputed image embeddings"
elif has_clip_model:
data_source_string += "clip image embeddings generation"
else:
raise ValueError("No image embeddings source specified")
if conditioning_on_text:
if has_text_embeddings:
data_source_string += " and precomputed text embeddings"
elif has_clip_model:
data_source_string += " and clip text encoding generation"
else:
raise ValueError("No text embeddings source specified")
accelerator.print(print_ribbon("Loaded Config", repeat=40)) accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training") accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {num_parameters}") accelerator.print(f"Number of parameters: {num_parameters}")
train(dataloaders, decoder, accelerator, train(dataloaders, decoder, accelerator,
tracker=tracker, tracker=tracker,
inference_device=accelerator.device, inference_device=accelerator.device,
load_config=config.load, load_config=config.load,
evaluate_config=config.evaluate, evaluate_config=config.evaluate,
condition_on_text_encodings=conditioning_on_text,
**config.train.dict(), **config.train.dict(),
) )