Compare commits

..

1 Commits

5 changed files with 53 additions and 245 deletions

View File

@@ -44,7 +44,6 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
@@ -582,8 +581,7 @@ unet1 = Unet(
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
@@ -600,11 +598,12 @@ decoder = Decoder(
clip = clip,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
for unet_number in (1, 2):
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -991,7 +990,34 @@ dataset = ImageEmbeddingDataset(
#### `train_diffusion_prior.py`
For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
#### Usage
```bash
$ python train_diffusion_prior.py
```
The most significant parameters for the script are as follows:
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
- `learning-rate`, default = `1.1e-4`
- `weight-decay`, default = `6.02e-2`
- `max-grad-norm`, default = `0.5`
- `batch-size`, default = `10 ** 4`
- `num-epochs`, default = `5`
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
## CLI (wip)

View File

@@ -125,28 +125,14 @@ def log(t, eps = 1e-12):
def l2norm(t):
return F.normalize(t, dim = -1)
def resize_image_to(
image,
target_image_size,
clamp_range = None,
nearest = False,
**kwargs
):
def resize_image_to(image, target_image_size):
orig_image_size = image.shape[-1]
if orig_image_size == target_image_size:
return image
if not nearest:
scale_factors = target_image_size / orig_image_size
out = resize(image, scale_factors = scale_factors, **kwargs)
else:
out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False)
if exists(clamp_range):
out = out.clamp(*clamp_range)
return out
scale_factors = target_image_size / orig_image_size
return resize(image, scale_factors = scale_factors)
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
@@ -351,7 +337,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
# attempting to correct nan gradients when learned variance is turned on
# in the setting of deepspeed fp16
eps = 1e-12 if x.dtype == torch.float32 else 1e-3
eps = 1e-12 if x.dtype == torch.float32 else 1e-5
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
@@ -359,8 +345,8 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1. / 255.)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = log(cdf_plus, eps = eps)
log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)
log_cdf_plus = log(cdf_plus)
log_one_minus_cdf_min = log(1. - cdf_min)
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(x < -thres,
@@ -1790,17 +1776,11 @@ class LowresConditioner(nn.Module):
def __init__(
self,
downsample_first = True,
downsample_mode_nearest = False,
blur_sigma = 0.6,
blur_kernel_size = 3,
input_image_range = None
):
super().__init__()
self.downsample_first = downsample_first
self.downsample_mode_nearest = downsample_mode_nearest
self.input_image_range = input_image_range
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
@@ -1814,7 +1794,7 @@ class LowresConditioner(nn.Module):
blur_kernel_size = None
):
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
cond_fmap = resize_image_to(cond_fmap, downsample_image_size)
if self.training:
# when training, blur the low resolution conditional image
@@ -1834,7 +1814,7 @@ class LowresConditioner(nn.Module):
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
cond_fmap = resize_image_to(cond_fmap, target_image_size)
return cond_fmap
@@ -1857,7 +1837,6 @@ class Decoder(nn.Module):
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
blur_sigma = 0.6, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
clip_denoised = True,
@@ -1951,6 +1930,10 @@ class Decoder(nn.Module):
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# determine from unets whether conditioning on text encoding is needed
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
# create noise schedulers per unet
if not exists(beta_schedule):
@@ -1989,10 +1972,6 @@ class Decoder(nn.Module):
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# input image range
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
@@ -2000,10 +1979,8 @@ class Decoder(nn.Module):
self.to_lowres_cond = LowresConditioner(
downsample_first = lowres_downsample_first,
downsample_mode_nearest = lowres_downsample_mode_nearest,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
input_image_range = self.input_image_range
)
# classifier free guidance
@@ -2035,10 +2012,6 @@ class Decoder(nn.Module):
def device(self):
return self._dummy.device
@property
def condition_on_text_encodings(self):
return any([unet.cond_on_text_encodings for unet in self.unets])
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1

View File

@@ -192,7 +192,6 @@ class DiffusionPriorTrainer(nn.Module):
self.device = diffusion_prior_device
else:
self.device = accelerator.device if exists(accelerator) else device
diffusion_prior.to(self.device)
# save model
@@ -229,7 +228,7 @@ class DiffusionPriorTrainer(nn.Module):
# track steps internally
self.register_buffer('step', torch.tensor([0], device = self.device))
self.register_buffer('step', torch.tensor([0], device = device))
# accelerator wrappers
@@ -474,7 +473,7 @@ class DecoderTrainer(nn.Module):
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
optimizers = []
schedulers = []
@@ -527,17 +526,6 @@ class DecoderTrainer(nn.Module):
self.warmup_schedulers = warmup_schedulers
def validate_and_return_unet_number(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
return unet_number
def num_steps_taken(self, unet_number = None):
unet_number = self.validate_and_return_unet_number(unet_number)
return self.steps[unet_number - 1].item()
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
@@ -606,7 +594,10 @@ class DecoderTrainer(nn.Module):
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
def update(self, unet_number = None):
unet_number = self.validate_and_return_unet_number(unet_number)
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
index = unet_number - 1
optimizer = getattr(self, f'optim{index}')
@@ -672,7 +663,8 @@ class DecoderTrainer(nn.Module):
max_batch_size = None,
**kwargs
):
unet_number = self.validate_and_return_unet_number(unet_number)
if self.num_unets == 1:
unet_number = default(unet_number, 1)
total_loss = 0.

View File

@@ -1 +1 @@
__version__ = '0.16.19'
__version__ = '0.16.11'

183
prior.md
View File

@@ -1,183 +0,0 @@
# Diffusion Prior
This readme serves as an introduction to the diffusion prior.
## Intro
A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful.
### Motivation
Before we dive into the model, lets look at a quick example of where the model may be helpful.
For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.
> [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets.
```python
# Load Models
clip_model = clip.load("ViT-L/14")
decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings
# Retrieve prompt from user and encode with CLIP
prompt = "A corgi wearing sunglasses"
tokenized_text = tokenize(prompt)
text_embedding = clip_model.encode_text(tokenized_text)
# Now, pass the text embedding to the decoder
predicted_image = decoder.sample(text_embedding)
```
> **Question**: *Can you spot the issue here?*
>
> **Answer**: *Were trying to generate an image from a text embedding!*
Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution
```python
# Load Models
prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb
decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings
# Retrieve prompt from user and encode with a prior
prompt = "A corgi wearing sunglasses"
tokenized_text = tokenize(prompt)
text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!
# Now, pass the predicted image embedding to the decoder
predicted_image = decoder.sample(text_embedding)
```
With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.
> **You may be asking yourself the following question:**
>
> *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"*
>
> OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile:
## Usage
To utilize a pre-trained prior, its quite simple.
### Loading Checkpoints
```python
import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer
def load_diffusion_model(dprior_path):
prior_network = DiffusionPriorNetwork(
dim=768,
depth=24,
dim_head=64,
heads=32,
normformer=True,
attn_dropout=5e-2,
ff_dropout=5e-2,
num_time_embeds=1,
num_image_embeds=1,
num_text_embeds=1,
num_timesteps=1000,
ff_mult=4
)
diffusion_prior = DiffusionPrior(
net=prior_network,
clip=OpenAIClipAdapter("ViT-L/14"),
image_embed_dim=768,
timesteps=1000,
cond_drop_prob=0.1,
loss_type="l2",
condition_on_text_encodings=True,
)
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=1.1e-4,
wd=6.02e-2,
max_grad_norm=0.5,
amp=False,
group_wd_params=True,
use_ema=True,
device=device,
accelerator=None,
)
trainer.load(dprior_path)
return trainer
```
Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*)
### Sampling
Once we have a pre-trained model, generating embeddings is quite simple!
```python
# tokenize the text
tokenized_text = clip.tokenize("<your amazing prompt>")
# predict an embedding
predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)
```
The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768).
> For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().
**Some things to note:**
* It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.
* You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*.
---
## Training
### Overview
Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration
## Dataset
To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader.
## Configuration
The configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that will specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.
## Distributed Training
If you would like to train in a distributed manner we have opted to leverage huggingface new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPUs and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator).
## Evaluation
There are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:
| Metric | Description | Comments |
| ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Online Model Validation | The validation loss associated with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. |
| EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your "online" model's validation loss, but should outperform in the long-term. |
| Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. |
| Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) |
| Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting somehow. |
| Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. |
| Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. |
## Launching the script
Now that youve done all the prep its time for the easy part! 🚀
To actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path <path to your config>` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate.
## Checkpointing
Checkpoints will be saved to the directory specified in your configuration file.
Additionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and titled “latest.pth”. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data.
## Things To Keep In Mind
The prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet.
As we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.
With that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you dont see documentation for!