mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
46a2558d53 | ||
|
|
86109646e3 | ||
|
|
6a11b9678b | ||
|
|
b90364695d | ||
|
|
868c001199 | ||
|
|
032e83b0e0 | ||
|
|
2e85e736f3 | ||
|
|
f5760bdb92 | ||
|
|
c453f468b1 | ||
|
|
98f0c17759 | ||
|
|
a5b9fd6ca8 | ||
|
|
4b994601ae | ||
|
|
fddf66e91e | ||
|
|
c8422ffd5d | ||
|
|
2aadc23c7c | ||
|
|
c098f57e09 | ||
|
|
0021535c26 | ||
|
|
56883910fb | ||
|
|
893f270012 |
58
README.md
58
README.md
@@ -368,7 +368,8 @@ unet1 = Unet(
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
@@ -385,8 +386,7 @@ decoder = Decoder(
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
@@ -1017,33 +1017,6 @@ The most significant parameters for the script are as follows:
|
||||
|
||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||
|
||||
#### Loading and Saving the DiffusionPrior model
|
||||
|
||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
||||
|
||||
```python
|
||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||
```
|
||||
|
||||
##### Loading
|
||||
|
||||
load_diffusion_model(dprior_path, device)
|
||||
dprior_path : path to saved model(.pth)
|
||||
device : the cuda device you're running on
|
||||
|
||||
##### Saving
|
||||
|
||||
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
||||
save_path : path to save at
|
||||
model : object of Diffusion_Prior
|
||||
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
||||
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
scaler : a GradScaler object.
|
||||
e.g: scaler = GradScaler(enabled=amp)
|
||||
config : config object created in train_diffusion_prior.py - see file for example.
|
||||
image_embed_dim - the dimension of the image_embedding
|
||||
e.g: 768
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
```bash
|
||||
@@ -1092,19 +1065,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
|
||||
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
|
||||
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -1221,4 +1189,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Saharia2021PaletteID,
|
||||
title = {Palette: Image-to-Image Diffusion Models},
|
||||
author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2111.05826}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"channels": 3,
|
||||
"timesteps": 1000,
|
||||
"loss_type": "l2",
|
||||
"beta_schedule": "cosine",
|
||||
"beta_schedule": ["cosine"],
|
||||
"learned_variance": true
|
||||
},
|
||||
"data": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import math
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
from functools import partial, wraps
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
@@ -1359,6 +1359,7 @@ class Unet(nn.Module):
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
memory_efficient = False,
|
||||
scale_skip_connection = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1440,6 +1441,10 @@ class Unet(nn.Module):
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
|
||||
# whether to scale skip connection, adopted in Imagen
|
||||
|
||||
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
|
||||
|
||||
# attention related params
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
@@ -1687,7 +1692,9 @@ class Unet(nn.Module):
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
|
||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||
skip_connect = hiddens.pop() * self.skip_connect_scale
|
||||
|
||||
x = torch.cat((x, skip_connect), dim = 1)
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
|
||||
@@ -1766,14 +1773,13 @@ class Decoder(nn.Module):
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict(),
|
||||
learned_variance = True,
|
||||
learned_variance_constrain_frac = False,
|
||||
vb_loss_weight = 0.001,
|
||||
unconditional = False,
|
||||
unconditional = False, # set to True for generating images without conditioning
|
||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||
use_dynamic_thres = False, # from the Imagen paper
|
||||
dynamic_thres_percentile = 0.9,
|
||||
@@ -1782,13 +1788,6 @@ class Decoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.unconditional = unconditional
|
||||
|
||||
# text conditioning
|
||||
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# clip
|
||||
|
||||
self.clip = None
|
||||
@@ -1820,12 +1819,16 @@ class Decoder(nn.Module):
|
||||
|
||||
self.channels = channels
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
# verify conditioning method
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
num_unets = len(unets)
|
||||
|
||||
self.unconditional = unconditional
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
|
||||
@@ -1852,8 +1855,8 @@ class Decoder(nn.Module):
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first and not unconditional,
|
||||
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
||||
cond_on_image_embeds = not unconditional and is_first,
|
||||
cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings,
|
||||
channels = unet_channels,
|
||||
channels_out = unet_channels_out
|
||||
)
|
||||
@@ -1861,6 +1864,10 @@ class Decoder(nn.Module):
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# determine from unets whether conditioning on text encoding is needed
|
||||
|
||||
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
|
||||
# create noise schedulers per unet
|
||||
|
||||
if not exists(beta_schedule):
|
||||
|
||||
@@ -21,7 +21,7 @@ def get_example_file(fs, path, file_format):
|
||||
"""
|
||||
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
||||
|
||||
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
|
||||
def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
|
||||
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
||||
previous_tar_url = None
|
||||
current_embeddings = None
|
||||
@@ -56,7 +56,7 @@ def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handler
|
||||
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
||||
if torch.count_nonzero(embedding) == 0:
|
||||
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
||||
sample["npy"] = embedding
|
||||
sample[sample_key] = embedding
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
@@ -84,18 +84,20 @@ def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.re
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
||||
|
||||
def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
||||
def join_embeddings(samples, handler=wds.handlers.reraise_exception):
|
||||
"""
|
||||
Requires that both the image and embedding are present in the sample
|
||||
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
||||
Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
|
||||
either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
|
||||
"""
|
||||
for sample in samples:
|
||||
try:
|
||||
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
|
||||
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
|
||||
sample['emb'] = {}
|
||||
if 'text_emb' in sample:
|
||||
sample['emb']['text'] = sample['text_emb']
|
||||
if 'img_emb' in sample:
|
||||
sample['emb']['img'] = sample['img_emb']
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
@@ -103,6 +105,23 @@ def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
||||
else:
|
||||
break
|
||||
|
||||
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
|
||||
"""
|
||||
Requires that both the image and embedding are present in the sample
|
||||
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
||||
"""
|
||||
for sample in samples:
|
||||
try:
|
||||
for key in required_keys:
|
||||
assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
key_verifier = wds.filters.pipelinefilter(verify_keys)
|
||||
|
||||
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
"""
|
||||
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
||||
@@ -112,7 +131,8 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
def __init__(
|
||||
self,
|
||||
urls,
|
||||
embedding_folder_url=None,
|
||||
img_embedding_folder_url=None,
|
||||
text_embedding_folder_url=None,
|
||||
index_width=None,
|
||||
img_preproc=None,
|
||||
extra_keys=[],
|
||||
@@ -136,7 +156,12 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
keys = ["jpg", "npy"] + extra_keys
|
||||
keys = ["jpg", "emb"] + extra_keys
|
||||
# if img_embedding_folder_url is not None:
|
||||
# keys.append("img_emb")
|
||||
# if text_embedding_folder_url is not None:
|
||||
# keys.append("text_emb")
|
||||
# keys.extend(extra_keys)
|
||||
self.key_map = {key: i for i, key in enumerate(keys)}
|
||||
self.resampling = resample
|
||||
self.img_preproc = img_preproc
|
||||
@@ -145,7 +170,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
# Then this has an s3 link for the webdataset and we need extra packages
|
||||
if shutil.which("s3cmd") is None:
|
||||
raise RuntimeError("s3cmd is required for s3 webdataset")
|
||||
if "s3:" in embedding_folder_url:
|
||||
if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
|
||||
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
||||
try:
|
||||
import s3fs
|
||||
@@ -160,17 +185,24 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
if shuffle_shards:
|
||||
self.append(wds.filters.shuffle(1000))
|
||||
|
||||
if embedding_folder_url is not None:
|
||||
if img_embedding_folder_url is not None:
|
||||
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
||||
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
|
||||
self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
|
||||
if text_embedding_folder_url is not None:
|
||||
self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
|
||||
|
||||
self.append(wds.tarfile_to_samples(handler=handler))
|
||||
self.append(wds.decode("pilrgb", handler=handler))
|
||||
if embedding_folder_url is not None:
|
||||
# Then we are loading embeddings for a remote source
|
||||
if img_embedding_folder_url is not None:
|
||||
# Then we are loading image embeddings for a remote source
|
||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
|
||||
self.append(verify_keys)
|
||||
self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
|
||||
if text_embedding_folder_url is not None:
|
||||
# Then we are loading image embeddings for a remote source
|
||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||
self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
|
||||
self.append(join_embeddings)
|
||||
self.append(key_verifier(required_keys=keys, handler=handler))
|
||||
# Apply preprocessing
|
||||
self.append(wds.map(self.preproc))
|
||||
self.append(wds.to_tuple(*keys))
|
||||
@@ -185,7 +217,8 @@ def create_image_embedding_dataloader(
|
||||
tar_url,
|
||||
num_workers,
|
||||
batch_size,
|
||||
embeddings_url=None,
|
||||
img_embeddings_url=None,
|
||||
text_embeddings_url=None,
|
||||
index_width=None,
|
||||
shuffle_num = None,
|
||||
shuffle_shards = True,
|
||||
@@ -211,7 +244,8 @@ def create_image_embedding_dataloader(
|
||||
"""
|
||||
ds = ImageEmbeddingDataset(
|
||||
tar_url,
|
||||
embeddings_url,
|
||||
img_embedding_folder_url=img_embeddings_url,
|
||||
text_embedding_folder_url=text_embeddings_url,
|
||||
index_width=index_width,
|
||||
shuffle_shards=shuffle_shards,
|
||||
resample=resample_shards,
|
||||
@@ -228,4 +262,4 @@ def create_image_embedding_dataloader(
|
||||
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
|
||||
pin_memory=True,
|
||||
shuffle=False
|
||||
)
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ from dalle2_pytorch.dalle2_pytorch import (
|
||||
Decoder,
|
||||
DiffusionPrior,
|
||||
DiffusionPriorNetwork,
|
||||
XClipAdapter,
|
||||
XClipAdapter
|
||||
)
|
||||
|
||||
# helper functions
|
||||
@@ -158,6 +158,8 @@ class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: ListOrTuple(int)
|
||||
image_embed_dim: int = None
|
||||
text_embed_dim: int = None
|
||||
cond_on_text_encodings: bool = None
|
||||
cond_dim: int = None
|
||||
channels: int = 3
|
||||
attn_dim_head: int = 32
|
||||
@@ -170,6 +172,7 @@ class DecoderConfig(BaseModel):
|
||||
unets: ListOrTuple(UnetConfig)
|
||||
image_size: int = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
loss_type: str = 'l2'
|
||||
@@ -180,9 +183,16 @@ class DecoderConfig(BaseModel):
|
||||
|
||||
def create(self):
|
||||
decoder_kwargs = self.dict()
|
||||
|
||||
unet_configs = decoder_kwargs.pop('unets')
|
||||
unets = [Unet(**config) for config in unet_configs]
|
||||
return Decoder(unets, **decoder_kwargs)
|
||||
|
||||
has_clip = exists(decoder_kwargs.pop('clip'))
|
||||
clip = None
|
||||
if has_clip:
|
||||
clip = self.clip.create()
|
||||
|
||||
return Decoder(unets, clip=clip, **decoder_kwargs)
|
||||
|
||||
@validator('image_sizes')
|
||||
def check_image_sizes(cls, image_sizes, values):
|
||||
@@ -194,8 +204,9 @@ class DecoderConfig(BaseModel):
|
||||
extra = "allow"
|
||||
|
||||
class DecoderDataConfig(BaseModel):
|
||||
webdataset_base_url: str # path to a webdataset with jpg images
|
||||
embeddings_url: str # path to .npy files with embeddings
|
||||
webdataset_base_url: str # path to a webdataset with jpg images
|
||||
img_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||
text_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||
num_workers: int = 4
|
||||
batch_size: int = 64
|
||||
start_shard: int = 0
|
||||
@@ -268,3 +279,32 @@ class TrainDecoderConfig(BaseModel):
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
@root_validator
|
||||
def check_has_embeddings(cls, values):
|
||||
# Makes sure that enough information is provided to get the embeddings specified for training
|
||||
data_config, decoder_config = values.get('data'), values.get('decoder')
|
||||
|
||||
if not exists(data_config) or not exists(decoder_config):
|
||||
# Then something else errored and we should just pass through
|
||||
return values
|
||||
|
||||
using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
|
||||
using_clip = exists(decoder_config.clip)
|
||||
img_emb_url = data_config.img_embeddings_url
|
||||
text_emb_url = data_config.text_embeddings_url
|
||||
|
||||
if using_text_embeddings:
|
||||
# Then we need some way to get the embeddings
|
||||
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
|
||||
|
||||
if using_clip:
|
||||
if using_text_embeddings:
|
||||
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
|
||||
else:
|
||||
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||
|
||||
if text_emb_url:
|
||||
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||
|
||||
return values
|
||||
|
||||
@@ -14,6 +14,8 @@ from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
import numpy as np
|
||||
@@ -62,16 +64,6 @@ def num_to_groups(num, divisor):
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
def clamp(value, min_value = None, max_value = None):
|
||||
assert exists(min_value) or exists(max_value)
|
||||
if exists(min_value):
|
||||
value = max(value, min_value)
|
||||
|
||||
if exists(max_value):
|
||||
value = min(value, max_value)
|
||||
|
||||
return value
|
||||
|
||||
# decorators
|
||||
|
||||
def cast_torch_tensor(fn):
|
||||
@@ -145,146 +137,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||
chunk_size_frac = chunk_size / batch_size
|
||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||
|
||||
# saving and loading functions
|
||||
|
||||
# for diffusion prior
|
||||
|
||||
def load_diffusion_model(dprior_path, device):
|
||||
dprior_path = Path(dprior_path)
|
||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
||||
|
||||
# Get hyperparameters of loaded model
|
||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
||||
|
||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
||||
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior, loaded_obj
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
print_ribbon('Saving checkpoint')
|
||||
|
||||
state_dict = dict(model=model.state_dict(),
|
||||
optimizer=optimizer.state_dict(),
|
||||
scaler=scaler.state_dict(),
|
||||
hparams = config,
|
||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
"""
|
||||
Implements exponential moving average shadowing for your model.
|
||||
|
||||
Utilizes an inverse decay schedule to manage longer term training runs.
|
||||
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
|
||||
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
||||
good values for models you plan to train for a million or more steps (reaches decay
|
||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||
215.4k steps).
|
||||
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.9999,
|
||||
update_after_step = 100,
|
||||
update_every = 10,
|
||||
inv_gamma = 1.0,
|
||||
power = 2/3,
|
||||
min_value = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.update_every = update_every
|
||||
self.update_after_step = update_after_step
|
||||
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
def restore_ema_model_device(self):
|
||||
device = self.initted.device
|
||||
self.ema_model.to(device)
|
||||
|
||||
def copy_params_from_model_to_ema(self):
|
||||
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
||||
ma_param.data.copy_(current_param.data)
|
||||
|
||||
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
|
||||
ma_buffer.data.copy_(current_buffer.data)
|
||||
|
||||
def get_current_decay(self):
|
||||
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
|
||||
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
||||
|
||||
if epoch <= 0:
|
||||
return 0.
|
||||
|
||||
return clamp(value, min_value = self.min_value, max_value = self.beta)
|
||||
|
||||
def update(self):
|
||||
step = self.step.item()
|
||||
self.step += 1
|
||||
|
||||
if (step % self.update_every) != 0:
|
||||
return
|
||||
|
||||
if step <= self.update_after_step:
|
||||
self.copy_params_from_model_to_ema()
|
||||
return
|
||||
|
||||
if not self.initted.item():
|
||||
self.copy_params_from_model_to_ema()
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
@torch.no_grad()
|
||||
def update_moving_average(self, ma_model, current_model):
|
||||
current_decay = self.get_current_decay()
|
||||
|
||||
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
|
||||
difference = ma_params.data - current_params.data
|
||||
difference.mul_(1.0 - current_decay)
|
||||
ma_params.sub_(difference)
|
||||
|
||||
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
|
||||
difference = ma_buffer - current_buffer
|
||||
difference.mul_(1.0 - current_decay)
|
||||
ma_buffer.sub_(difference)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
def prior_sample_in_chunks(fn):
|
||||
@@ -505,26 +357,20 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.p_sample_loop(*args, **kwargs)
|
||||
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||
return model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.sample(*args, **kwargs)
|
||||
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||
return model.sample(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.sample_batch_size(*args, **kwargs)
|
||||
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||
return model.sample_batch_size(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@@ -605,6 +451,8 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
|
||||
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
optimizers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
||||
@@ -730,6 +578,18 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_text(self, *args, **kwargs):
|
||||
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_image(self, *args, **kwargs):
|
||||
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.11.2'
|
||||
__version__ = '0.12.4'
|
||||
|
||||
@@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from dalle2_pytorch.train import EMA
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
@@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module):
|
||||
valid_frac = 0.05,
|
||||
random_split_seed = 42,
|
||||
ema_beta = 0.995,
|
||||
ema_update_after_step = 2000,
|
||||
ema_update_after_step = 500,
|
||||
ema_update_every = 10,
|
||||
apply_grad_penalty_every = 4,
|
||||
amp = False
|
||||
|
||||
1
setup.py
1
setup.py
@@ -28,6 +28,7 @@ setup(
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'coca-pytorch>=0.0.5',
|
||||
'ema-pytorch>=0.0.7',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
|
||||
176
train_decoder.py
176
train_decoder.py
@@ -6,6 +6,7 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||
from clip import tokenize
|
||||
|
||||
import torchvision
|
||||
import torch
|
||||
@@ -33,7 +34,8 @@ def exists(val):
|
||||
def create_dataloaders(
|
||||
available_shards,
|
||||
webdataset_base_url,
|
||||
embeddings_url,
|
||||
img_embeddings_url=None,
|
||||
text_embeddings_url=None,
|
||||
shard_width=6,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
@@ -63,14 +65,15 @@ def create_dataloaders(
|
||||
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
||||
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
||||
|
||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
|
||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
|
||||
tar_url=tar_urls,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size if not for_sampling else n_sample_images,
|
||||
embeddings_url=embeddings_url,
|
||||
img_embeddings_url=img_embeddings_url,
|
||||
text_embeddings_url=text_embeddings_url,
|
||||
index_width=index_width,
|
||||
shuffle_num = None,
|
||||
extra_keys= ["txt"] if with_text else [],
|
||||
extra_keys= ["txt"],
|
||||
shuffle_shards = shuffle,
|
||||
resample_shards = resample,
|
||||
img_preproc=img_preproc,
|
||||
@@ -79,8 +82,8 @@ def create_dataloaders(
|
||||
|
||||
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
||||
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
||||
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
|
||||
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
|
||||
val_dataloader = create_dataloader(val_urls, shuffle=False)
|
||||
test_dataloader = create_dataloader(test_urls, shuffle=False)
|
||||
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
||||
return {
|
||||
"train": train_dataloader,
|
||||
@@ -104,42 +107,65 @@ def get_example_data(dataloader, device, n=5):
|
||||
Samples the dataloader and returns a zipped list of examples
|
||||
"""
|
||||
images = []
|
||||
embeddings = []
|
||||
img_embeddings = []
|
||||
text_embeddings = []
|
||||
captions = []
|
||||
dataset_keys = get_dataset_keys(dataloader)
|
||||
has_caption = "txt" in dataset_keys
|
||||
for data in dataloader:
|
||||
if has_caption:
|
||||
img, emb, txt = data
|
||||
for img, emb, txt in dataloader:
|
||||
img_emb, text_emb = emb.get('img'), emb.get('text')
|
||||
if img_emb is not None:
|
||||
img_emb = img_emb.to(device=device, dtype=torch.float)
|
||||
img_embeddings.extend(list(img_emb))
|
||||
else:
|
||||
img, emb = data
|
||||
txt = [""] * emb.shape[0]
|
||||
# Then we add None img.shape[0] times
|
||||
img_embeddings.extend([None]*img.shape[0])
|
||||
if text_emb is not None:
|
||||
text_emb = text_emb.to(device=device, dtype=torch.float)
|
||||
text_embeddings.extend(list(text_emb))
|
||||
else:
|
||||
# Then we add None img.shape[0] times
|
||||
text_embeddings.extend([None]*img.shape[0])
|
||||
img = img.to(device=device, dtype=torch.float)
|
||||
emb = emb.to(device=device, dtype=torch.float)
|
||||
images.extend(list(img))
|
||||
embeddings.extend(list(emb))
|
||||
captions.extend(list(txt))
|
||||
if len(images) >= n:
|
||||
break
|
||||
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, text_prepend=""):
|
||||
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
"""
|
||||
real_images, embeddings, txts = zip(*example_data)
|
||||
embeddings_tensor = torch.stack(embeddings)
|
||||
samples = trainer.sample(embeddings_tensor)
|
||||
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
|
||||
sample_params = {}
|
||||
if img_embeddings[0] is None:
|
||||
# Generate image embeddings from clip
|
||||
imgs_tensor = torch.stack(real_images)
|
||||
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
else:
|
||||
# Then we are using precomputed image embeddings
|
||||
img_embeddings = torch.stack(img_embeddings)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
if condition_on_text_encodings:
|
||||
if text_embeddings[0] is None:
|
||||
# Generate text embeddings from text
|
||||
tokenized_texts = tokenize(txts, truncate=True)
|
||||
sample_params["text"] = tokenized_texts
|
||||
else:
|
||||
# Then we are using precomputed text embeddings
|
||||
text_embeddings = torch.stack(text_embeddings)
|
||||
sample_params["text_encodings"] = text_embeddings
|
||||
samples = trainer.sample(**sample_params)
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
||||
"""
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
||||
|
||||
real_image_size = real_images[0].shape[-1]
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
@@ -151,7 +177,7 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
"""
|
||||
Computes evaluation metrics for the decoder
|
||||
"""
|
||||
@@ -161,7 +187,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
|
||||
if len(examples) == 0:
|
||||
print("No data to evaluate. Check that your dataloader has shards.")
|
||||
return metrics
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
@@ -250,6 +276,7 @@ def train(
|
||||
save_latest=True,
|
||||
save_best=True,
|
||||
unet_training_mask=None,
|
||||
condition_on_text_encodings=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -258,8 +285,8 @@ def train(
|
||||
is_master = accelerator.process_index == 0
|
||||
|
||||
trainer = DecoderTrainer(
|
||||
accelerator,
|
||||
decoder,
|
||||
decoder=decoder,
|
||||
accelerator=accelerator,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -268,6 +295,7 @@ def train(
|
||||
validation_losses = []
|
||||
next_task = 'train'
|
||||
sample = 0
|
||||
samples_seen = 0
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.step.item())
|
||||
|
||||
@@ -306,13 +334,22 @@ def train(
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
total_samples = all_samples.sum().item()
|
||||
sample += total_samples
|
||||
img, emb = send_to_device((img, emb))
|
||||
samples_seen += total_samples
|
||||
img_emb = emb.get('img')
|
||||
has_img_embedding = img_emb is not None
|
||||
if has_img_embedding:
|
||||
img_emb, = send_to_device((img_emb,))
|
||||
text_emb = emb.get('text')
|
||||
has_text_embedding = text_emb is not None
|
||||
if has_text_embedding:
|
||||
text_emb, = send_to_device((text_emb,))
|
||||
img, = send_to_device((img,))
|
||||
|
||||
trainer.train()
|
||||
for unet in range(1, trainer.num_unets+1):
|
||||
@@ -320,7 +357,20 @@ def train(
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
continue
|
||||
|
||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
||||
forward_params = {}
|
||||
if has_img_embedding:
|
||||
forward_params['image_embed'] = img_emb
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||
|
||||
@@ -334,14 +384,20 @@ def train(
|
||||
mask = unet_all_losses != 0
|
||||
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
||||
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
||||
|
||||
# gather decay rate on each UNet
|
||||
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
|
||||
|
||||
log_data = {
|
||||
"Epoch": epoch,
|
||||
"Sample": sample,
|
||||
"Step": i,
|
||||
"Samples per second": samples_per_sec,
|
||||
"Samples Seen": samples_seen,
|
||||
**ema_decay_list,
|
||||
**loss_map
|
||||
}
|
||||
# print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}")
|
||||
|
||||
if is_master:
|
||||
tracker.log(log_data, step=step(), verbose=True)
|
||||
|
||||
@@ -358,7 +414,7 @@ def train(
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
if epoch_samples is not None and sample >= epoch_samples:
|
||||
@@ -381,14 +437,35 @@ def train(
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
val_sample += total_samples
|
||||
img, emb = send_to_device((img, emb))
|
||||
img_emb = emb.get('img')
|
||||
has_img_embedding = img_emb is not None
|
||||
if has_img_embedding:
|
||||
img_emb, = send_to_device((img_emb,))
|
||||
text_emb = emb.get('text')
|
||||
has_text_embedding = text_emb is not None
|
||||
if has_text_embedding:
|
||||
text_emb, = send_to_device((text_emb,))
|
||||
img, = send_to_device((img,))
|
||||
|
||||
for unet in range(1, len(decoder.unets)+1):
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
# No need to evaluate an unchanging unet
|
||||
continue
|
||||
|
||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
||||
|
||||
forward_params = {}
|
||||
if has_img_embedding:
|
||||
forward_params['image_embed'] = img_emb.float()
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb.float()
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
|
||||
average_val_loss_tensor[0, unet-1] += loss
|
||||
|
||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||
@@ -415,7 +492,7 @@ def train(
|
||||
if next_task == 'eval':
|
||||
if exists(evaluate_config):
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step(), verbose=True)
|
||||
next_task = 'sample'
|
||||
@@ -426,8 +503,8 @@ def train(
|
||||
# Generate examples and save the model if we are the master
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
@@ -517,14 +594,37 @@ def initialize_training(config, config_path):
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
||||
|
||||
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
|
||||
|
||||
has_clip_model = config.decoder.clip is not None
|
||||
data_source_string = ""
|
||||
|
||||
if has_img_embeddings:
|
||||
data_source_string += "precomputed image embeddings"
|
||||
elif has_clip_model:
|
||||
data_source_string += "clip image embeddings generation"
|
||||
else:
|
||||
raise ValueError("No image embeddings source specified")
|
||||
if conditioning_on_text:
|
||||
if has_text_embeddings:
|
||||
data_source_string += " and precomputed text embeddings"
|
||||
elif has_clip_model:
|
||||
data_source_string += " and clip text encoding generation"
|
||||
else:
|
||||
raise ValueError("No text embeddings source specified")
|
||||
|
||||
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
||||
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||
train(dataloaders, decoder, accelerator,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
condition_on_text_encodings=conditioning_on_text,
|
||||
**config.train.dict(),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user