mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
60 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6647050c33 | ||
|
|
b4c3e5b854 | ||
|
|
b7f9607258 | ||
|
|
2219348a6e | ||
|
|
9eea9b9862 | ||
|
|
5d958713c0 | ||
|
|
0f31980362 | ||
|
|
bee5bf3815 | ||
|
|
350a3d6045 | ||
|
|
1a81670718 | ||
|
|
934c9728dc | ||
|
|
ce4b0107c1 | ||
|
|
64c2f9c4eb | ||
|
|
22cc613278 | ||
|
|
83517849e5 | ||
|
|
708809ed6c | ||
|
|
9cc475f6e7 | ||
|
|
ffd342e9d0 | ||
|
|
f8bfd3493a | ||
|
|
9025345e29 | ||
|
|
8cc278447e | ||
|
|
38cd62010c | ||
|
|
1cc288af39 | ||
|
|
a851168633 | ||
|
|
1ffeecd0ca | ||
|
|
3df899f7a4 | ||
|
|
09534119a1 | ||
|
|
6f8b90d4d7 | ||
|
|
b588286288 | ||
|
|
b693e0be03 | ||
|
|
a0bed30a84 | ||
|
|
387c5bf774 | ||
|
|
a13d2d89c5 | ||
|
|
44d4b1bba9 | ||
|
|
f12a7589c5 | ||
|
|
b8af2210df | ||
|
|
f4fe6c570d | ||
|
|
645e207441 | ||
|
|
00743b3a0b | ||
|
|
01589aff6a | ||
|
|
7ecfd76cc0 | ||
|
|
6161b61c55 | ||
|
|
1ed0f9d80b | ||
|
|
f326a95e26 | ||
|
|
d7a0a2ce4b | ||
|
|
f23fab7ef7 | ||
|
|
857b9fbf1e | ||
|
|
8864fd0aa7 | ||
|
|
72bf159331 | ||
|
|
e5e47cfecb | ||
|
|
fa533962bd | ||
|
|
276abf337b | ||
|
|
ae42d03006 | ||
|
|
4d346e98d9 | ||
|
|
2b1fd1ad2e | ||
|
|
82a2ef37d9 | ||
|
|
5c397c9d66 | ||
|
|
0f4edff214 | ||
|
|
501a8c7c46 | ||
|
|
4e49373fc5 |
43
README.md
43
README.md
@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
||||
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
|
||||
|
||||
## Status
|
||||
|
||||
@@ -24,6 +24,13 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||
|
||||
## Pre-Trained Models
|
||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||
- DALL-E 2 🚧
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -936,7 +943,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
@@ -1043,6 +1050,7 @@ This library would not have gotten to this working state without the help of
|
||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||
|
||||
... and many others. Thank you! 🙏
|
||||
|
||||
@@ -1078,20 +1086,20 @@ This library would not have gotten to this working state without the help of
|
||||
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [x] use pydantic for config drive training
|
||||
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||
|
||||
## Citations
|
||||
@@ -1135,8 +1143,9 @@ This library would not have gotten to this working state without the help of
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022}
|
||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022},
|
||||
url = {https://arxiv.org/abs/2204.01697}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1190,4 +1199,22 @@ This library would not have gotten to this working state without the help of
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{Saharia2022,
|
||||
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
||||
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Choi2022PerceptionPT,
|
||||
title = {Perception Prioritized Training of Diffusion Models},
|
||||
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2204.00227}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -6,9 +6,10 @@ For more complex configuration, we provide the option of using a configuration f
|
||||
|
||||
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
||||
|
||||
**<ins>Unets</ins>:**
|
||||
**<ins>Unet</ins>:**
|
||||
|
||||
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
|
||||
|
||||
Each member of this array defines a single unet that will be added to the decoder.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `dim` | Yes | N/A | The starting channels of the unet. |
|
||||
@@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here.
|
||||
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `unets` | Yes | N/A | A list of unets, using the configuration above |
|
||||
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
||||
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
||||
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
||||
@@ -81,7 +83,7 @@ Defines which evaluation metrics will be used to test the model.
|
||||
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||
| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
||||
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
||||
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
{
|
||||
"unets": [
|
||||
{
|
||||
"dim": 128,
|
||||
"image_embed_dim": 768,
|
||||
"cond_dim": 64,
|
||||
"channels": 3,
|
||||
"dim_mults": [1, 2, 4, 8],
|
||||
"attn_dim_head": 32,
|
||||
"attn_heads": 16
|
||||
}
|
||||
],
|
||||
"decoder": {
|
||||
"unets": [
|
||||
{
|
||||
"dim": 128,
|
||||
"image_embed_dim": 768,
|
||||
"cond_dim": 64,
|
||||
"channels": 3,
|
||||
"dim_mults": [1, 2, 4, 8],
|
||||
"attn_dim_head": 32,
|
||||
"attn_heads": 16
|
||||
}
|
||||
],
|
||||
"image_sizes": [64],
|
||||
"channels": 3,
|
||||
"timesteps": 1000,
|
||||
|
||||
70
configs/train_prior_config.example.json
Normal file
70
configs/train_prior_config.example.json
Normal file
@@ -0,0 +1,70 @@
|
||||
{
|
||||
"prior": {
|
||||
"clip": {
|
||||
"make": "x-clip",
|
||||
"model": "ViT-L/14",
|
||||
"base_model_kwargs": {
|
||||
"dim_text": 768,
|
||||
"dim_image": 768,
|
||||
"dim_latent": 768
|
||||
}
|
||||
},
|
||||
"net": {
|
||||
"dim": 768,
|
||||
"depth": 12,
|
||||
"num_timesteps": 1000,
|
||||
"num_time_embeds": 1,
|
||||
"num_image_embeds": 1,
|
||||
"num_text_embeds": 1,
|
||||
"dim_head": 64,
|
||||
"heads": 12,
|
||||
"ff_mult": 4,
|
||||
"norm_out": true,
|
||||
"attn_dropout": 0.0,
|
||||
"ff_dropout": 0.0,
|
||||
"final_proj": true,
|
||||
"normformer": true,
|
||||
"rotary_emb": true
|
||||
},
|
||||
"image_embed_dim": 768,
|
||||
"image_size": 224,
|
||||
"image_channels": 3,
|
||||
"timesteps": 1000,
|
||||
"cond_drop_prob": 0.1,
|
||||
"loss_type": "l2",
|
||||
"predict_x_start": true,
|
||||
"beta_schedule": "cosine",
|
||||
"condition_on_text_encodings": true
|
||||
},
|
||||
"data": {
|
||||
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
|
||||
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
|
||||
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
|
||||
"batch_size": 256,
|
||||
"splits": {
|
||||
"train": 0.9,
|
||||
"val": 1e-7,
|
||||
"test": 0.0999999
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 1,
|
||||
"lr": 1.1e-4,
|
||||
"wd": 6.02e-2,
|
||||
"max_grad_norm": 0.5,
|
||||
"use_ema": true,
|
||||
"amp": false,
|
||||
"save_every": 10000
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
"resume": false
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "wandb",
|
||||
"data_path": "./prior_checkpoints",
|
||||
"wandb_entity": "laion",
|
||||
"wandb_project": "diffusion-prior",
|
||||
"verbose": true
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
from dalle2_pytorch.version import __version__
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import math
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial, wraps
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
import torchvision.transforms as T
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
@@ -56,7 +56,7 @@ def maybe(fn):
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
return d() if callable(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
if isinstance(val, list):
|
||||
@@ -313,11 +313,6 @@ def extract(a, t, x_shape):
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
def meanflat(x):
|
||||
return x.mean(dim = tuple(range(1, len(x.shape))))
|
||||
|
||||
@@ -372,7 +367,7 @@ def quadratic_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
||||
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
|
||||
|
||||
|
||||
def sigmoid_beta_schedule(timesteps):
|
||||
@@ -384,7 +379,7 @@ def sigmoid_beta_schedule(timesteps):
|
||||
|
||||
|
||||
class BaseGaussianDiffusion(nn.Module):
|
||||
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
||||
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
|
||||
super().__init__()
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
@@ -449,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
# p2 loss reweighting
|
||||
|
||||
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
|
||||
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
@@ -890,6 +890,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
)
|
||||
|
||||
if exists(clip):
|
||||
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
@@ -943,10 +945,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
@@ -1082,8 +1084,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
def Upsample(dim):
|
||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
|
||||
def Downsample(dim):
|
||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||
def Downsample(dim, *, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
@@ -1105,13 +1108,20 @@ class Block(nn.Module):
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
nn.SiLU()
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||
self.norm = nn.GroupNorm(groups, dim_out)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x, scale_shift = None):
|
||||
x = self.project(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if exists(scale_shift):
|
||||
scale, shift = scale_shift
|
||||
x = x * (scale + 1) + shift
|
||||
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
@@ -1130,7 +1140,7 @@ class ResnetBlock(nn.Module):
|
||||
if exists(time_cond_dim):
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_cond_dim, dim_out)
|
||||
nn.Linear(time_cond_dim, dim_out * 2)
|
||||
)
|
||||
|
||||
self.cross_attn = None
|
||||
@@ -1150,11 +1160,14 @@ class ResnetBlock(nn.Module):
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None, time_emb = None):
|
||||
h = self.block1(x)
|
||||
|
||||
scale_shift = None
|
||||
if exists(self.time_mlp) and exists(time_emb):
|
||||
time_emb = self.time_mlp(time_emb)
|
||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
||||
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
|
||||
scale_shift = time_emb.chunk(2, dim = 1)
|
||||
|
||||
h = self.block1(x, scale_shift = scale_shift)
|
||||
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
@@ -1331,12 +1344,15 @@ class Unet(nn.Module):
|
||||
cond_on_text_encodings = False,
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
resnet_groups = 8,
|
||||
num_resnet_blocks = 2,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
memory_efficient = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1356,7 +1372,7 @@ class Unet(nn.Module):
|
||||
self.channels_out = default(channels_out, channels)
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
init_dim = default(init_dim, dim)
|
||||
|
||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
||||
|
||||
@@ -1383,11 +1399,16 @@ class Unet(nn.Module):
|
||||
nn.Linear(time_cond_dim, time_cond_dim)
|
||||
)
|
||||
|
||||
self.image_to_cond = nn.Sequential(
|
||||
self.image_to_tokens = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.to_image_hiddens = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, time_cond_dim),
|
||||
nn.GELU()
|
||||
) if cond_on_image_embeds and add_image_embeds_to_time else None
|
||||
|
||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||
|
||||
@@ -1408,6 +1429,7 @@ class Unet(nn.Module):
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
|
||||
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
@@ -1419,6 +1441,7 @@ class Unet(nn.Module):
|
||||
# resnet block klass
|
||||
|
||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
|
||||
@@ -1434,16 +1457,17 @@ class Unet(nn.Module):
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
||||
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
||||
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
downsample_klass(dim_out) if not is_last and not memory_efficient else None
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
@@ -1452,19 +1476,23 @@ class Unet(nn.Module):
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
|
||||
up_in_out_slice = slice(1 if not memory_efficient else None, None)
|
||||
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[up_in_out_slice]), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
final_dim_in = dim * (1 if memory_efficient else 2)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
||||
ResnetBlock(final_dim_in, dim, groups = resnet_groups[0]),
|
||||
nn.Conv2d(dim, self.channels_out, 1)
|
||||
)
|
||||
|
||||
@@ -1549,7 +1577,23 @@ class Unet(nn.Module):
|
||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
||||
|
||||
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
||||
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
|
||||
|
||||
# image embedding to be summed to time embedding
|
||||
# discovered by @mhh0318 in the paper
|
||||
|
||||
if exists(image_embed) and exists(self.to_image_hiddens):
|
||||
image_hiddens = self.to_image_hiddens(image_embed)
|
||||
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
|
||||
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
|
||||
|
||||
image_hiddens = torch.where(
|
||||
image_keep_mask_hidden,
|
||||
image_hiddens,
|
||||
null_image_hiddens
|
||||
)
|
||||
|
||||
t = t + image_hiddens
|
||||
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
@@ -1557,11 +1601,12 @@ class Unet(nn.Module):
|
||||
image_tokens = None
|
||||
|
||||
if self.cond_on_image_embeds:
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
|
||||
image_tokens = self.image_to_tokens(image_embed)
|
||||
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
image_tokens = torch.where(
|
||||
image_keep_mask,
|
||||
image_keep_mask_embed,
|
||||
image_tokens,
|
||||
null_image_embed
|
||||
)
|
||||
@@ -1616,12 +1661,20 @@ class Unet(nn.Module):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for block1, sparse_attn, block2, downsample in self.downs:
|
||||
x = block1(x, c, t)
|
||||
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
|
||||
if exists(pre_downsample):
|
||||
x = pre_downsample(x)
|
||||
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = block2(x, c, t)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
if exists(post_downsample):
|
||||
x = post_downsample(x)
|
||||
|
||||
x = self.mid_block1(x, mid_c, t)
|
||||
|
||||
@@ -1630,20 +1683,26 @@ class Unet(nn.Module):
|
||||
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
|
||||
for block1, sparse_attn, block2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = block1(x, c, t)
|
||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = block2(x, c, t)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
|
||||
x = upsample(x)
|
||||
|
||||
if len(hiddens) > 0:
|
||||
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
downsample_first = True,
|
||||
blur_sigma = 0.1,
|
||||
blur_sigma = (0.1, 0.2),
|
||||
blur_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1667,6 +1726,18 @@ class LowresConditioner(nn.Module):
|
||||
# when training, blur the low resolution conditional image
|
||||
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||
|
||||
# allow for drawing a random sigma between lo and hi float values
|
||||
if isinstance(blur_sigma, tuple):
|
||||
blur_sigma = tuple(map(float, blur_sigma))
|
||||
blur_sigma = random.uniform(*blur_sigma)
|
||||
|
||||
# allow for drawing a random kernel size between lo and hi int values
|
||||
if isinstance(blur_kernel_size, tuple):
|
||||
blur_kernel_size = tuple(map(int, blur_kernel_size))
|
||||
kernel_size_lo, kernel_size_hi = blur_kernel_size
|
||||
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
|
||||
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
||||
@@ -1692,30 +1763,44 @@ class Decoder(BaseGaussianDiffusion):
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict(),
|
||||
learned_variance = True,
|
||||
learned_variance_constrain_frac = False,
|
||||
vb_loss_weight = 0.001,
|
||||
unconditional = False,
|
||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||
use_dynamic_thres = False, # from the Imagen paper
|
||||
dynamic_thres_percentile = 0.9,
|
||||
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||
p2_loss_weight_k = 1
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type
|
||||
loss_type = loss_type,
|
||||
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
||||
p2_loss_weight_k = p2_loss_weight_k
|
||||
)
|
||||
|
||||
self.unconditional = unconditional
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
# text conditioning
|
||||
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# clip
|
||||
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
assert not unconditional, 'clip must not be given if doing unconditional image training'
|
||||
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
@@ -1725,13 +1810,20 @@ class Decoder(BaseGaussianDiffusion):
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
|
||||
self.clip = clip
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
else:
|
||||
self.clip_image_size = image_size
|
||||
self.channels = channels
|
||||
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
# determine image size, with image_size and image_sizes taking precedence
|
||||
|
||||
if exists(image_size) or exists(image_sizes):
|
||||
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
|
||||
image_size = default(image_size, lambda: image_sizes[-1])
|
||||
elif exists(clip):
|
||||
image_size = clip.image_size
|
||||
else:
|
||||
raise Error('either image_size, image_sizes, or clip must be given to decoder')
|
||||
|
||||
# channels
|
||||
|
||||
self.channels = channels
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
@@ -1743,6 +1835,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
||||
self.learned_variance = learned_variance
|
||||
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
|
||||
self.vb_loss_weight = vb_loss_weight
|
||||
|
||||
# construct unets and vaes
|
||||
@@ -1773,7 +1866,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
# unet image sizes
|
||||
|
||||
image_sizes = default(image_sizes, (self.clip_image_size,))
|
||||
image_sizes = default(image_sizes, (image_size,))
|
||||
image_sizes = tuple(sorted(set(image_sizes)))
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
@@ -1810,7 +1903,13 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.clip_denoised = clip_denoised
|
||||
self.clip_x_start = clip_x_start
|
||||
|
||||
# dynamic thresholding settings, if clipping denoised during sampling
|
||||
|
||||
self.use_dynamic_thres = use_dynamic_thres
|
||||
self.dynamic_thres_percentile = dynamic_thres_percentile
|
||||
|
||||
# normalize and unnormalize image functions
|
||||
|
||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||
|
||||
@@ -1851,7 +1950,21 @@ class Decoder(BaseGaussianDiffusion):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
# s is the threshold amount
|
||||
# static thresholding would just be s = 1
|
||||
s = 1.
|
||||
if self.use_dynamic_thres:
|
||||
s = torch.quantile(
|
||||
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
||||
self.dynamic_thres_percentile,
|
||||
dim = -1
|
||||
)
|
||||
|
||||
s.clamp_(min = 1.)
|
||||
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
||||
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x_recon = x_recon.clamp(-s, s) / s
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
|
||||
@@ -1863,16 +1976,19 @@ class Decoder(BaseGaussianDiffusion):
|
||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||
|
||||
if self.learned_variance_constrain_frac:
|
||||
var_interp_frac = var_interp_frac.sigmoid()
|
||||
|
||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||
posterior_variance = posterior_log_variance.exp()
|
||||
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
@@ -1936,7 +2052,13 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
loss = self.loss_fn(pred, target, reduction = 'none')
|
||||
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||
|
||||
if self.has_p2_loss_reweighting:
|
||||
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
if not learned_variance:
|
||||
# return simple loss if not using learned variance
|
||||
|
||||
@@ -4,7 +4,7 @@ In order to make loading data simple and efficient, we include some general data
|
||||
### Decoder: Image Embedding Dataset
|
||||
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
|
||||
|
||||
Generating a dataset of this type:
|
||||
Generating a dataset of this type:
|
||||
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
|
||||
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
|
||||
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
|
||||
@@ -15,7 +15,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
@@ -39,3 +39,37 @@ dataset = ImageEmbeddingDataset(
|
||||
)
|
||||
```
|
||||
|
||||
### Diffusion Prior: Prior Embedding Dataset
|
||||
When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
|
||||
|
||||
To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
|
||||
|
||||
If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||
|
||||
# grab embeddings from some specified location
|
||||
IMG_URL = "data/img_emb/"
|
||||
META_URL = "data/meta/"
|
||||
|
||||
reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
|
||||
|
||||
# some config for training
|
||||
TRAIN_ARGS = {
|
||||
"world_size": 3,
|
||||
"text_conditioned": True,
|
||||
"start": 0,
|
||||
"num_data_points": 10000,
|
||||
"batch_size": 2,
|
||||
"train_split": 0.5,
|
||||
"eval_split": 0.25,
|
||||
"image_reader": reader,
|
||||
}
|
||||
|
||||
# specifying a rank will handle allocation internally
|
||||
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
|
||||
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
|
||||
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
|
||||
```
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits
|
||||
from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch import from_numpy
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
|
||||
class PriorEmbeddingLoader(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
device: str = "cpu",
|
||||
) -> None:
|
||||
super(PriorEmbeddingLoader).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.batch_size = batch_size
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.device = device
|
||||
|
||||
def __iter__(self):
|
||||
self.n = 0
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
self.n += 1
|
||||
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
tokenized_caption = tokenize(
|
||||
caption["caption"].to_list(), truncate=True
|
||||
).to(self.device)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
text_embedding = from_numpy(text_embedding).to(self.device)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
device: str,
|
||||
img_url: str,
|
||||
meta_url: str = None,
|
||||
txt_url: str = None,
|
||||
):
|
||||
|
||||
assert img_url is not None, "Must supply some image embeddings"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply metadata url if text-conditioning"
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
else:
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from math import ceil
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
from torch import from_numpy
|
||||
from torch.utils.data import IterableDataset, DataLoader
|
||||
|
||||
|
||||
class PriorEmbeddingDataset(IterableDataset):
|
||||
"""
|
||||
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
|
||||
|
||||
It enables one to simplify the logic necessary to yield samples from
|
||||
the different EmbeddingReader configurations available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
) -> None:
|
||||
super(PriorEmbeddingDataset).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self.stop - self.start
|
||||
|
||||
def __iter__(self):
|
||||
# D.R.Y loader args
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
# if the data requested is text conditioned, only load images
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
# otherwise, include text embeddings and bypass metadata
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
|
||||
# return the data loader in its formatted state
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def __str__(self):
|
||||
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding)
|
||||
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding)
|
||||
text_embedding = from_numpy(text_embedding)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def distribute_to_rank(start, stop, rank, world_size):
|
||||
"""
|
||||
Distribute data to each rank given the world size.
|
||||
|
||||
Return:
|
||||
- New start and stop points for this rank.
|
||||
"""
|
||||
num_samples = int(stop - start)
|
||||
|
||||
per_rank = int(ceil((num_samples) / float(world_size)))
|
||||
|
||||
assert (
|
||||
per_rank > 0
|
||||
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
|
||||
|
||||
rank_start = start + rank * per_rank
|
||||
|
||||
rank_stop = min(rank_start + per_rank, stop)
|
||||
|
||||
new_length = rank_stop - rank_start
|
||||
|
||||
assert (
|
||||
new_length > 0
|
||||
), "Calculated start and stop points result in a length of zero for this rank."
|
||||
|
||||
return rank_start, rank_stop
|
||||
|
||||
|
||||
def get_reader(
|
||||
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
|
||||
):
|
||||
"""
|
||||
Create an EmbeddingReader object from the specified URLs
|
||||
|
||||
get_reader() will always expect a url to image embeddings.
|
||||
|
||||
If text-conditioned, it will also expect a meta_url for the captions.
|
||||
Otherwise, it will need txt_url for the matching text embeddings.
|
||||
|
||||
Returns an image_reader object if text-conditioned.
|
||||
Otherwise it returns both an image_reader and a text_reader
|
||||
"""
|
||||
|
||||
assert img_url is not None, "Must supply a image url"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply meta url if text-conditioned"
|
||||
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
# will assume the caption column exists and is the only one requested
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
return image_reader
|
||||
|
||||
# otherwise we will require text embeddings as well and return two readers
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
return image_reader, text_reader
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
image_reader: EmbeddingReader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
start=0,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
):
|
||||
"""
|
||||
Split an embedding reader object as needed.
|
||||
|
||||
NOTE: make_splits() will infer the test set size from your train and eval.
|
||||
|
||||
Input:
|
||||
- text_conditioned: whether to prepare text-conditioned training data
|
||||
- batch_size: the batch size for a single gpu
|
||||
- num_data_points: the total number of data points you wish to train on
|
||||
- train_split: the percentage of data you wish to train on
|
||||
- eval_split: the percentage of data you wish to validate on
|
||||
- image_reader: the image_reader you wish to split
|
||||
- text_reader: the text_reader you want to split (if !text_conditioned)
|
||||
- start: the starting point within your dataset
|
||||
- rank: the rank of your worker
|
||||
- world_size: the total world size of your distributed training run
|
||||
|
||||
Returns:
|
||||
- PyTorch Dataloaders that yield tuples of (img, txt) data.
|
||||
"""
|
||||
|
||||
assert start < image_reader.count, "start position cannot exceed reader count."
|
||||
|
||||
# verify that the num_data_points does not exceed the max points
|
||||
if num_data_points > (image_reader.count - start):
|
||||
print(
|
||||
"Specified count is larger than what's available...defaulting to reader's count."
|
||||
)
|
||||
num_data_points = image_reader.count
|
||||
|
||||
# compute split points
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_start = train_set_size
|
||||
eval_stop = int(eval_start + eval_set_size)
|
||||
|
||||
assert (
|
||||
train_split + eval_split
|
||||
) < 1.0, "Specified train and eval split is too large to infer a test split."
|
||||
|
||||
# distribute to rank
|
||||
rank_train_start, rank_train_stop = distribute_to_rank(
|
||||
start, train_set_size, rank, world_size
|
||||
)
|
||||
rank_eval_start, rank_eval_stop = distribute_to_rank(
|
||||
train_set_size, eval_stop, rank, world_size
|
||||
)
|
||||
rank_test_start, rank_test_stop = distribute_to_rank(
|
||||
eval_stop, num_data_points, rank, world_size
|
||||
)
|
||||
|
||||
# wrap up splits into a dict
|
||||
train_split_args = dict(
|
||||
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
|
||||
)
|
||||
eval_split_args = dict(
|
||||
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
|
||||
)
|
||||
test_split_args = dict(
|
||||
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
|
||||
)
|
||||
|
||||
if text_conditioned:
|
||||
# add the text-conditioned args to a unified dict
|
||||
reader_args = dict(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
)
|
||||
|
||||
train_split_args = dict(**reader_args, **train_split_args)
|
||||
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||
test_split_args = dict(**reader_args, **test_split_args)
|
||||
|
||||
train = PriorEmbeddingDataset(**train_split_args)
|
||||
val = PriorEmbeddingDataset(**eval_split_args)
|
||||
test = PriorEmbeddingDataset(**test_split_args)
|
||||
|
||||
else:
|
||||
# add the non-conditioned args to a unified dict
|
||||
reader_args = dict(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
)
|
||||
|
||||
train_split_args = dict(**reader_args, **train_split_args)
|
||||
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||
test_split_args = dict(**reader_args, **test_split_args)
|
||||
|
||||
train = PriorEmbeddingDataset(**train_split_args)
|
||||
val = PriorEmbeddingDataset(**eval_split_args)
|
||||
test = PriorEmbeddingDataset(**test_split_args)
|
||||
|
||||
# true batch size is specifed in the PriorEmbeddingDataset
|
||||
train_loader = DataLoader(train, batch_size=None)
|
||||
eval_loader = DataLoader(val, batch_size=None)
|
||||
test_loader = DataLoader(test, batch_size=None)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
@@ -1,17 +1,20 @@
|
||||
from torch.optim import AdamW, Adam
|
||||
|
||||
def separate_weight_decayable_params(params):
|
||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
||||
wd_params = set(params) - no_wd_params
|
||||
wd_params, no_wd_params = [], []
|
||||
for param in params:
|
||||
param_list = no_wd_params if param.ndim < 2 else wd_params
|
||||
param_list.append(param)
|
||||
return wd_params, no_wd_params
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
betas = (0.9, 0.99),
|
||||
eps = 1e-8,
|
||||
filter_by_requires_grad = False,
|
||||
group_wd_params = True,
|
||||
**kwargs
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
@@ -20,12 +23,12 @@ def get_optimizer(
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||
|
||||
params = set(params)
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
if group_wd_params:
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
|
||||
param_groups = [
|
||||
{'params': list(wd_params)},
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
params = [
|
||||
{'params': wd_params},
|
||||
{'params': no_wd_params, 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# to give users a quick easy start to training DALL-E without doing BPE
|
||||
|
||||
import torch
|
||||
import youtokentome as yttm
|
||||
|
||||
import html
|
||||
import os
|
||||
@@ -11,6 +10,8 @@ import regex as re
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
|
||||
# OpenAI simple tokenizer
|
||||
|
||||
@lru_cache()
|
||||
@@ -156,7 +157,9 @@ class YttmTokenizer:
|
||||
bpe_path = Path(bpe_path)
|
||||
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
||||
|
||||
tokenizer = yttm.BPE(model = str(bpe_path))
|
||||
self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
|
||||
|
||||
tokenizer = self.yttm.BPE(model = str(bpe_path))
|
||||
self.tokenizer = tokenizer
|
||||
self.vocab_size = tokenizer.vocab_size()
|
||||
|
||||
@@ -167,7 +170,7 @@ class YttmTokenizer:
|
||||
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
|
||||
|
||||
def encode(self, texts):
|
||||
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
|
||||
encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
|
||||
return list(map(torch.tensor, encoded))
|
||||
|
||||
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
||||
|
||||
@@ -6,6 +6,8 @@ from itertools import zip_longest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DATA_PATH = './.tracker-data'
|
||||
@@ -15,14 +17,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def import_or_print_error(pkg_name, err_str = None):
|
||||
try:
|
||||
return importlib.import_module(pkg_name)
|
||||
except ModuleNotFoundError as e:
|
||||
if exists(err_str):
|
||||
print(err_str)
|
||||
exit()
|
||||
|
||||
# load state dict functions
|
||||
|
||||
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
||||
|
||||
@@ -3,15 +3,160 @@ from torchvision import transforms as T
|
||||
from pydantic import BaseModel, validator, root_validator
|
||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||
|
||||
from x_clip import CLIP as XCLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import (
|
||||
CoCaAdapter,
|
||||
OpenAIClipAdapter,
|
||||
Unet,
|
||||
Decoder,
|
||||
DiffusionPrior,
|
||||
DiffusionPriorNetwork,
|
||||
XClipAdapter,
|
||||
)
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def ListOrTuple(inner_type):
|
||||
return Union[List[inner_type], Tuple[inner_type]]
|
||||
|
||||
def SingularOrIterable(inner_type):
|
||||
return Union[inner_type, ListOrTuple(inner_type)]
|
||||
|
||||
# general pydantic classes
|
||||
|
||||
class TrainSplitConfig(BaseModel):
|
||||
train: float = 0.75
|
||||
val: float = 0.15
|
||||
test: float = 0.1
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, fields):
|
||||
actual_sum = sum([*fields.values()])
|
||||
if actual_sum != 1.:
|
||||
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||
return fields
|
||||
|
||||
class TrackerConfig(BaseModel):
|
||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||
data_path: str = './models' # The path where files will be saved locally
|
||||
init_config: Dict[str, Any] = None
|
||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||
wandb_project: str = ''
|
||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||
|
||||
# diffusion prior pydantic classes
|
||||
|
||||
class AdapterConfig(BaseModel):
|
||||
make: str = "openai"
|
||||
model: str = "ViT-L/14"
|
||||
base_model_kwargs: Dict[str, Any] = None
|
||||
|
||||
def create(self):
|
||||
if self.make == "openai":
|
||||
return OpenAIClipAdapter(self.model)
|
||||
elif self.make == "x-clip":
|
||||
return XClipAdapter(XCLIP(**self.base_model_kwargs))
|
||||
elif self.make == "coca":
|
||||
return CoCaAdapter(CoCa(**self.base_model_kwargs))
|
||||
else:
|
||||
raise AttributeError("No adapter with that name is available.")
|
||||
|
||||
class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim: int
|
||||
depth: int
|
||||
num_timesteps: int = None
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
num_text_embeds: int = 1
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
final_proj: bool = True
|
||||
normformer: bool = False
|
||||
rotary_emb: bool = True
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
return DiffusionPriorNetwork(**kwargs)
|
||||
|
||||
class DiffusionPriorConfig(BaseModel):
|
||||
clip: AdapterConfig = None
|
||||
net: DiffusionPriorNetworkConfig
|
||||
image_embed_dim: int
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
beta_schedule: str = 'cosine'
|
||||
condition_on_text_encodings: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
|
||||
has_clip = exists(kwargs.pop('clip'))
|
||||
kwargs.pop('net')
|
||||
|
||||
clip = None
|
||||
if has_clip:
|
||||
clip = self.clip.create()
|
||||
|
||||
diffusion_prior_network = self.net.create()
|
||||
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
|
||||
|
||||
class DiffusionPriorTrainConfig(BaseModel):
|
||||
epochs: int = 1
|
||||
lr: float = 1.1e-4
|
||||
wd: float = 6.02e-2
|
||||
max_grad_norm: float = 0.5
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
amp: bool = False
|
||||
save_every: int = 10000 # what steps to save on
|
||||
|
||||
class DiffusionPriorDataConfig(BaseModel):
|
||||
image_url: str # path to embeddings folder
|
||||
meta_url: str # path to metadata (captions) for images
|
||||
splits: TrainSplitConfig
|
||||
batch_size: int = 64
|
||||
|
||||
class DiffusionPriorLoadConfig(BaseModel):
|
||||
source: str = None
|
||||
resume: bool = False
|
||||
|
||||
class TrainDiffusionPriorConfig(BaseModel):
|
||||
prior: DiffusionPriorConfig
|
||||
data: DiffusionPriorDataConfig
|
||||
train: DiffusionPriorTrainConfig
|
||||
load: DiffusionPriorLoadConfig
|
||||
tracker: TrackerConfig
|
||||
|
||||
@classmethod
|
||||
def from_json_path(cls, json_path):
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
# decoder pydantic classes
|
||||
|
||||
class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: List[int]
|
||||
dim_mults: ListOrTuple(int)
|
||||
image_embed_dim: int = None
|
||||
cond_dim: int = None
|
||||
channels: int = 3
|
||||
@@ -22,13 +167,22 @@ class UnetConfig(BaseModel):
|
||||
extra = "allow"
|
||||
|
||||
class DecoderConfig(BaseModel):
|
||||
unets: ListOrTuple(UnetConfig)
|
||||
image_size: int = None
|
||||
image_sizes: Union[List[int], Tuple[int]] = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: str = 'cosine'
|
||||
learned_variance: bool = True
|
||||
image_cond_drop_prob: float = 0.1
|
||||
text_cond_drop_prob: float = 0.5
|
||||
|
||||
def create(self):
|
||||
decoder_kwargs = self.dict()
|
||||
unet_configs = decoder_kwargs.pop('unets')
|
||||
unets = [Unet(**config) for config in unet_configs]
|
||||
return Decoder(unets, **decoder_kwargs)
|
||||
|
||||
@validator('image_sizes')
|
||||
def check_image_sizes(cls, image_sizes, values):
|
||||
@@ -39,17 +193,6 @@ class DecoderConfig(BaseModel):
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
class TrainSplitConfig(BaseModel):
|
||||
train: float = 0.75
|
||||
val: float = 0.15
|
||||
test: float = 0.1
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, fields):
|
||||
if sum([*fields.values()]) != 1.:
|
||||
raise ValueError(f'{fields.keys()} must sum to 1.0')
|
||||
return fields
|
||||
|
||||
class DecoderDataConfig(BaseModel):
|
||||
webdataset_base_url: str # path to a webdataset with jpg images
|
||||
embeddings_url: str # path to .npy files with embeddings
|
||||
@@ -64,23 +207,39 @@ class DecoderDataConfig(BaseModel):
|
||||
resample_train: bool = False
|
||||
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
||||
|
||||
@property
|
||||
def img_preproc(self):
|
||||
def _get_transformation(transformation_name, **kwargs):
|
||||
if transformation_name == "RandomResizedCrop":
|
||||
return T.RandomResizedCrop(**kwargs)
|
||||
elif transformation_name == "RandomHorizontalFlip":
|
||||
return T.RandomHorizontalFlip()
|
||||
elif transformation_name == "ToTensor":
|
||||
return T.ToTensor()
|
||||
|
||||
transforms = []
|
||||
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
|
||||
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
||||
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
||||
return T.Compose(transforms)
|
||||
|
||||
class DecoderTrainConfig(BaseModel):
|
||||
epochs: int = 20
|
||||
lr: float = 1e-4
|
||||
wd: float = 0.01
|
||||
max_grad_norm: float = 0.5
|
||||
lr: SingularOrIterable(float) = 1e-4
|
||||
wd: SingularOrIterable(float) = 0.01
|
||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
device: str = 'cuda:0'
|
||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
validation_samples: int = None # Same as above but for validation.
|
||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
validation_samples: int = None # Same as above but for validation.
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
ema_beta: float = 0.999
|
||||
amp: bool = False
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: List[bool] = None # If None, use all unets
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||
|
||||
class DecoderEvaluateConfig(BaseModel):
|
||||
n_evaluation_samples: int = 1000
|
||||
@@ -89,14 +248,6 @@ class DecoderEvaluateConfig(BaseModel):
|
||||
KID: Dict[str, Any] = None
|
||||
LPIPS: Dict[str, Any] = None
|
||||
|
||||
class TrackerConfig(BaseModel):
|
||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||
data_path: str = './models' # The path where files will be saved locally
|
||||
init_config: Dict[str, Any] = None
|
||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||
wandb_project: str = ''
|
||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||
|
||||
class DecoderLoadConfig(BaseModel):
|
||||
source: str = None # Supports file and wandb
|
||||
run_path: str = '' # Used only if source is wandb
|
||||
@@ -104,7 +255,6 @@ class DecoderLoadConfig(BaseModel):
|
||||
resume: bool = False # If using wandb, whether to resume the run
|
||||
|
||||
class TrainDecoderConfig(BaseModel):
|
||||
unets: List[UnetConfig]
|
||||
decoder: DecoderConfig
|
||||
data: DecoderDataConfig
|
||||
train: DecoderTrainConfig
|
||||
@@ -117,19 +267,3 @@ class TrainDecoderConfig(BaseModel):
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
@property
|
||||
def img_preproc(self):
|
||||
def _get_transformation(transformation_name, **kwargs):
|
||||
if transformation_name == "RandomResizedCrop":
|
||||
return T.RandomResizedCrop(**kwargs)
|
||||
elif transformation_name == "RandomHorizontalFlip":
|
||||
return T.RandomHorizontalFlip()
|
||||
elif transformation_name == "ToTensor":
|
||||
return T.ToTensor()
|
||||
|
||||
transforms = []
|
||||
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
|
||||
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
||||
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
||||
return T.Compose(transforms)
|
||||
|
||||
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -56,9 +58,15 @@ def num_to_groups(num, divisor):
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
def get_pkg_version():
|
||||
from pkg_resources import get_distribution
|
||||
return get_distribution('dalle2_pytorch').version
|
||||
def clamp(value, min_value = None, max_value = None):
|
||||
assert exists(min_value) or exists(max_value)
|
||||
if exists(min_value):
|
||||
value = max(value, min_value)
|
||||
|
||||
if exists(max_value):
|
||||
value = min(value, max_value)
|
||||
|
||||
return value
|
||||
|
||||
# decorators
|
||||
|
||||
@@ -133,12 +141,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||
chunk_size_frac = chunk_size / batch_size
|
||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
flank = symbol * repeat
|
||||
return f'{flank} {s} {flank}'
|
||||
|
||||
# saving and loading functions
|
||||
|
||||
# for diffusion prior
|
||||
@@ -180,12 +182,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
"""
|
||||
Implements exponential moving average shadowing for your model.
|
||||
|
||||
Utilizes an inverse decay schedule to manage longer term training runs.
|
||||
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
|
||||
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
||||
good values for models you plan to train for a million or more steps (reaches decay
|
||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||
215.4k steps).
|
||||
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.9999,
|
||||
update_after_step = 1000,
|
||||
update_after_step = 10000,
|
||||
update_every = 10,
|
||||
inv_gamma = 1.0,
|
||||
power = 2/3,
|
||||
min_value = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
@@ -193,7 +217,11 @@ class EMA(nn.Module):
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.update_every = update_every
|
||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
||||
self.update_after_step = update_after_step
|
||||
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
@@ -203,37 +231,51 @@ class EMA(nn.Module):
|
||||
self.ema_model.to(device)
|
||||
|
||||
def copy_params_from_model_to_ema(self):
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
||||
ma_param.data.copy_(current_param.data)
|
||||
|
||||
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
|
||||
ma_buffer.data.copy_(current_buffer.data)
|
||||
|
||||
def get_current_decay(self):
|
||||
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
|
||||
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
||||
|
||||
if epoch <= 0:
|
||||
return 0.
|
||||
|
||||
return clamp(value, min_value = self.min_value, max_value = self.beta)
|
||||
|
||||
def update(self):
|
||||
step = self.step.item()
|
||||
self.step += 1
|
||||
|
||||
if (self.step % self.update_every) != 0:
|
||||
if (step % self.update_every) != 0:
|
||||
return
|
||||
|
||||
if self.step <= self.update_after_step:
|
||||
if step <= self.update_after_step:
|
||||
self.copy_params_from_model_to_ema()
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
if not self.initted.item():
|
||||
self.copy_params_from_model_to_ema()
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
@torch.no_grad()
|
||||
def update_moving_average(self, ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * beta + (1 - beta) * new
|
||||
current_decay = self.get_current_decay()
|
||||
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
|
||||
difference = ma_params.data - current_params.data
|
||||
difference.mul_(1.0 - current_decay)
|
||||
ma_params.sub_(difference)
|
||||
|
||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||
ma_buffer.copy_(new_buffer_value)
|
||||
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
|
||||
difference = ma_buffer - current_buffer
|
||||
difference.mul_(1.0 - current_decay)
|
||||
ma_buffer.sub_(difference)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
@@ -260,6 +302,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -285,6 +328,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
lr = lr,
|
||||
wd = wd,
|
||||
eps = eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -294,7 +338,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
def save(self, path, overwrite = True):
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
@@ -303,8 +347,9 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
scaler = self.scaler.state_dict(),
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
model = self.diffusion_prior.state_dict(),
|
||||
version = get_pkg_version(),
|
||||
step = self.step.item()
|
||||
version = __version__,
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if self.use_ema:
|
||||
@@ -318,14 +363,14 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
loaded_obj = torch.load(str(path))
|
||||
|
||||
if get_pkg_version() != loaded_obj['version']:
|
||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
|
||||
if version.parse(__version__) != loaded_obj['version']:
|
||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||
|
||||
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return
|
||||
return loaded_obj
|
||||
|
||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
@@ -334,6 +379,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
def update(self):
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
@@ -413,6 +460,7 @@ class DecoderTrainer(nn.Module):
|
||||
eps = 1e-8,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -438,6 +486,7 @@ class DecoderTrainer(nn.Module):
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -455,15 +504,16 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def save(self, path, overwrite = True):
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
save_obj = dict(
|
||||
model = self.decoder.state_dict(),
|
||||
version = get_pkg_version(),
|
||||
step = self.step.item()
|
||||
version = __version__,
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
@@ -484,14 +534,14 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
loaded_obj = torch.load(str(path))
|
||||
|
||||
if get_pkg_version() != loaded_obj['version']:
|
||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
||||
if version.parse(__version__) != loaded_obj['version']:
|
||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return
|
||||
return loaded_obj
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
scaler_key = f'scaler{ind}'
|
||||
@@ -506,6 +556,8 @@ class DecoderTrainer(nn.Module):
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import time
|
||||
|
||||
# time helpers
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -9,3 +11,19 @@ class Timer:
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
flank = symbol * repeat
|
||||
return f'{flank} {s} {flank}'
|
||||
|
||||
# import helpers
|
||||
|
||||
def import_or_print_error(pkg_name, err_str = None):
|
||||
try:
|
||||
return importlib.import_module(pkg_name)
|
||||
except ModuleNotFoundError as e:
|
||||
if exists(err_str):
|
||||
print(err_str)
|
||||
exit()
|
||||
|
||||
1
dalle2_pytorch/version.py
Normal file
1
dalle2_pytorch/version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = '0.9.0'
|
||||
@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
def string_begins_with(prefix, string_input):
|
||||
return string_input.startswith(prefix)
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
5
setup.py
5
setup.py
@@ -1,4 +1,5 @@
|
||||
from setuptools import setup, find_packages
|
||||
exec(open('dalle2_pytorch/version.py').read())
|
||||
|
||||
setup(
|
||||
name = 'dalle2-pytorch',
|
||||
@@ -10,7 +11,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.4.5',
|
||||
version = __version__,
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -31,6 +32,7 @@ setup(
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'numpy',
|
||||
'packaging',
|
||||
'pillow',
|
||||
'pydantic',
|
||||
'resize-right>=0.0.2',
|
||||
@@ -40,7 +42,6 @@ setup(
|
||||
'tqdm',
|
||||
'vector-quantize-pytorch',
|
||||
'x-clip>=0.4.4',
|
||||
'youtokentome',
|
||||
'webdataset>=0.2.5',
|
||||
'fsspec>=2022.1.0',
|
||||
'torchmetrics[image]>=0.8.0'
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from dalle2_pytorch import Unet, Decoder
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
||||
from dalle2_pytorch.trainer import DecoderTrainer
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||
|
||||
import torchvision
|
||||
import torch
|
||||
@@ -85,20 +86,6 @@ def create_dataloaders(
|
||||
"test_sampling": test_sampling_dataloader
|
||||
}
|
||||
|
||||
|
||||
def create_decoder(device, decoder_config, unets_config):
|
||||
"""Creates a sample decoder"""
|
||||
|
||||
unets = [Unet(**config.dict()) for config in unets_config]
|
||||
|
||||
decoder = Decoder(
|
||||
unet=unets,
|
||||
**decoder_config.dict()
|
||||
)
|
||||
|
||||
decoder.to(device=device)
|
||||
return decoder
|
||||
|
||||
def get_dataset_keys(dataloader):
|
||||
"""
|
||||
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
||||
@@ -150,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||
|
||||
real_image_size = real_images[0].shape[-1]
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
|
||||
# training images may be larger than the generated one
|
||||
if real_image_size > generated_image_size:
|
||||
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
|
||||
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
@@ -216,7 +211,7 @@ def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
print(print_ribbon(f"Loading model from {recall_source}"))
|
||||
state_dict = tracker.recall_state_dict(recall_source, **load_config)
|
||||
state_dict = tracker.recall_state_dict(recall_source, **load_config.dict())
|
||||
trainer.load_state_dict(state_dict["trainer"])
|
||||
print("Model loaded")
|
||||
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
||||
@@ -336,7 +331,7 @@ def train(
|
||||
sample = 0
|
||||
average_loss = 0
|
||||
timer = Timer()
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
for i, (img, emb, *_) in enumerate(dataloaders["val"]):
|
||||
sample += img.shape[0]
|
||||
img, emb = send_to_device((img, emb))
|
||||
|
||||
@@ -361,7 +356,7 @@ def train(
|
||||
# Compute evaluation metrics
|
||||
if exists(evaluate_config):
|
||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
||||
tracker.log(evaluation, step=step, verbose=True)
|
||||
|
||||
# Generate sample images
|
||||
@@ -420,7 +415,7 @@ def initialize_training(config):
|
||||
|
||||
dataloaders = create_dataloaders (
|
||||
available_shards=all_shards,
|
||||
img_preproc = config.img_preproc,
|
||||
img_preproc = config.data.img_preproc,
|
||||
train_prop = config.data.splits.train,
|
||||
val_prop = config.data.splits.val,
|
||||
test_prop = config.data.splits.test,
|
||||
@@ -428,7 +423,7 @@ def initialize_training(config):
|
||||
**config.data.dict()
|
||||
)
|
||||
|
||||
decoder = create_decoder(device, config.decoder, config.unets)
|
||||
decoder = config.decoder.create().to(device = device)
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
print(print_ribbon("Loaded Config", repeat=40))
|
||||
print(f"Number of parameters: {num_parameters}")
|
||||
|
||||
@@ -7,14 +7,12 @@ import torch
|
||||
import clip
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.dataloaders import make_splits
|
||||
from dalle2_pytorch.dataloaders import make_splits, get_reader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
from dalle2_pytorch.utils import Timer
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -31,7 +29,7 @@ def exists(val):
|
||||
|
||||
# functions
|
||||
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -39,6 +37,8 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
||||
total_samples = 0.
|
||||
|
||||
for image_embeddings, text_data in tqdm(dataloader):
|
||||
image_embeddings = image_embeddings.to(device)
|
||||
text_data = text_data.to(device)
|
||||
|
||||
batches = image_embeddings.shape[0]
|
||||
|
||||
@@ -57,12 +57,14 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
||||
|
||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
|
||||
diffusion_prior.eval()
|
||||
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||
test_image_embeddings = test_image_embeddings.to(device)
|
||||
text_data = text_data.to(device)
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
@@ -240,7 +242,7 @@ def train(
|
||||
# Training loop
|
||||
# diffusion prior network
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
@@ -249,16 +251,16 @@ def train(
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer
|
||||
)
|
||||
|
||||
|
||||
# Load clip model if text-conditioning
|
||||
if dp_condition_on_text_encodings:
|
||||
clip_adapter = OpenAIClipAdapter(clip)
|
||||
else:
|
||||
clip_adapter = None
|
||||
|
||||
|
||||
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip_adapter,
|
||||
image_embed_dim = image_embed_dim,
|
||||
@@ -296,28 +298,46 @@ def train(
|
||||
|
||||
# Utilize wrapper to abstract away loader logic
|
||||
print_ribbon("Downloading Embeddings")
|
||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
||||
reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
|
||||
|
||||
if dp_condition_on_text_encodings:
|
||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
||||
reader_args = dict(**reader_args, meta_url=meta_url)
|
||||
img_reader = get_reader(**reader_args)
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=dp_condition_on_text_encodings,
|
||||
batch_size=batch_size,
|
||||
num_data_points=num_data_points,
|
||||
train_split=train_percent,
|
||||
eval_split=val_percent,
|
||||
image_reader=img_reader
|
||||
)
|
||||
else:
|
||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
||||
reader_args = dict(**reader_args, txt_url=text_embed_url)
|
||||
img_reader, txt_reader = get_reader(**reader_args)
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=dp_condition_on_text_encodings,
|
||||
batch_size=batch_size,
|
||||
num_data_points=num_data_points,
|
||||
train_split=train_percent,
|
||||
eval_split=val_percent,
|
||||
image_reader=img_reader,
|
||||
text_reader=txt_reader
|
||||
)
|
||||
|
||||
### Training code ###
|
||||
|
||||
step = 1
|
||||
step = 1
|
||||
timer = Timer()
|
||||
epochs = num_epochs
|
||||
|
||||
for _ in range(epochs):
|
||||
|
||||
for image, text in tqdm(train_loader):
|
||||
|
||||
diffusion_prior.train()
|
||||
|
||||
|
||||
image = image.to(device)
|
||||
text = text.to(device)
|
||||
|
||||
input_args = dict(image_embed=image)
|
||||
if dp_condition_on_text_encodings:
|
||||
input_args = dict(**input_args, text = text)
|
||||
@@ -350,9 +370,9 @@ def train(
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
|
||||
### Evaluate model(validation run) ###
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
|
||||
|
||||
step += 1
|
||||
trainer.update()
|
||||
|
||||
Reference in New Issue
Block a user