mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
67 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e37072a48c | ||
|
|
41ca896413 | ||
|
|
fe19b508ca | ||
|
|
6651eafa93 | ||
|
|
e6bb75e5ab | ||
|
|
b4c3e5b854 | ||
|
|
b7f9607258 | ||
|
|
2219348a6e | ||
|
|
9eea9b9862 | ||
|
|
5d958713c0 | ||
|
|
0f31980362 | ||
|
|
bee5bf3815 | ||
|
|
350a3d6045 | ||
|
|
1a81670718 | ||
|
|
934c9728dc | ||
|
|
ce4b0107c1 | ||
|
|
64c2f9c4eb | ||
|
|
22cc613278 | ||
|
|
83517849e5 | ||
|
|
708809ed6c | ||
|
|
9cc475f6e7 | ||
|
|
ffd342e9d0 | ||
|
|
f8bfd3493a | ||
|
|
9025345e29 | ||
|
|
8cc278447e | ||
|
|
38cd62010c | ||
|
|
1cc288af39 | ||
|
|
a851168633 | ||
|
|
1ffeecd0ca | ||
|
|
3df899f7a4 | ||
|
|
09534119a1 | ||
|
|
6f8b90d4d7 | ||
|
|
b588286288 | ||
|
|
b693e0be03 | ||
|
|
a0bed30a84 | ||
|
|
387c5bf774 | ||
|
|
a13d2d89c5 | ||
|
|
44d4b1bba9 | ||
|
|
f12a7589c5 | ||
|
|
b8af2210df | ||
|
|
f4fe6c570d | ||
|
|
645e207441 | ||
|
|
00743b3a0b | ||
|
|
01589aff6a | ||
|
|
7ecfd76cc0 | ||
|
|
6161b61c55 | ||
|
|
1ed0f9d80b | ||
|
|
f326a95e26 | ||
|
|
d7a0a2ce4b | ||
|
|
f23fab7ef7 | ||
|
|
857b9fbf1e | ||
|
|
8864fd0aa7 | ||
|
|
72bf159331 | ||
|
|
e5e47cfecb | ||
|
|
fa533962bd | ||
|
|
276abf337b | ||
|
|
ae42d03006 | ||
|
|
4d346e98d9 | ||
|
|
2b1fd1ad2e | ||
|
|
82a2ef37d9 | ||
|
|
5c397c9d66 | ||
|
|
0f4edff214 | ||
|
|
501a8c7c46 | ||
|
|
4e49373fc5 | ||
|
|
49de72040c | ||
|
|
271a376eaf | ||
|
|
e527002472 |
71
README.md
71
README.md
@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
||||
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
|
||||
|
||||
## Status
|
||||
|
||||
@@ -24,6 +24,28 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||
|
||||
## Pre-Trained Models
|
||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||
- DALL-E 2 🚧
|
||||
|
||||
## Appreciation
|
||||
|
||||
This library would not have gotten to this working state without the help of
|
||||
|
||||
- <a href="https://github.com/nousr">Zion</a> for the distributed training code for the diffusion prior
|
||||
- <a href="https://github.com/Veldrovive">Aidan</a> for the distributed training code for the decoder as well as the dataloaders
|
||||
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
|
||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
|
||||
|
||||
... and many others. Thank you! 🙏
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -936,7 +958,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
@@ -1034,18 +1056,6 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
||||
|
||||
## Appreciation
|
||||
|
||||
This library would not have gotten to this working state without the help of
|
||||
|
||||
- <a href="https://github.com/nousr">Zion</a> and <a href="https://github.com/krish240574">Kumar</a> for the diffusion training script
|
||||
- <a href="https://github.com/Veldrovive">Aidan</a> for the decoder training script and dataloaders
|
||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
|
||||
... and many others. Thank you! 🙏
|
||||
|
||||
## Todo
|
||||
|
||||
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
||||
@@ -1077,21 +1087,21 @@ This library would not have gotten to this working state without the help of
|
||||
- [x] cross embed layers for downsampling, as an option
|
||||
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [x] use pydantic for config drive training
|
||||
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
||||
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||
|
||||
## Citations
|
||||
@@ -1135,8 +1145,9 @@ This library would not have gotten to this working state without the help of
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022}
|
||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022},
|
||||
url = {https://arxiv.org/abs/2204.01697}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1190,4 +1201,22 @@ This library would not have gotten to this working state without the help of
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{Saharia2022,
|
||||
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
||||
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Choi2022PerceptionPT,
|
||||
title = {Perception Prioritized Training of Diffusion Models},
|
||||
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2204.00227}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -6,9 +6,10 @@ For more complex configuration, we provide the option of using a configuration f
|
||||
|
||||
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
||||
|
||||
**<ins>Unets</ins>:**
|
||||
**<ins>Unet</ins>:**
|
||||
|
||||
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
|
||||
|
||||
Each member of this array defines a single unet that will be added to the decoder.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `dim` | Yes | N/A | The starting channels of the unet. |
|
||||
@@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here.
|
||||
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `unets` | Yes | N/A | A list of unets, using the configuration above |
|
||||
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
||||
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
||||
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
||||
@@ -81,7 +83,7 @@ Defines which evaluation metrics will be used to test the model.
|
||||
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||
| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
||||
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
||||
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
{
|
||||
"unets": [
|
||||
{
|
||||
"dim": 128,
|
||||
"image_embed_dim": 768,
|
||||
"cond_dim": 64,
|
||||
"channels": 3,
|
||||
"dim_mults": [1, 2, 4, 8],
|
||||
"attn_dim_head": 32,
|
||||
"attn_heads": 16
|
||||
}
|
||||
],
|
||||
"decoder": {
|
||||
"unets": [
|
||||
{
|
||||
"dim": 128,
|
||||
"image_embed_dim": 768,
|
||||
"cond_dim": 64,
|
||||
"channels": 3,
|
||||
"dim_mults": [1, 2, 4, 8],
|
||||
"attn_dim_head": 32,
|
||||
"attn_heads": 16
|
||||
}
|
||||
],
|
||||
"image_sizes": [64],
|
||||
"channels": 3,
|
||||
"timesteps": 1000,
|
||||
|
||||
70
configs/train_prior_config.example.json
Normal file
70
configs/train_prior_config.example.json
Normal file
@@ -0,0 +1,70 @@
|
||||
{
|
||||
"prior": {
|
||||
"clip": {
|
||||
"make": "x-clip",
|
||||
"model": "ViT-L/14",
|
||||
"base_model_kwargs": {
|
||||
"dim_text": 768,
|
||||
"dim_image": 768,
|
||||
"dim_latent": 768
|
||||
}
|
||||
},
|
||||
"net": {
|
||||
"dim": 768,
|
||||
"depth": 12,
|
||||
"num_timesteps": 1000,
|
||||
"num_time_embeds": 1,
|
||||
"num_image_embeds": 1,
|
||||
"num_text_embeds": 1,
|
||||
"dim_head": 64,
|
||||
"heads": 12,
|
||||
"ff_mult": 4,
|
||||
"norm_out": true,
|
||||
"attn_dropout": 0.0,
|
||||
"ff_dropout": 0.0,
|
||||
"final_proj": true,
|
||||
"normformer": true,
|
||||
"rotary_emb": true
|
||||
},
|
||||
"image_embed_dim": 768,
|
||||
"image_size": 224,
|
||||
"image_channels": 3,
|
||||
"timesteps": 1000,
|
||||
"cond_drop_prob": 0.1,
|
||||
"loss_type": "l2",
|
||||
"predict_x_start": true,
|
||||
"beta_schedule": "cosine",
|
||||
"condition_on_text_encodings": true
|
||||
},
|
||||
"data": {
|
||||
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
|
||||
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
|
||||
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
|
||||
"batch_size": 256,
|
||||
"splits": {
|
||||
"train": 0.9,
|
||||
"val": 1e-7,
|
||||
"test": 0.0999999
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 1,
|
||||
"lr": 1.1e-4,
|
||||
"wd": 6.02e-2,
|
||||
"max_grad_norm": 0.5,
|
||||
"use_ema": true,
|
||||
"amp": false,
|
||||
"save_every": 10000
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
"resume": false
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "wandb",
|
||||
"data_path": "./prior_checkpoints",
|
||||
"wandb_entity": "laion",
|
||||
"wandb_project": "diffusion-prior",
|
||||
"verbose": true
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
from dalle2_pytorch.version import __version__
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import math
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial, wraps
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
import torchvision.transforms as T
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
@@ -56,7 +56,7 @@ def maybe(fn):
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
return d() if callable(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
if isinstance(val, list):
|
||||
@@ -313,11 +313,6 @@ def extract(a, t, x_shape):
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
def meanflat(x):
|
||||
return x.mean(dim = tuple(range(1, len(x.shape))))
|
||||
|
||||
@@ -372,7 +367,7 @@ def quadratic_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
||||
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
|
||||
|
||||
|
||||
def sigmoid_beta_schedule(timesteps):
|
||||
@@ -384,7 +379,7 @@ def sigmoid_beta_schedule(timesteps):
|
||||
|
||||
|
||||
class BaseGaussianDiffusion(nn.Module):
|
||||
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
||||
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
|
||||
super().__init__()
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
@@ -449,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
# p2 loss reweighting
|
||||
|
||||
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
|
||||
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
@@ -890,6 +890,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
)
|
||||
|
||||
if exists(clip):
|
||||
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
@@ -943,10 +945,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
@@ -1082,8 +1084,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
def Upsample(dim):
|
||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
|
||||
def Downsample(dim):
|
||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||
def Downsample(dim, *, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
@@ -1105,13 +1108,20 @@ class Block(nn.Module):
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
nn.SiLU()
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||
self.norm = nn.GroupNorm(groups, dim_out)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x, scale_shift = None):
|
||||
x = self.project(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if exists(scale_shift):
|
||||
scale, shift = scale_shift
|
||||
x = x * (scale + 1) + shift
|
||||
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
@@ -1130,7 +1140,7 @@ class ResnetBlock(nn.Module):
|
||||
if exists(time_cond_dim):
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_cond_dim, dim_out)
|
||||
nn.Linear(time_cond_dim, dim_out * 2)
|
||||
)
|
||||
|
||||
self.cross_attn = None
|
||||
@@ -1150,11 +1160,14 @@ class ResnetBlock(nn.Module):
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None, time_emb = None):
|
||||
h = self.block1(x)
|
||||
|
||||
scale_shift = None
|
||||
if exists(self.time_mlp) and exists(time_emb):
|
||||
time_emb = self.time_mlp(time_emb)
|
||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
||||
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
|
||||
scale_shift = time_emb.chunk(2, dim = 1)
|
||||
|
||||
h = self.block1(x, scale_shift = scale_shift)
|
||||
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
@@ -1331,12 +1344,15 @@ class Unet(nn.Module):
|
||||
cond_on_text_encodings = False,
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
resnet_groups = 8,
|
||||
num_resnet_blocks = 2,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
memory_efficient = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1356,7 +1372,7 @@ class Unet(nn.Module):
|
||||
self.channels_out = default(channels_out, channels)
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
init_dim = default(init_dim, dim)
|
||||
|
||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
||||
|
||||
@@ -1383,11 +1399,16 @@ class Unet(nn.Module):
|
||||
nn.Linear(time_cond_dim, time_cond_dim)
|
||||
)
|
||||
|
||||
self.image_to_cond = nn.Sequential(
|
||||
self.image_to_tokens = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.to_image_hiddens = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, time_cond_dim),
|
||||
nn.GELU()
|
||||
) if cond_on_image_embeds and add_image_embeds_to_time else None
|
||||
|
||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||
|
||||
@@ -1408,6 +1429,7 @@ class Unet(nn.Module):
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
|
||||
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
@@ -1419,6 +1441,7 @@ class Unet(nn.Module):
|
||||
# resnet block klass
|
||||
|
||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
|
||||
@@ -1434,16 +1457,17 @@ class Unet(nn.Module):
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
||||
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
||||
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
downsample_klass(dim_out) if not is_last and not memory_efficient else None
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
@@ -1452,19 +1476,19 @@ class Unet(nn.Module):
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Upsample(dim_in)
|
||||
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||
]))
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
||||
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
|
||||
nn.Conv2d(dim, self.channels_out, 1)
|
||||
)
|
||||
|
||||
@@ -1536,6 +1560,7 @@ class Unet(nn.Module):
|
||||
# initial convolution
|
||||
|
||||
x = self.init_conv(x)
|
||||
r = x.clone() # final residual
|
||||
|
||||
# time conditioning
|
||||
|
||||
@@ -1549,7 +1574,23 @@ class Unet(nn.Module):
|
||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
||||
|
||||
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
||||
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
|
||||
|
||||
# image embedding to be summed to time embedding
|
||||
# discovered by @mhh0318 in the paper
|
||||
|
||||
if exists(image_embed) and exists(self.to_image_hiddens):
|
||||
image_hiddens = self.to_image_hiddens(image_embed)
|
||||
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
|
||||
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
|
||||
|
||||
image_hiddens = torch.where(
|
||||
image_keep_mask_hidden,
|
||||
image_hiddens,
|
||||
null_image_hiddens
|
||||
)
|
||||
|
||||
t = t + image_hiddens
|
||||
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
@@ -1557,11 +1598,12 @@ class Unet(nn.Module):
|
||||
image_tokens = None
|
||||
|
||||
if self.cond_on_image_embeds:
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
|
||||
image_tokens = self.image_to_tokens(image_embed)
|
||||
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
image_tokens = torch.where(
|
||||
image_keep_mask,
|
||||
image_keep_mask_embed,
|
||||
image_tokens,
|
||||
null_image_embed
|
||||
)
|
||||
@@ -1616,12 +1658,20 @@ class Unet(nn.Module):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for block1, sparse_attn, block2, downsample in self.downs:
|
||||
x = block1(x, c, t)
|
||||
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
|
||||
if exists(pre_downsample):
|
||||
x = pre_downsample(x)
|
||||
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = block2(x, c, t)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
if exists(post_downsample):
|
||||
x = post_downsample(x)
|
||||
|
||||
x = self.mid_block1(x, mid_c, t)
|
||||
|
||||
@@ -1630,20 +1680,24 @@ class Unet(nn.Module):
|
||||
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
|
||||
for block1, sparse_attn, block2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = block1(x, c, t)
|
||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = block2(x, c, t)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
|
||||
x = upsample(x)
|
||||
|
||||
x = torch.cat((x, r), dim = 1)
|
||||
return self.final_conv(x)
|
||||
|
||||
class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
downsample_first = True,
|
||||
blur_sigma = 0.1,
|
||||
blur_sigma = (0.1, 0.2),
|
||||
blur_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1667,6 +1721,18 @@ class LowresConditioner(nn.Module):
|
||||
# when training, blur the low resolution conditional image
|
||||
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||
|
||||
# allow for drawing a random sigma between lo and hi float values
|
||||
if isinstance(blur_sigma, tuple):
|
||||
blur_sigma = tuple(map(float, blur_sigma))
|
||||
blur_sigma = random.uniform(*blur_sigma)
|
||||
|
||||
# allow for drawing a random kernel size between lo and hi int values
|
||||
if isinstance(blur_kernel_size, tuple):
|
||||
blur_kernel_size = tuple(map(int, blur_kernel_size))
|
||||
kernel_size_lo, kernel_size_hi = blur_kernel_size
|
||||
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
|
||||
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
||||
@@ -1692,30 +1758,44 @@ class Decoder(BaseGaussianDiffusion):
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict(),
|
||||
learned_variance = True,
|
||||
learned_variance_constrain_frac = False,
|
||||
vb_loss_weight = 0.001,
|
||||
unconditional = False,
|
||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||
use_dynamic_thres = False, # from the Imagen paper
|
||||
dynamic_thres_percentile = 0.9,
|
||||
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||
p2_loss_weight_k = 1
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type
|
||||
loss_type = loss_type,
|
||||
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
||||
p2_loss_weight_k = p2_loss_weight_k
|
||||
)
|
||||
|
||||
self.unconditional = unconditional
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
# text conditioning
|
||||
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# clip
|
||||
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
assert not unconditional, 'clip must not be given if doing unconditional image training'
|
||||
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
@@ -1725,13 +1805,20 @@ class Decoder(BaseGaussianDiffusion):
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
|
||||
self.clip = clip
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
else:
|
||||
self.clip_image_size = image_size
|
||||
self.channels = channels
|
||||
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
# determine image size, with image_size and image_sizes taking precedence
|
||||
|
||||
if exists(image_size) or exists(image_sizes):
|
||||
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
|
||||
image_size = default(image_size, lambda: image_sizes[-1])
|
||||
elif exists(clip):
|
||||
image_size = clip.image_size
|
||||
else:
|
||||
raise Error('either image_size, image_sizes, or clip must be given to decoder')
|
||||
|
||||
# channels
|
||||
|
||||
self.channels = channels
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
@@ -1743,6 +1830,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
||||
self.learned_variance = learned_variance
|
||||
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
|
||||
self.vb_loss_weight = vb_loss_weight
|
||||
|
||||
# construct unets and vaes
|
||||
@@ -1773,7 +1861,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
# unet image sizes
|
||||
|
||||
image_sizes = default(image_sizes, (self.clip_image_size,))
|
||||
image_sizes = default(image_sizes, (image_size,))
|
||||
image_sizes = tuple(sorted(set(image_sizes)))
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
@@ -1810,7 +1898,13 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.clip_denoised = clip_denoised
|
||||
self.clip_x_start = clip_x_start
|
||||
|
||||
# dynamic thresholding settings, if clipping denoised during sampling
|
||||
|
||||
self.use_dynamic_thres = use_dynamic_thres
|
||||
self.dynamic_thres_percentile = dynamic_thres_percentile
|
||||
|
||||
# normalize and unnormalize image functions
|
||||
|
||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||
|
||||
@@ -1851,7 +1945,21 @@ class Decoder(BaseGaussianDiffusion):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
# s is the threshold amount
|
||||
# static thresholding would just be s = 1
|
||||
s = 1.
|
||||
if self.use_dynamic_thres:
|
||||
s = torch.quantile(
|
||||
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
||||
self.dynamic_thres_percentile,
|
||||
dim = -1
|
||||
)
|
||||
|
||||
s.clamp_(min = 1.)
|
||||
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
||||
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x_recon = x_recon.clamp(-s, s) / s
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
|
||||
@@ -1863,16 +1971,19 @@ class Decoder(BaseGaussianDiffusion):
|
||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||
|
||||
if self.learned_variance_constrain_frac:
|
||||
var_interp_frac = var_interp_frac.sigmoid()
|
||||
|
||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||
posterior_variance = posterior_log_variance.exp()
|
||||
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
@@ -1936,7 +2047,13 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
loss = self.loss_fn(pred, target, reduction = 'none')
|
||||
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||
|
||||
if self.has_p2_loss_reweighting:
|
||||
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
if not learned_variance:
|
||||
# return simple loss if not using learned variance
|
||||
|
||||
@@ -4,7 +4,7 @@ In order to make loading data simple and efficient, we include some general data
|
||||
### Decoder: Image Embedding Dataset
|
||||
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
|
||||
|
||||
Generating a dataset of this type:
|
||||
Generating a dataset of this type:
|
||||
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
|
||||
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
|
||||
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
|
||||
@@ -15,7 +15,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
@@ -39,3 +39,37 @@ dataset = ImageEmbeddingDataset(
|
||||
)
|
||||
```
|
||||
|
||||
### Diffusion Prior: Prior Embedding Dataset
|
||||
When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
|
||||
|
||||
To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
|
||||
|
||||
If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||
|
||||
# grab embeddings from some specified location
|
||||
IMG_URL = "data/img_emb/"
|
||||
META_URL = "data/meta/"
|
||||
|
||||
reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
|
||||
|
||||
# some config for training
|
||||
TRAIN_ARGS = {
|
||||
"world_size": 3,
|
||||
"text_conditioned": True,
|
||||
"start": 0,
|
||||
"num_data_points": 10000,
|
||||
"batch_size": 2,
|
||||
"train_split": 0.5,
|
||||
"eval_split": 0.25,
|
||||
"image_reader": reader,
|
||||
}
|
||||
|
||||
# specifying a rank will handle allocation internally
|
||||
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
|
||||
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
|
||||
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
|
||||
```
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits
|
||||
from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch import from_numpy
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
|
||||
class PriorEmbeddingLoader(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
device: str = "cpu",
|
||||
) -> None:
|
||||
super(PriorEmbeddingLoader).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.batch_size = batch_size
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.device = device
|
||||
|
||||
def __iter__(self):
|
||||
self.n = 0
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
self.n += 1
|
||||
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
tokenized_caption = tokenize(
|
||||
caption["caption"].to_list(), truncate=True
|
||||
).to(self.device)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
text_embedding = from_numpy(text_embedding).to(self.device)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
device: str,
|
||||
img_url: str,
|
||||
meta_url: str = None,
|
||||
txt_url: str = None,
|
||||
):
|
||||
|
||||
assert img_url is not None, "Must supply some image embeddings"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply metadata url if text-conditioning"
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
else:
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from math import ceil
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
from torch import from_numpy
|
||||
from torch.utils.data import IterableDataset, DataLoader
|
||||
|
||||
|
||||
class PriorEmbeddingDataset(IterableDataset):
|
||||
"""
|
||||
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
|
||||
|
||||
It enables one to simplify the logic necessary to yield samples from
|
||||
the different EmbeddingReader configurations available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
) -> None:
|
||||
super(PriorEmbeddingDataset).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self.stop - self.start
|
||||
|
||||
def __iter__(self):
|
||||
# D.R.Y loader args
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
# if the data requested is text conditioned, only load images
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
# otherwise, include text embeddings and bypass metadata
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
|
||||
# return the data loader in its formatted state
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def __str__(self):
|
||||
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding)
|
||||
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding)
|
||||
text_embedding = from_numpy(text_embedding)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def distribute_to_rank(start, stop, rank, world_size):
|
||||
"""
|
||||
Distribute data to each rank given the world size.
|
||||
|
||||
Return:
|
||||
- New start and stop points for this rank.
|
||||
"""
|
||||
num_samples = int(stop - start)
|
||||
|
||||
per_rank = int(ceil((num_samples) / float(world_size)))
|
||||
|
||||
assert (
|
||||
per_rank > 0
|
||||
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
|
||||
|
||||
rank_start = start + rank * per_rank
|
||||
|
||||
rank_stop = min(rank_start + per_rank, stop)
|
||||
|
||||
new_length = rank_stop - rank_start
|
||||
|
||||
assert (
|
||||
new_length > 0
|
||||
), "Calculated start and stop points result in a length of zero for this rank."
|
||||
|
||||
return rank_start, rank_stop
|
||||
|
||||
|
||||
def get_reader(
|
||||
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
|
||||
):
|
||||
"""
|
||||
Create an EmbeddingReader object from the specified URLs
|
||||
|
||||
get_reader() will always expect a url to image embeddings.
|
||||
|
||||
If text-conditioned, it will also expect a meta_url for the captions.
|
||||
Otherwise, it will need txt_url for the matching text embeddings.
|
||||
|
||||
Returns an image_reader object if text-conditioned.
|
||||
Otherwise it returns both an image_reader and a text_reader
|
||||
"""
|
||||
|
||||
assert img_url is not None, "Must supply a image url"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply meta url if text-conditioned"
|
||||
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
# will assume the caption column exists and is the only one requested
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
return image_reader
|
||||
|
||||
# otherwise we will require text embeddings as well and return two readers
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
return image_reader, text_reader
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
image_reader: EmbeddingReader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
start=0,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
):
|
||||
"""
|
||||
Split an embedding reader object as needed.
|
||||
|
||||
NOTE: make_splits() will infer the test set size from your train and eval.
|
||||
|
||||
Input:
|
||||
- text_conditioned: whether to prepare text-conditioned training data
|
||||
- batch_size: the batch size for a single gpu
|
||||
- num_data_points: the total number of data points you wish to train on
|
||||
- train_split: the percentage of data you wish to train on
|
||||
- eval_split: the percentage of data you wish to validate on
|
||||
- image_reader: the image_reader you wish to split
|
||||
- text_reader: the text_reader you want to split (if !text_conditioned)
|
||||
- start: the starting point within your dataset
|
||||
- rank: the rank of your worker
|
||||
- world_size: the total world size of your distributed training run
|
||||
|
||||
Returns:
|
||||
- PyTorch Dataloaders that yield tuples of (img, txt) data.
|
||||
"""
|
||||
|
||||
assert start < image_reader.count, "start position cannot exceed reader count."
|
||||
|
||||
# verify that the num_data_points does not exceed the max points
|
||||
if num_data_points > (image_reader.count - start):
|
||||
print(
|
||||
"Specified count is larger than what's available...defaulting to reader's count."
|
||||
)
|
||||
num_data_points = image_reader.count
|
||||
|
||||
# compute split points
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_start = train_set_size
|
||||
eval_stop = int(eval_start + eval_set_size)
|
||||
|
||||
assert (
|
||||
train_split + eval_split
|
||||
) < 1.0, "Specified train and eval split is too large to infer a test split."
|
||||
|
||||
# distribute to rank
|
||||
rank_train_start, rank_train_stop = distribute_to_rank(
|
||||
start, train_set_size, rank, world_size
|
||||
)
|
||||
rank_eval_start, rank_eval_stop = distribute_to_rank(
|
||||
train_set_size, eval_stop, rank, world_size
|
||||
)
|
||||
rank_test_start, rank_test_stop = distribute_to_rank(
|
||||
eval_stop, num_data_points, rank, world_size
|
||||
)
|
||||
|
||||
# wrap up splits into a dict
|
||||
train_split_args = dict(
|
||||
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
|
||||
)
|
||||
eval_split_args = dict(
|
||||
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
|
||||
)
|
||||
test_split_args = dict(
|
||||
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
|
||||
)
|
||||
|
||||
if text_conditioned:
|
||||
# add the text-conditioned args to a unified dict
|
||||
reader_args = dict(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
)
|
||||
|
||||
train_split_args = dict(**reader_args, **train_split_args)
|
||||
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||
test_split_args = dict(**reader_args, **test_split_args)
|
||||
|
||||
train = PriorEmbeddingDataset(**train_split_args)
|
||||
val = PriorEmbeddingDataset(**eval_split_args)
|
||||
test = PriorEmbeddingDataset(**test_split_args)
|
||||
|
||||
else:
|
||||
# add the non-conditioned args to a unified dict
|
||||
reader_args = dict(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
)
|
||||
|
||||
train_split_args = dict(**reader_args, **train_split_args)
|
||||
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||
test_split_args = dict(**reader_args, **test_split_args)
|
||||
|
||||
train = PriorEmbeddingDataset(**train_split_args)
|
||||
val = PriorEmbeddingDataset(**eval_split_args)
|
||||
test = PriorEmbeddingDataset(**test_split_args)
|
||||
|
||||
# true batch size is specifed in the PriorEmbeddingDataset
|
||||
train_loader = DataLoader(train, batch_size=None)
|
||||
eval_loader = DataLoader(val, batch_size=None)
|
||||
test_loader = DataLoader(test, batch_size=None)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
@@ -1,17 +1,20 @@
|
||||
from torch.optim import AdamW, Adam
|
||||
|
||||
def separate_weight_decayable_params(params):
|
||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
||||
wd_params = set(params) - no_wd_params
|
||||
wd_params, no_wd_params = [], []
|
||||
for param in params:
|
||||
param_list = no_wd_params if param.ndim < 2 else wd_params
|
||||
param_list.append(param)
|
||||
return wd_params, no_wd_params
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
betas = (0.9, 0.99),
|
||||
eps = 1e-8,
|
||||
filter_by_requires_grad = False,
|
||||
group_wd_params = True,
|
||||
**kwargs
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
@@ -20,12 +23,12 @@ def get_optimizer(
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||
|
||||
params = set(params)
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
if group_wd_params:
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
|
||||
param_groups = [
|
||||
{'params': list(wd_params)},
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
params = [
|
||||
{'params': wd_params},
|
||||
{'params': no_wd_params, 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# to give users a quick easy start to training DALL-E without doing BPE
|
||||
|
||||
import torch
|
||||
import youtokentome as yttm
|
||||
|
||||
import html
|
||||
import os
|
||||
@@ -11,6 +10,8 @@ import regex as re
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
|
||||
# OpenAI simple tokenizer
|
||||
|
||||
@lru_cache()
|
||||
@@ -156,7 +157,9 @@ class YttmTokenizer:
|
||||
bpe_path = Path(bpe_path)
|
||||
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
||||
|
||||
tokenizer = yttm.BPE(model = str(bpe_path))
|
||||
self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
|
||||
|
||||
tokenizer = self.yttm.BPE(model = str(bpe_path))
|
||||
self.tokenizer = tokenizer
|
||||
self.vocab_size = tokenizer.vocab_size()
|
||||
|
||||
@@ -167,7 +170,7 @@ class YttmTokenizer:
|
||||
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
|
||||
|
||||
def encode(self, texts):
|
||||
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
|
||||
encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
|
||||
return list(map(torch.tensor, encoded))
|
||||
|
||||
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
||||
|
||||
@@ -6,6 +6,8 @@ from itertools import zip_longest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DATA_PATH = './.tracker-data'
|
||||
@@ -15,14 +17,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def import_or_print_error(pkg_name, err_str = None):
|
||||
try:
|
||||
return importlib.import_module(pkg_name)
|
||||
except ModuleNotFoundError as e:
|
||||
if exists(err_str):
|
||||
print(err_str)
|
||||
exit()
|
||||
|
||||
# load state dict functions
|
||||
|
||||
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
||||
|
||||
@@ -3,15 +3,160 @@ from torchvision import transforms as T
|
||||
from pydantic import BaseModel, validator, root_validator
|
||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||
|
||||
from x_clip import CLIP as XCLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import (
|
||||
CoCaAdapter,
|
||||
OpenAIClipAdapter,
|
||||
Unet,
|
||||
Decoder,
|
||||
DiffusionPrior,
|
||||
DiffusionPriorNetwork,
|
||||
XClipAdapter,
|
||||
)
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def ListOrTuple(inner_type):
|
||||
return Union[List[inner_type], Tuple[inner_type]]
|
||||
|
||||
def SingularOrIterable(inner_type):
|
||||
return Union[inner_type, ListOrTuple(inner_type)]
|
||||
|
||||
# general pydantic classes
|
||||
|
||||
class TrainSplitConfig(BaseModel):
|
||||
train: float = 0.75
|
||||
val: float = 0.15
|
||||
test: float = 0.1
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, fields):
|
||||
actual_sum = sum([*fields.values()])
|
||||
if actual_sum != 1.:
|
||||
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||
return fields
|
||||
|
||||
class TrackerConfig(BaseModel):
|
||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||
data_path: str = './models' # The path where files will be saved locally
|
||||
init_config: Dict[str, Any] = None
|
||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||
wandb_project: str = ''
|
||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||
|
||||
# diffusion prior pydantic classes
|
||||
|
||||
class AdapterConfig(BaseModel):
|
||||
make: str = "openai"
|
||||
model: str = "ViT-L/14"
|
||||
base_model_kwargs: Dict[str, Any] = None
|
||||
|
||||
def create(self):
|
||||
if self.make == "openai":
|
||||
return OpenAIClipAdapter(self.model)
|
||||
elif self.make == "x-clip":
|
||||
return XClipAdapter(XCLIP(**self.base_model_kwargs))
|
||||
elif self.make == "coca":
|
||||
return CoCaAdapter(CoCa(**self.base_model_kwargs))
|
||||
else:
|
||||
raise AttributeError("No adapter with that name is available.")
|
||||
|
||||
class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim: int
|
||||
depth: int
|
||||
num_timesteps: int = None
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
num_text_embeds: int = 1
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
final_proj: bool = True
|
||||
normformer: bool = False
|
||||
rotary_emb: bool = True
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
return DiffusionPriorNetwork(**kwargs)
|
||||
|
||||
class DiffusionPriorConfig(BaseModel):
|
||||
clip: AdapterConfig = None
|
||||
net: DiffusionPriorNetworkConfig
|
||||
image_embed_dim: int
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
beta_schedule: str = 'cosine'
|
||||
condition_on_text_encodings: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
|
||||
has_clip = exists(kwargs.pop('clip'))
|
||||
kwargs.pop('net')
|
||||
|
||||
clip = None
|
||||
if has_clip:
|
||||
clip = self.clip.create()
|
||||
|
||||
diffusion_prior_network = self.net.create()
|
||||
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
|
||||
|
||||
class DiffusionPriorTrainConfig(BaseModel):
|
||||
epochs: int = 1
|
||||
lr: float = 1.1e-4
|
||||
wd: float = 6.02e-2
|
||||
max_grad_norm: float = 0.5
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
amp: bool = False
|
||||
save_every: int = 10000 # what steps to save on
|
||||
|
||||
class DiffusionPriorDataConfig(BaseModel):
|
||||
image_url: str # path to embeddings folder
|
||||
meta_url: str # path to metadata (captions) for images
|
||||
splits: TrainSplitConfig
|
||||
batch_size: int = 64
|
||||
|
||||
class DiffusionPriorLoadConfig(BaseModel):
|
||||
source: str = None
|
||||
resume: bool = False
|
||||
|
||||
class TrainDiffusionPriorConfig(BaseModel):
|
||||
prior: DiffusionPriorConfig
|
||||
data: DiffusionPriorDataConfig
|
||||
train: DiffusionPriorTrainConfig
|
||||
load: DiffusionPriorLoadConfig
|
||||
tracker: TrackerConfig
|
||||
|
||||
@classmethod
|
||||
def from_json_path(cls, json_path):
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
# decoder pydantic classes
|
||||
|
||||
class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: List[int]
|
||||
dim_mults: ListOrTuple(int)
|
||||
image_embed_dim: int = None
|
||||
cond_dim: int = None
|
||||
channels: int = 3
|
||||
@@ -22,13 +167,22 @@ class UnetConfig(BaseModel):
|
||||
extra = "allow"
|
||||
|
||||
class DecoderConfig(BaseModel):
|
||||
unets: ListOrTuple(UnetConfig)
|
||||
image_size: int = None
|
||||
image_sizes: Union[List[int], Tuple[int]] = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: str = 'cosine'
|
||||
learned_variance: bool = True
|
||||
image_cond_drop_prob: float = 0.1
|
||||
text_cond_drop_prob: float = 0.5
|
||||
|
||||
def create(self):
|
||||
decoder_kwargs = self.dict()
|
||||
unet_configs = decoder_kwargs.pop('unets')
|
||||
unets = [Unet(**config) for config in unet_configs]
|
||||
return Decoder(unets, **decoder_kwargs)
|
||||
|
||||
@validator('image_sizes')
|
||||
def check_image_sizes(cls, image_sizes, values):
|
||||
@@ -39,17 +193,6 @@ class DecoderConfig(BaseModel):
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
class TrainSplitConfig(BaseModel):
|
||||
train: float = 0.75
|
||||
val: float = 0.15
|
||||
test: float = 0.1
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, fields):
|
||||
if sum([*fields.values()]) != 1.:
|
||||
raise ValueError(f'{fields.keys()} must sum to 1.0')
|
||||
return fields
|
||||
|
||||
class DecoderDataConfig(BaseModel):
|
||||
webdataset_base_url: str # path to a webdataset with jpg images
|
||||
embeddings_url: str # path to .npy files with embeddings
|
||||
@@ -64,23 +207,39 @@ class DecoderDataConfig(BaseModel):
|
||||
resample_train: bool = False
|
||||
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
||||
|
||||
@property
|
||||
def img_preproc(self):
|
||||
def _get_transformation(transformation_name, **kwargs):
|
||||
if transformation_name == "RandomResizedCrop":
|
||||
return T.RandomResizedCrop(**kwargs)
|
||||
elif transformation_name == "RandomHorizontalFlip":
|
||||
return T.RandomHorizontalFlip()
|
||||
elif transformation_name == "ToTensor":
|
||||
return T.ToTensor()
|
||||
|
||||
transforms = []
|
||||
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
|
||||
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
||||
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
||||
return T.Compose(transforms)
|
||||
|
||||
class DecoderTrainConfig(BaseModel):
|
||||
epochs: int = 20
|
||||
lr: float = 1e-4
|
||||
wd: float = 0.01
|
||||
max_grad_norm: float = 0.5
|
||||
lr: SingularOrIterable(float) = 1e-4
|
||||
wd: SingularOrIterable(float) = 0.01
|
||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
device: str = 'cuda:0'
|
||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
validation_samples: int = None # Same as above but for validation.
|
||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
validation_samples: int = None # Same as above but for validation.
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
ema_beta: float = 0.999
|
||||
amp: bool = False
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: List[bool] = None # If None, use all unets
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||
|
||||
class DecoderEvaluateConfig(BaseModel):
|
||||
n_evaluation_samples: int = 1000
|
||||
@@ -89,14 +248,6 @@ class DecoderEvaluateConfig(BaseModel):
|
||||
KID: Dict[str, Any] = None
|
||||
LPIPS: Dict[str, Any] = None
|
||||
|
||||
class TrackerConfig(BaseModel):
|
||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||
data_path: str = './models' # The path where files will be saved locally
|
||||
init_config: Dict[str, Any] = None
|
||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||
wandb_project: str = ''
|
||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||
|
||||
class DecoderLoadConfig(BaseModel):
|
||||
source: str = None # Supports file and wandb
|
||||
run_path: str = '' # Used only if source is wandb
|
||||
@@ -104,7 +255,6 @@ class DecoderLoadConfig(BaseModel):
|
||||
resume: bool = False # If using wandb, whether to resume the run
|
||||
|
||||
class TrainDecoderConfig(BaseModel):
|
||||
unets: List[UnetConfig]
|
||||
decoder: DecoderConfig
|
||||
data: DecoderDataConfig
|
||||
train: DecoderTrainConfig
|
||||
@@ -117,19 +267,3 @@ class TrainDecoderConfig(BaseModel):
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
@property
|
||||
def img_preproc(self):
|
||||
def _get_transformation(transformation_name, **kwargs):
|
||||
if transformation_name == "RandomResizedCrop":
|
||||
return T.RandomResizedCrop(**kwargs)
|
||||
elif transformation_name == "RandomHorizontalFlip":
|
||||
return T.RandomHorizontalFlip()
|
||||
elif transformation_name == "ToTensor":
|
||||
return T.ToTensor()
|
||||
|
||||
transforms = []
|
||||
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
|
||||
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
||||
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
||||
return T.Compose(transforms)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from math import ceil
|
||||
from functools import partial, wraps
|
||||
from collections.abc import Iterable
|
||||
@@ -10,6 +11,10 @@ from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -55,6 +60,16 @@ def num_to_groups(num, divisor):
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
def clamp(value, min_value = None, max_value = None):
|
||||
assert exists(min_value) or exists(max_value)
|
||||
if exists(min_value):
|
||||
value = max(value, min_value)
|
||||
|
||||
if exists(max_value):
|
||||
value = min(value, max_value)
|
||||
|
||||
return value
|
||||
|
||||
# decorators
|
||||
|
||||
def cast_torch_tensor(fn):
|
||||
@@ -128,12 +143,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||
chunk_size_frac = chunk_size / batch_size
|
||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
flank = symbol * repeat
|
||||
return f'{flank} {s} {flank}'
|
||||
|
||||
# saving and loading functions
|
||||
|
||||
# for diffusion prior
|
||||
@@ -175,12 +184,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
"""
|
||||
Implements exponential moving average shadowing for your model.
|
||||
|
||||
Utilizes an inverse decay schedule to manage longer term training runs.
|
||||
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
|
||||
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
||||
good values for models you plan to train for a million or more steps (reaches decay
|
||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||
215.4k steps).
|
||||
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.9999,
|
||||
update_after_step = 1000,
|
||||
update_after_step = 100,
|
||||
update_every = 10,
|
||||
inv_gamma = 1.0,
|
||||
power = 2/3,
|
||||
min_value = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
@@ -188,51 +219,70 @@ class EMA(nn.Module):
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.update_every = update_every
|
||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
||||
self.update_after_step = update_after_step
|
||||
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
def restore_ema_model_device(self):
|
||||
device = self.initted.device
|
||||
self.ema_model.to(device)
|
||||
|
||||
def copy_params_from_model_to_ema(self):
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
||||
ma_param.data.copy_(current_param.data)
|
||||
|
||||
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
|
||||
ma_buffer.data.copy_(current_buffer.data)
|
||||
|
||||
def get_current_decay(self):
|
||||
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
|
||||
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
||||
|
||||
if epoch <= 0:
|
||||
return 0.
|
||||
|
||||
return clamp(value, min_value = self.min_value, max_value = self.beta)
|
||||
|
||||
def update(self):
|
||||
step = self.step.item()
|
||||
self.step += 1
|
||||
|
||||
if (self.step % self.update_every) != 0:
|
||||
if (step % self.update_every) != 0:
|
||||
return
|
||||
|
||||
if self.step <= self.update_after_step:
|
||||
if step <= self.update_after_step:
|
||||
self.copy_params_from_model_to_ema()
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
if not self.initted.item():
|
||||
self.copy_params_from_model_to_ema()
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
@torch.no_grad()
|
||||
def update_moving_average(self, ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * beta + (1 - beta) * new
|
||||
current_decay = self.get_current_decay()
|
||||
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
|
||||
difference = ma_params.data - current_params.data
|
||||
difference.mul_(1.0 - current_decay)
|
||||
ma_params.sub_(difference)
|
||||
|
||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||
ma_buffer.copy_(new_buffer_value)
|
||||
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
|
||||
difference = ma_buffer - current_buffer
|
||||
difference.mul_(1.0 - current_decay)
|
||||
ma_buffer.sub_(difference)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
def prior_sample_in_chunks(fn):
|
||||
@@ -255,44 +305,190 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
device = None,
|
||||
accelerator = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
||||
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
# assign some helpful member vars
|
||||
self.accelerator = accelerator
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
self.text_conditioned = diffusion_prior.condition_on_text_encodings
|
||||
|
||||
# save model
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
|
||||
# exponential moving average
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
||||
|
||||
# optimizer and mixed precision stuff
|
||||
|
||||
self.amp = amp
|
||||
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
|
||||
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
|
||||
|
||||
self.optimizer = get_optimizer(
|
||||
diffusion_prior.parameters(),
|
||||
lr = lr,
|
||||
wd = wd,
|
||||
eps = eps,
|
||||
self.diffusion_prior.parameters(),
|
||||
**self.optim_kwargs,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# distribute the model if using HFA
|
||||
if exists(self.accelerator):
|
||||
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
|
||||
|
||||
# exponential moving average stuff
|
||||
|
||||
self.use_ema = use_ema
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
# track steps internally
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
# accelerator wrappers
|
||||
|
||||
def print(self, msg):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.print(msg)
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
def unwrap_model(self, model):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.unwrap_model(model)
|
||||
else:
|
||||
return model
|
||||
|
||||
def wait_for_everyone(self):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
def is_main_process(self):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.is_main_process
|
||||
else:
|
||||
return True
|
||||
|
||||
def clip_grad_norm_(self, *args):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.clip_grad_norm_(*args)
|
||||
else:
|
||||
return torch.nn.utils.clip_grad_norm_(*args)
|
||||
|
||||
def backprop(self, x):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.backward(x)
|
||||
else:
|
||||
try:
|
||||
x.backward()
|
||||
except Exception as e:
|
||||
self.print(f"Caught error in backprop call: {e}")
|
||||
|
||||
# utility
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
# ensure we sync gradients before continuing
|
||||
self.wait_for_everyone()
|
||||
|
||||
# only save on the main process
|
||||
if self.is_main_process():
|
||||
self.print(f"Saving checkpoint at step: {self.step.item()}")
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
save_obj = dict(
|
||||
scaler = self.scaler.state_dict(),
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
|
||||
version = version.parse(__version__),
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {
|
||||
**save_obj,
|
||||
'ema': self.ema_diffusion_prior.state_dict(),
|
||||
'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload
|
||||
}
|
||||
|
||||
torch.save(save_obj, str(path))
|
||||
|
||||
def load(self, path, overwrite_lr = True, strict = True):
|
||||
"""
|
||||
Load a checkpoint of a diffusion prior trainer.
|
||||
|
||||
Will load the entire trainer, including the optimizer and EMA.
|
||||
|
||||
Params:
|
||||
- path (str): a path to the DiffusionPriorTrainer checkpoint file
|
||||
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
|
||||
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
|
||||
|
||||
Returns:
|
||||
loaded_obj (dict): The loaded checkpoint dictionary
|
||||
"""
|
||||
|
||||
# all processes need to load checkpoint. no restriction here
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||
|
||||
if version.parse(__version__) != loaded_obj['version']:
|
||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||
|
||||
# unwrap the model when loading from checkpoint
|
||||
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
|
||||
if overwrite_lr:
|
||||
new_lr = self.optim_kwargs["lr"]
|
||||
|
||||
self.print(f"Overriding LR to be {new_lr}")
|
||||
|
||||
for group in self.optimizer.param_groups:
|
||||
group["lr"] = new_lr
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
|
||||
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
|
||||
|
||||
# sync and inform
|
||||
self.wait_for_everyone()
|
||||
self.print(f"Loaded model")
|
||||
|
||||
return loaded_obj
|
||||
|
||||
# model functionality
|
||||
|
||||
def update(self):
|
||||
# only continue with updates until all ranks finish
|
||||
self.wait_for_everyone()
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
# utilize HFA clipping where applicable
|
||||
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
@@ -307,17 +503,32 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.sample(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.sample_batch_size(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_text(self, *args, **kwargs):
|
||||
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
@@ -335,8 +546,10 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# backprop with accelerate if applicable
|
||||
|
||||
if self.training:
|
||||
self.scaler.scale(loss).backward()
|
||||
self.backprop(self.scaler.scale(loss))
|
||||
|
||||
return total_loss
|
||||
|
||||
@@ -368,6 +581,7 @@ class DecoderTrainer(nn.Module):
|
||||
eps = 1e-8,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -393,6 +607,7 @@ class DecoderTrainer(nn.Module):
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -410,6 +625,60 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
save_obj = dict(
|
||||
model = self.decoder.state_dict(),
|
||||
version = __version__,
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
scaler_key = f'scaler{ind}'
|
||||
optimizer_key = f'scaler{ind}'
|
||||
scaler = getattr(self, scaler_key)
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
|
||||
torch.save(save_obj, str(path))
|
||||
|
||||
def load(self, path, only_model = False, strict = True):
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path))
|
||||
|
||||
if version.parse(__version__) != loaded_obj['version']:
|
||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
scaler_key = f'scaler{ind}'
|
||||
optimizer_key = f'scaler{ind}'
|
||||
scaler = getattr(self, scaler_key)
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
scaler.load_state_dict(loaded_obj[scaler_key])
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import time
|
||||
|
||||
# time helpers
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -9,3 +11,19 @@ class Timer:
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
flank = symbol * repeat
|
||||
return f'{flank} {s} {flank}'
|
||||
|
||||
# import helpers
|
||||
|
||||
def import_or_print_error(pkg_name, err_str = None):
|
||||
try:
|
||||
return importlib.import_module(pkg_name)
|
||||
except ModuleNotFoundError as e:
|
||||
if exists(err_str):
|
||||
print(err_str)
|
||||
exit()
|
||||
|
||||
1
dalle2_pytorch/version.py
Normal file
1
dalle2_pytorch/version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = '0.10.0'
|
||||
@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
def string_begins_with(prefix, string_input):
|
||||
return string_input.startswith(prefix)
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
6
setup.py
6
setup.py
@@ -1,4 +1,5 @@
|
||||
from setuptools import setup, find_packages
|
||||
exec(open('dalle2_pytorch/version.py').read())
|
||||
|
||||
setup(
|
||||
name = 'dalle2-pytorch',
|
||||
@@ -10,7 +11,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.4.2',
|
||||
version = __version__,
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -23,6 +24,7 @@ setup(
|
||||
'text to image'
|
||||
],
|
||||
install_requires=[
|
||||
'accelerate',
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'coca-pytorch>=0.0.5',
|
||||
@@ -31,6 +33,7 @@ setup(
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'numpy',
|
||||
'packaging',
|
||||
'pillow',
|
||||
'pydantic',
|
||||
'resize-right>=0.0.2',
|
||||
@@ -40,7 +43,6 @@ setup(
|
||||
'tqdm',
|
||||
'vector-quantize-pytorch',
|
||||
'x-clip>=0.4.4',
|
||||
'youtokentome',
|
||||
'webdataset>=0.2.5',
|
||||
'fsspec>=2022.1.0',
|
||||
'torchmetrics[image]>=0.8.0'
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from dalle2_pytorch import Unet, Decoder
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
||||
from dalle2_pytorch.trainer import DecoderTrainer
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||
|
||||
import torchvision
|
||||
import torch
|
||||
@@ -85,20 +86,6 @@ def create_dataloaders(
|
||||
"test_sampling": test_sampling_dataloader
|
||||
}
|
||||
|
||||
|
||||
def create_decoder(device, decoder_config, unets_config):
|
||||
"""Creates a sample decoder"""
|
||||
|
||||
unets = [Unet(**config.dict()) for config in unets_config]
|
||||
|
||||
decoder = Decoder(
|
||||
unet=unets,
|
||||
**decoder_config.dict()
|
||||
)
|
||||
|
||||
decoder.to(device=device)
|
||||
return decoder
|
||||
|
||||
def get_dataset_keys(dataloader):
|
||||
"""
|
||||
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
||||
@@ -150,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||
|
||||
real_image_size = real_images[0].shape[-1]
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
|
||||
# training images may be larger than the generated one
|
||||
if real_image_size > generated_image_size:
|
||||
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
|
||||
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
@@ -216,7 +211,7 @@ def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
print(print_ribbon(f"Loading model from {recall_source}"))
|
||||
state_dict = tracker.recall_state_dict(recall_source, **load_config)
|
||||
state_dict = tracker.recall_state_dict(recall_source, **load_config.dict())
|
||||
trainer.load_state_dict(state_dict["trainer"])
|
||||
print("Model loaded")
|
||||
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
||||
@@ -336,7 +331,7 @@ def train(
|
||||
sample = 0
|
||||
average_loss = 0
|
||||
timer = Timer()
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
for i, (img, emb, *_) in enumerate(dataloaders["val"]):
|
||||
sample += img.shape[0]
|
||||
img, emb = send_to_device((img, emb))
|
||||
|
||||
@@ -361,7 +356,7 @@ def train(
|
||||
# Compute evaluation metrics
|
||||
if exists(evaluate_config):
|
||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
||||
tracker.log(evaluation, step=step, verbose=True)
|
||||
|
||||
# Generate sample images
|
||||
@@ -420,7 +415,7 @@ def initialize_training(config):
|
||||
|
||||
dataloaders = create_dataloaders (
|
||||
available_shards=all_shards,
|
||||
img_preproc = config.img_preproc,
|
||||
img_preproc = config.data.img_preproc,
|
||||
train_prop = config.data.splits.train,
|
||||
val_prop = config.data.splits.val,
|
||||
test_prop = config.data.splits.test,
|
||||
@@ -428,7 +423,7 @@ def initialize_training(config):
|
||||
**config.data.dict()
|
||||
)
|
||||
|
||||
decoder = create_decoder(device, config.decoder, config.unets)
|
||||
decoder = config.decoder.create().to(device = device)
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
print(print_ribbon("Loaded Config", repeat=40))
|
||||
print(f"Number of parameters: {num_parameters}")
|
||||
|
||||
@@ -1,75 +1,136 @@
|
||||
from pathlib import Path
|
||||
# TODO: add start, num_data_points, eval_every and group to config
|
||||
# TODO: switch back to repo's wandb
|
||||
|
||||
START = 0
|
||||
NUM_DATA_POINTS = 250e6
|
||||
EVAL_EVERY = 1000
|
||||
GROUP = "distributed"
|
||||
|
||||
import os
|
||||
import click
|
||||
import math
|
||||
import numpy as np
|
||||
import wandb
|
||||
|
||||
import torch
|
||||
import clip
|
||||
from torch import nn
|
||||
from torch.nn.functional import normalize
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dalle2_pytorch.dataloaders import make_splits
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
import numpy as np
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
from accelerate import Accelerator
|
||||
|
||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||
from dalle2_pytorch.utils import Timer
|
||||
from dalle2_pytorch.train_configs import (
|
||||
DiffusionPriorTrainConfig,
|
||||
TrainDiffusionPriorConfig,
|
||||
)
|
||||
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
|
||||
from dalle2_pytorch import DiffusionPriorTrainer
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
from tqdm import tqdm
|
||||
# helpers
|
||||
|
||||
# constants
|
||||
|
||||
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
tracker = WandbTracker()
|
||||
|
||||
# helpers functions
|
||||
|
||||
def exists(val):
|
||||
val is not None
|
||||
return val is not None
|
||||
|
||||
# functions
|
||||
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||
model.eval()
|
||||
def make_model(
|
||||
prior_config, train_config, device: str = None, accelerator: Accelerator = None
|
||||
):
|
||||
# create model from config
|
||||
diffusion_prior = prior_config.create()
|
||||
|
||||
# instantiate the trainer
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior=diffusion_prior,
|
||||
lr=train_config.lr,
|
||||
wd=train_config.wd,
|
||||
max_grad_norm=train_config.max_grad_norm,
|
||||
amp=train_config.amp,
|
||||
use_ema=train_config.use_ema,
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
# eval functions
|
||||
|
||||
|
||||
def eval_model(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
loss_type: str,
|
||||
phase: str,
|
||||
tracker: BaseTracker = None,
|
||||
use_ema: bool = True,
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
click.secho(f"Measuring performance on {phase}", fg="green", blink=True)
|
||||
|
||||
with torch.no_grad():
|
||||
total_loss = 0.
|
||||
total_samples = 0.
|
||||
total_loss = 0.0
|
||||
total_samples = 0.0
|
||||
|
||||
for image_embeddings, text_data in tqdm(dataloader):
|
||||
for image_embeddings, text_data in dataloader:
|
||||
image_embeddings = image_embeddings.to(trainer.device)
|
||||
text_data = text_data.to(trainer.device)
|
||||
|
||||
batches = image_embeddings.shape[0]
|
||||
|
||||
input_args = dict(image_embed=image_embeddings)
|
||||
|
||||
if text_conditioned:
|
||||
input_args = dict(**input_args, text = text_data)
|
||||
input_args = dict(**input_args, text=text_data)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text_data)
|
||||
|
||||
loss = model(**input_args)
|
||||
if use_ema:
|
||||
loss = trainer.ema_diffusion_prior(**input_args)
|
||||
else:
|
||||
loss = trainer(**input_args)
|
||||
|
||||
total_loss += loss * batches
|
||||
total_samples += batches
|
||||
|
||||
avg_loss = (total_loss / total_samples)
|
||||
avg_loss = total_loss / total_samples
|
||||
|
||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||
stats = {f"{phase}/{loss_type}": avg_loss}
|
||||
trainer.print(stats)
|
||||
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
diffusion_prior.eval()
|
||||
if exists(tracker):
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||
def report_cosine_sims(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
tracker: BaseTracker,
|
||||
phase: str = "validation",
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
|
||||
|
||||
for test_image_embeddings, text_data in dataloader:
|
||||
test_image_embeddings = test_image_embeddings.to(trainer.device)
|
||||
text_data = text_data.to(trainer.device)
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
|
||||
text_data)
|
||||
text_cond = dict(text_embed=text_embedding,
|
||||
text_encodings=text_encodings, mask=text_mask)
|
||||
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
|
||||
text_cond = dict(
|
||||
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
|
||||
)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
text_cond = dict(text_embed=text_embedding)
|
||||
@@ -80,8 +141,7 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
# roll the text to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / \
|
||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||
text_embed_shuffled = text_embed_shuffled / normalize(text_embed_shuffled)
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
@@ -90,276 +150,272 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
|
||||
text_cond_shuffled = dict(
|
||||
text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled,
|
||||
mask=text_mask_shuffled,
|
||||
)
|
||||
|
||||
# prepare the text embedding
|
||||
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
|
||||
text_embed = normalize(text_embedding / text_embedding)
|
||||
|
||||
# prepare image embeddings
|
||||
test_image_embeddings = test_image_embeddings / \
|
||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
test_image_embeddings = test_image_embeddings / normalize(test_image_embeddings)
|
||||
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / \
|
||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
predicted_image_embeddings = trainer.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond
|
||||
)
|
||||
|
||||
predicted_image_embeddings = predicted_image_embeddings / normalize(
|
||||
predicted_image_embeddings
|
||||
)
|
||||
|
||||
# predict on the shuffled embeddings
|
||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled)
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
predicted_unrelated_embeddings = trainer.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled
|
||||
)
|
||||
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / normalize(
|
||||
predicted_unrelated_embeddings
|
||||
)
|
||||
|
||||
# calculate similarities
|
||||
original_similarity = cos(
|
||||
text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(
|
||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = cos(
|
||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
predicted_img_similarity = cos(
|
||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
||||
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
||||
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
||||
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = (
|
||||
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
)
|
||||
predicted_img_similarity = (
|
||||
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
)
|
||||
|
||||
stats = {
|
||||
f"{phase}/baseline similarity": np.mean(original_similarity),
|
||||
f"{phase}/similarity with text": np.mean(predicted_similarity),
|
||||
f"{phase}/similarity with original image": np.mean(
|
||||
predicted_img_similarity
|
||||
),
|
||||
f"{phase}/similarity with unrelated caption": np.mean(unrelated_similarity),
|
||||
f"{phase}/difference from baseline similarity": np.mean(
|
||||
predicted_similarity - original_similarity
|
||||
),
|
||||
}
|
||||
|
||||
for k, v in stats.items():
|
||||
trainer.print(f"{phase}/{k}: {v}")
|
||||
|
||||
if exists(tracker):
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
|
||||
|
||||
# training script
|
||||
|
||||
|
||||
def train(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
train_loader: DataLoader,
|
||||
eval_loader: DataLoader,
|
||||
test_loader: DataLoader,
|
||||
config: DiffusionPriorTrainConfig,
|
||||
):
|
||||
# distributed tracking with wandb
|
||||
if trainer.accelerator.num_processes > 1:
|
||||
os.environ["WANDB_START_METHOD"] = "thread"
|
||||
|
||||
tracker = wandb.init(
|
||||
name=f"RANK:{trainer.device}",
|
||||
entity=config.tracker.wandb_entity,
|
||||
project=config.tracker.wandb_project,
|
||||
config=config.dict(),
|
||||
group=GROUP,
|
||||
)
|
||||
|
||||
# sync after tracker init
|
||||
trainer.wait_for_everyone()
|
||||
|
||||
# init a timer
|
||||
timer = Timer()
|
||||
|
||||
# do training
|
||||
for img, txt in train_loader:
|
||||
trainer.train()
|
||||
current_step = trainer.step.item() + 1
|
||||
|
||||
# place data on device
|
||||
img = img.to(trainer.device)
|
||||
txt = txt.to(trainer.device)
|
||||
|
||||
# pass to model
|
||||
loss = trainer(text=txt, image_embed=img)
|
||||
|
||||
# display & log loss (will only print from main process)
|
||||
trainer.print(f"Step {current_step}: Loss {loss}")
|
||||
|
||||
# perform backprop & apply EMA updates
|
||||
trainer.update()
|
||||
|
||||
# track samples/sec/rank
|
||||
samples_per_sec = img.shape[0] / timer.elapsed()
|
||||
|
||||
# samples seen
|
||||
samples_seen = (
|
||||
config.data.batch_size * trainer.accelerator.num_processes * current_step
|
||||
)
|
||||
|
||||
# ema decay
|
||||
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
|
||||
|
||||
# Log on all processes for debugging
|
||||
tracker.log(
|
||||
{
|
||||
"training/loss": loss,
|
||||
"samples/sec/rank": samples_per_sec,
|
||||
"samples/seen": samples_seen,
|
||||
"ema/decay": ema_decay,
|
||||
},
|
||||
step=current_step,
|
||||
)
|
||||
|
||||
# Metric Tracking & Checkpointing (outside of timer's scope)
|
||||
if current_step % EVAL_EVERY == 0:
|
||||
eval_model(
|
||||
trainer,
|
||||
eval_loader,
|
||||
config.prior.condition_on_text_encodings,
|
||||
config.prior.loss_type,
|
||||
"training/validation",
|
||||
tracker,
|
||||
use_ema=False,
|
||||
)
|
||||
|
||||
eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss=config.prior.loss_type,
|
||||
phase="ema/validation",
|
||||
tracker=tracker,
|
||||
use_ema=True,
|
||||
)
|
||||
|
||||
report_cosine_sims(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
tracker=tracker,
|
||||
phase="ema/validation",
|
||||
)
|
||||
|
||||
if current_step % config.train.save_every == 0:
|
||||
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
|
||||
|
||||
# reset timer for next round
|
||||
timer.reset()
|
||||
|
||||
# evaluate on test data
|
||||
|
||||
eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=test_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
phase="test",
|
||||
tracker=tracker,
|
||||
)
|
||||
|
||||
report_cosine_sims(
|
||||
trainer,
|
||||
test_loader,
|
||||
config.prior.condition_on_text_encodings,
|
||||
tracker,
|
||||
phase="test",
|
||||
)
|
||||
|
||||
|
||||
def initialize_training(config, accelerator=None):
|
||||
"""
|
||||
Parse the configuration file, and prepare everything necessary for training
|
||||
"""
|
||||
|
||||
# get a device
|
||||
|
||||
if accelerator:
|
||||
device = accelerator.device
|
||||
click.secho(f"Accelerating on: {device}", fg="yellow")
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
|
||||
device = "cuda:0"
|
||||
else:
|
||||
click.secho("No GPU detected...using cpu", fg="yellow")
|
||||
device = "cpu"
|
||||
|
||||
# make the trainer (will automatically distribute if possible & configured)
|
||||
|
||||
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
|
||||
|
||||
# reload from chcekpoint
|
||||
|
||||
if config.load.resume == True:
|
||||
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
|
||||
trainer.load(config.load.source)
|
||||
|
||||
# fetch and prepare data
|
||||
|
||||
if trainer.is_main_process():
|
||||
click.secho("Grabbing data from source", fg="blue", blink=True)
|
||||
|
||||
img_reader = get_reader(
|
||||
text_conditioned=trainer.text_conditioned,
|
||||
img_url=config.data.image_url,
|
||||
meta_url=config.data.meta_url,
|
||||
)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=trainer.text_conditioned,
|
||||
batch_size=config.data.batch_size,
|
||||
num_data_points=NUM_DATA_POINTS,
|
||||
train_split=config.data.splits.train,
|
||||
eval_split=config.data.splits.val,
|
||||
image_reader=img_reader,
|
||||
rank=accelerator.state.process_index if exists(accelerator) else 0,
|
||||
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
|
||||
start=START,
|
||||
)
|
||||
|
||||
# wait for everyone to load data before continuing
|
||||
trainer.wait_for_everyone()
|
||||
|
||||
# start training
|
||||
train(
|
||||
trainer=trainer,
|
||||
train_loader=train_loader,
|
||||
eval_loader=eval_loader,
|
||||
test_loader=test_loader,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--wandb-entity", default="laion")
|
||||
@click.option("--wandb-project", default="diffusion-prior")
|
||||
@click.option("--wandb-dataset", default="LAION-5B")
|
||||
@click.option("--wandb-arch", default="DiffusionPrior")
|
||||
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
|
||||
@click.option("--learning-rate", default=1.1e-4)
|
||||
@click.option("--weight-decay", default=6.02e-2)
|
||||
@click.option("--dropout", default=5e-2)
|
||||
@click.option("--max-grad-norm", default=0.5)
|
||||
@click.option("--num-data-points", default=250e6)
|
||||
@click.option("--batch-size", default=320)
|
||||
@click.option("--num-epochs", default=5)
|
||||
@click.option("--image-embed-dim", default=768)
|
||||
@click.option("--train-percent", default=0.9)
|
||||
@click.option("--val-percent", default=1e-7)
|
||||
@click.option("--test-percent", default=0.0999999)
|
||||
@click.option("--dpn-depth", default=12)
|
||||
@click.option("--dpn-dim-head", default=64)
|
||||
@click.option("--dpn-heads", default=12)
|
||||
@click.option("--dp-condition-on-text-encodings", default=True)
|
||||
@click.option("--dp-timesteps", default=1000)
|
||||
@click.option("--dp-normformer", default=True)
|
||||
@click.option("--dp-cond-drop-prob", default=0.1)
|
||||
@click.option("--dp-loss-type", default="l2")
|
||||
@click.option("--clip", default="ViT-L/14")
|
||||
@click.option("--amp", default=False)
|
||||
@click.option("--save-interval", default=120)
|
||||
@click.option("--save-path", default="./diffusion_prior_checkpoints")
|
||||
@click.option("--pretrained-model-path", default=None)
|
||||
@click.option("--gpu-device", default=0)
|
||||
def train(
|
||||
wandb_entity,
|
||||
wandb_project,
|
||||
wandb_dataset,
|
||||
wandb_arch,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
meta_url,
|
||||
learning_rate,
|
||||
weight_decay,
|
||||
dropout,
|
||||
max_grad_norm,
|
||||
num_data_points,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
image_embed_dim,
|
||||
train_percent,
|
||||
val_percent,
|
||||
test_percent,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
dpn_heads,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dp_loss_type,
|
||||
clip,
|
||||
amp,
|
||||
save_interval,
|
||||
save_path,
|
||||
pretrained_model_path,
|
||||
gpu_device
|
||||
):
|
||||
config = {
|
||||
"learning_rate": learning_rate,
|
||||
"architecture": wandb_arch,
|
||||
"dataset": wandb_dataset,
|
||||
"weight_decay": weight_decay,
|
||||
"max_gradient_clipping_norm": max_grad_norm,
|
||||
"batch_size": batch_size,
|
||||
"epochs": num_epochs,
|
||||
"diffusion_prior_network": {
|
||||
"depth": dpn_depth,
|
||||
"dim_head": dpn_dim_head,
|
||||
"heads": dpn_heads,
|
||||
"normformer": dp_normformer
|
||||
},
|
||||
"diffusion_prior": {
|
||||
"condition_on_text_encodings": dp_condition_on_text_encodings,
|
||||
"timesteps": dp_timesteps,
|
||||
"cond_drop_prob": dp_cond_drop_prob,
|
||||
"loss_type": dp_loss_type,
|
||||
"clip": clip
|
||||
}
|
||||
}
|
||||
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
|
||||
DPRIOR_PATH = pretrained_model_path
|
||||
RESUME = exists(DPRIOR_PATH)
|
||||
|
||||
if not RESUME:
|
||||
tracker.init(
|
||||
entity = wandb_entity,
|
||||
project = wandb_project,
|
||||
config = config
|
||||
)
|
||||
|
||||
# Obtain the utilized device.
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device(f"cuda:{gpu_device}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Training loop
|
||||
# diffusion prior network
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer
|
||||
)
|
||||
|
||||
# Load clip model if text-conditioning
|
||||
if dp_condition_on_text_encodings:
|
||||
clip_adapter = OpenAIClipAdapter(clip)
|
||||
@click.option("--hfa", default=True)
|
||||
@click.option("--config_path", default="configs/prior.json")
|
||||
def main(hfa, config_path):
|
||||
# start HFA if requested
|
||||
if hfa:
|
||||
accelerator = Accelerator()
|
||||
else:
|
||||
clip_adapter = None
|
||||
|
||||
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||
accelerator = None
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip_adapter,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings
|
||||
)
|
||||
# load the configuration file on main process
|
||||
if not exists(accelerator) or accelerator.is_main_process:
|
||||
click.secho(f"Loading configuration from {config_path}", fg="green")
|
||||
|
||||
# Load pre-trained model from DPRIOR_PATH
|
||||
config = TrainDiffusionPriorConfig.from_json_path(config_path)
|
||||
|
||||
if RESUME:
|
||||
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
||||
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior = diffusion_prior,
|
||||
lr = learning_rate,
|
||||
wd = weight_decay,
|
||||
max_grad_norm = max_grad_norm,
|
||||
amp = amp,
|
||||
).to(device)
|
||||
|
||||
# load optimizer and scaler
|
||||
|
||||
if RESUME:
|
||||
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
trainer.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
|
||||
# Create save_path if it doesn't exist
|
||||
|
||||
Path(save_path).mkdir(exist_ok = True, parents = True)
|
||||
|
||||
# Utilize wrapper to abstract away loader logic
|
||||
print_ribbon("Downloading Embeddings")
|
||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
||||
|
||||
if dp_condition_on_text_encodings:
|
||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
||||
else:
|
||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
||||
|
||||
### Training code ###
|
||||
|
||||
step = 1
|
||||
timer = Timer()
|
||||
epochs = num_epochs
|
||||
|
||||
for _ in range(epochs):
|
||||
|
||||
for image, text in tqdm(train_loader):
|
||||
|
||||
diffusion_prior.train()
|
||||
|
||||
input_args = dict(image_embed=image)
|
||||
if dp_condition_on_text_encodings:
|
||||
input_args = dict(**input_args, text = text)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text)
|
||||
|
||||
loss = trainer(**input_args)
|
||||
|
||||
# Samples per second
|
||||
|
||||
samples_per_sec = batch_size * step / timer.elapsed()
|
||||
|
||||
# Save checkpoint every save_interval minutes
|
||||
if(int(timer.elapsed()) >= 60 * save_interval):
|
||||
timer.reset()
|
||||
|
||||
save_diffusion_model(
|
||||
save_path,
|
||||
diffusion_prior,
|
||||
trainer.optimizer,
|
||||
trainer.scaler,
|
||||
config,
|
||||
image_embed_dim)
|
||||
|
||||
# Log to wandb
|
||||
tracker.log({"Training loss": loss,
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
||||
### Evaluate model(validation run) ###
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
||||
|
||||
step += 1
|
||||
trainer.update()
|
||||
|
||||
### Test run ###
|
||||
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
|
||||
# send config to get processed
|
||||
initialize_training(config, accelerator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user