mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
44 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
46be8c32d3 | ||
|
|
900f086a6d | ||
|
|
b3e646fd3b | ||
|
|
6a59c7093d | ||
|
|
a6cdbe0b9c | ||
|
|
e928ae5c34 | ||
|
|
1bd8a7835a | ||
|
|
f33453df9f | ||
|
|
1e4bb2bafb | ||
|
|
ee75515c7d | ||
|
|
ec68243479 | ||
|
|
3afdcdfe86 | ||
|
|
b9a908ff75 | ||
|
|
e1fe3089df | ||
|
|
6d477d7654 | ||
|
|
531fe4b62f | ||
|
|
ec5a77fc55 | ||
|
|
fac63c61bc | ||
|
|
3d23ba4aa5 | ||
|
|
282c35930f | ||
|
|
27b0f7ca0d | ||
|
|
7b0edf9e42 | ||
|
|
a922a539de | ||
|
|
8f2466f1cd | ||
|
|
908ab83799 | ||
|
|
46a2558d53 | ||
|
|
86109646e3 | ||
|
|
6a11b9678b | ||
|
|
b90364695d | ||
|
|
868c001199 | ||
|
|
032e83b0e0 | ||
|
|
2e85e736f3 | ||
|
|
f5760bdb92 | ||
|
|
c453f468b1 | ||
|
|
98f0c17759 | ||
|
|
a5b9fd6ca8 | ||
|
|
4b994601ae | ||
|
|
fddf66e91e | ||
|
|
c8422ffd5d | ||
|
|
2aadc23c7c | ||
|
|
c098f57e09 | ||
|
|
0021535c26 | ||
|
|
56883910fb | ||
|
|
893f270012 |
111
README.md
111
README.md
@@ -20,18 +20,20 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
|
||||
|
||||
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
|
||||
|
||||
<img src="./samples/oxford.png" width="600px" />
|
||||
<img src="./samples/oxford.png" width="450px" />
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||
|
||||
- <a href="https://github.com/rom1504">Romain</a> has scaled up training to 800 GPUs with the available scripts without any issues
|
||||
|
||||
## Pre-Trained Models
|
||||
|
||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
|
||||
- DALL-E 2 🚧
|
||||
- DALL-E 2 🚧 - <a href="https://github.com/LAION-AI/dalle2-laion">DALL-E 2 Laion repository</a>
|
||||
|
||||
## Appreciation
|
||||
|
||||
@@ -42,6 +44,7 @@ This library would not have gotten to this working state without the help of
|
||||
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
|
||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
|
||||
@@ -368,7 +371,8 @@ unet1 = Unet(
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
@@ -385,8 +389,7 @@ decoder = Decoder(
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
@@ -579,7 +582,8 @@ unet1 = Unet(
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
@@ -596,12 +600,11 @@ decoder = Decoder(
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
@@ -988,61 +991,7 @@ dataset = ImageEmbeddingDataset(
|
||||
|
||||
#### `train_diffusion_prior.py`
|
||||
|
||||
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
|
||||
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
|
||||
|
||||
#### Usage
|
||||
|
||||
```bash
|
||||
$ python train_diffusion_prior.py
|
||||
```
|
||||
|
||||
The most significant parameters for the script are as follows:
|
||||
|
||||
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
|
||||
|
||||
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
|
||||
|
||||
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
||||
|
||||
- `learning-rate`, default = `1.1e-4`
|
||||
|
||||
- `weight-decay`, default = `6.02e-2`
|
||||
|
||||
- `max-grad-norm`, default = `0.5`
|
||||
|
||||
- `batch-size`, default = `10 ** 4`
|
||||
|
||||
- `num-epochs`, default = `5`
|
||||
|
||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||
|
||||
#### Loading and Saving the DiffusionPrior model
|
||||
|
||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
||||
|
||||
```python
|
||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||
```
|
||||
|
||||
##### Loading
|
||||
|
||||
load_diffusion_model(dprior_path, device)
|
||||
dprior_path : path to saved model(.pth)
|
||||
device : the cuda device you're running on
|
||||
|
||||
##### Saving
|
||||
|
||||
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
||||
save_path : path to save at
|
||||
model : object of Diffusion_Prior
|
||||
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
||||
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
scaler : a GradScaler object.
|
||||
e.g: scaler = GradScaler(enabled=amp)
|
||||
config : config object created in train_diffusion_prior.py - see file for example.
|
||||
image_embed_dim - the dimension of the image_embedding
|
||||
e.g: 768
|
||||
For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
@@ -1092,19 +1041,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
|
||||
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
|
||||
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -1144,15 +1088,6 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022},
|
||||
url = {https://arxiv.org/abs/2204.01697}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2021VectorquantizedIM,
|
||||
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
||||
@@ -1221,4 +1156,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Saharia2021PaletteID,
|
||||
title = {Palette: Image-to-Image Diffusion Models},
|
||||
author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2111.05826}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -91,21 +91,83 @@ Each metric can be enabled by setting its configuration. The configuration keys
|
||||
|
||||
**<ins>Tracker</ins>:**
|
||||
|
||||
Selects which tracker to use and configures it.
|
||||
Selects how the experiment will be tracked.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
|
||||
| `data_path` | No | `./models` | Where the tracker will store local data. |
|
||||
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
|
||||
| `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |
|
||||
| `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |
|
||||
| `log` | Yes | N/A | Logging configuration. |
|
||||
| `load` | No | `None` | Checkpoint loading configuration. |
|
||||
| `save` | Yes | N/A | Checkpoint/Model saving configuration. |
|
||||
Tracking is split up into three sections:
|
||||
* Log: Where to save run metadata and image output. Options are `console` or `wandb`.
|
||||
* Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.
|
||||
* Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.
|
||||
|
||||
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
|
||||
**Logging:**
|
||||
|
||||
**<ins>Load</ins>:**
|
||||
|
||||
Selects where to load a pretrained model from.
|
||||
If using `console` there is no further configuration than setting `log_type` to `console`.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `source` | No | `None` | Supports `file` or `wandb`. |
|
||||
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
|
||||
| `log_type` | Yes | N/A | Must be `console`. |
|
||||
|
||||
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
|
||||
If using `wandb`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `log_type` | Yes | N/A | Must be `wandb`. |
|
||||
| `wandb_entity` | Yes | N/A | The wandb entity to log to. |
|
||||
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
|
||||
| `wandb_run_name` | No | `None` | The wandb run name. |
|
||||
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
|
||||
| `wandb_resume` | No | `False` | Whether to resume an old run. |
|
||||
|
||||
**Loading:**
|
||||
|
||||
If using `local`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | Must be `local`. |
|
||||
| `file_path` | Yes | N/A | The path to the checkpoint file. |
|
||||
|
||||
If using `url`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | Must be `url`. |
|
||||
| `url` | Yes | N/A | The url of the checkpoint file. |
|
||||
|
||||
If using `wandb`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | Must be `wandb`. |
|
||||
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |
|
||||
| `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |
|
||||
|
||||
**Saving:**
|
||||
Unlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.
|
||||
|
||||
All save locations have these configuration options
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
|
||||
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
|
||||
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
|
||||
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
|
||||
|
||||
If using `local`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `local`. |
|
||||
|
||||
If using `huggingface`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `huggingface`. |
|
||||
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
|
||||
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
|
||||
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
|
||||
|
||||
If using `wandb`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `wandb`. |
|
||||
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"channels": 3,
|
||||
"timesteps": 1000,
|
||||
"loss_type": "l2",
|
||||
"beta_schedule": "cosine",
|
||||
"beta_schedule": ["cosine"],
|
||||
"learned_variance": true
|
||||
},
|
||||
"data": {
|
||||
@@ -80,20 +80,32 @@
|
||||
}
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "console",
|
||||
"data_path": "./models",
|
||||
"overwrite_data_path": true,
|
||||
|
||||
"wandb_entity": "",
|
||||
"wandb_project": "",
|
||||
"log": {
|
||||
"log_type": "wandb",
|
||||
|
||||
"verbose": false
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
"wandb_entity": "your_wandb",
|
||||
"wandb_project": "your_project",
|
||||
|
||||
"run_path": "",
|
||||
"file_path": "",
|
||||
"verbose": true
|
||||
},
|
||||
|
||||
"resume": false
|
||||
"load": {
|
||||
"load_from": null
|
||||
},
|
||||
|
||||
"save": [{
|
||||
"save_to": "wandb"
|
||||
}, {
|
||||
"save_to": "huggingface",
|
||||
"huggingface_repo": "Veldrovive/test_model",
|
||||
|
||||
"save_all": true,
|
||||
"save_latest": true,
|
||||
"save_best": true,
|
||||
|
||||
"save_type": "model"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import math
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
from functools import partial, wraps
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
@@ -45,6 +45,11 @@ def exists(val):
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def first(arr, d = None):
|
||||
if len(arr) == 0:
|
||||
return d
|
||||
return arr[0]
|
||||
|
||||
def maybe(fn):
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
@@ -58,11 +63,16 @@ def default(val, d):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
def cast_tuple(val, length = None):
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
|
||||
|
||||
if exists(length):
|
||||
assert len(out) == length
|
||||
|
||||
return out
|
||||
|
||||
def module_device(module):
|
||||
return next(module.parameters()).device
|
||||
@@ -115,14 +125,19 @@ def log(t, eps = 1e-12):
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
def resize_image_to(image, target_image_size):
|
||||
def resize_image_to(image, target_image_size, clamp_range = None):
|
||||
orig_image_size = image.shape[-1]
|
||||
|
||||
if orig_image_size == target_image_size:
|
||||
return image
|
||||
|
||||
scale_factors = target_image_size / orig_image_size
|
||||
return resize(image, scale_factors = scale_factors)
|
||||
out = resize(image, scale_factors = scale_factors)
|
||||
|
||||
if exists(clamp_range):
|
||||
out = out.clamp(*clamp_range)
|
||||
|
||||
return out
|
||||
|
||||
# image normalization functions
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
@@ -325,21 +340,25 @@ def approx_standard_normal_cdf(x):
|
||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
assert x.shape == means.shape == log_scales.shape
|
||||
|
||||
# attempting to correct nan gradients when learned variance is turned on
|
||||
# in the setting of deepspeed fp16
|
||||
eps = 1e-12 if x.dtype == torch.float32 else 1e-3
|
||||
|
||||
centered_x = x - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 1. / 255.)
|
||||
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - 1. / 255.)
|
||||
cdf_min = approx_standard_normal_cdf(min_in)
|
||||
log_cdf_plus = log(cdf_plus)
|
||||
log_one_minus_cdf_min = log(1. - cdf_min)
|
||||
log_cdf_plus = log(cdf_plus, eps = eps)
|
||||
log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
log_probs = torch.where(x < -thres,
|
||||
log_cdf_plus,
|
||||
torch.where(x > thres,
|
||||
log_one_minus_cdf_min,
|
||||
log(cdf_delta)))
|
||||
log(cdf_delta, eps = eps)))
|
||||
|
||||
return log_probs
|
||||
|
||||
@@ -351,7 +370,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
alphas_cumprod = alphas_cumprod / first(alphas_cumprod)
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return torch.clip(betas, 0, 0.999)
|
||||
|
||||
@@ -480,14 +499,16 @@ class NoiseScheduler(nn.Module):
|
||||
# diffusion prior
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.register_buffer("beta", torch.zeros(dim))
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
||||
|
||||
x = x / x.amax(dim = -1, keepdim = True).detach()
|
||||
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = -1, keepdim = True)
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
@@ -496,10 +517,10 @@ class ChanLayerNorm(nn.Module):
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = x / x.amax(dim = 1, keepdim = True).detach()
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
@@ -619,10 +640,13 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
rotary_emb = None
|
||||
rotary_emb = None,
|
||||
pb_relax_alpha = 32 ** 2
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.pb_relax_alpha = pb_relax_alpha
|
||||
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
||||
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
@@ -686,7 +710,10 @@ class Attention(nn.Module):
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
sim = sim * self.pb_relax_alpha
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
@@ -1088,8 +1115,16 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# decoder
|
||||
|
||||
def Upsample(dim):
|
||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
def ConvTransposeUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
||||
|
||||
def NearestUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Sequential(
|
||||
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||
)
|
||||
|
||||
def Downsample(dim, *, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
@@ -1101,11 +1136,12 @@ class SinusoidalPosEmb(nn.Module):
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
dtype, device = x.dtype, x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
|
||||
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
||||
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
@@ -1166,7 +1202,7 @@ class ResnetBlock(nn.Module):
|
||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None, time_emb = None):
|
||||
def forward(self, x, time_emb = None, cond = None):
|
||||
|
||||
scale_shift = None
|
||||
if exists(self.time_mlp) and exists(time_emb):
|
||||
@@ -1192,10 +1228,12 @@ class CrossAttention(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
norm_context = False
|
||||
norm_context = False,
|
||||
pb_relax_alpha = 32 ** 2
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.pb_relax_alpha = pb_relax_alpha
|
||||
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
@@ -1241,26 +1279,15 @@ class CrossAttention(nn.Module):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
sim = sim * self.pb_relax_alpha
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class GridAttention(nn.Module):
|
||||
def __init__(self, *args, window_size = 8, **kwargs):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.attn = Attention(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
wsz = self.window_size
|
||||
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
|
||||
out = self.attn(x)
|
||||
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
||||
return out
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1342,6 +1369,7 @@ class Unet(nn.Module):
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
channels_out = None,
|
||||
self_attn = False,
|
||||
attn_dim_head = 32,
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
@@ -1359,6 +1387,9 @@ class Unet(nn.Module):
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
memory_efficient = False,
|
||||
scale_skip_connection = False,
|
||||
nearest_upsample = False,
|
||||
final_conv_kernel_size = 1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1385,6 +1416,8 @@ class Unet(nn.Module):
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
num_stages = len(in_out)
|
||||
|
||||
# time, image embeddings, and optional text encoding
|
||||
|
||||
cond_dim = default(cond_dim, dim)
|
||||
@@ -1440,16 +1473,24 @@ class Unet(nn.Module):
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
|
||||
# whether to scale skip connection, adopted in Imagen
|
||||
|
||||
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
|
||||
|
||||
# attention related params
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
|
||||
self_attn = cast_tuple(self_attn, num_stages)
|
||||
|
||||
create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs)))
|
||||
|
||||
# resnet block klass
|
||||
|
||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||
resnet_groups = cast_tuple(resnet_groups, num_stages)
|
||||
top_level_resnet_group = first(resnet_groups)
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages)
|
||||
|
||||
# downsample klass
|
||||
|
||||
@@ -1457,46 +1498,71 @@ class Unet(nn.Module):
|
||||
if cross_embed_downsample:
|
||||
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
||||
|
||||
# upsample klass
|
||||
|
||||
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
||||
|
||||
# give memory efficient unet an initial resnet block
|
||||
|
||||
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||
skip_connect_dims = [] # keeping track of skip connection dimensions
|
||||
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
dim_layer = dim_out if memory_efficient else dim_in
|
||||
skip_connect_dims.append(dim_layer)
|
||||
|
||||
attention = nn.Identity()
|
||||
if layer_self_attn:
|
||||
attention = create_self_attn(dim_layer)
|
||||
elif sparse_attn:
|
||||
attention = Residual(LinearAttention(dim_layer, **attn_kwargs))
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
||||
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
downsample_klass(dim_out) if not is_last and not memory_efficient else None
|
||||
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
|
||||
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
attention,
|
||||
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_attn = create_self_attn(mid_dim)
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
skip_connect_dim = skip_connect_dims.pop()
|
||||
|
||||
attention = nn.Identity()
|
||||
if layer_self_attn:
|
||||
attention = create_self_attn(dim_out)
|
||||
elif sparse_attn:
|
||||
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
attention,
|
||||
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||
]))
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
|
||||
nn.Conv2d(dim, self.channels_out, 1)
|
||||
)
|
||||
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
@@ -1570,6 +1636,7 @@ class Unet(nn.Module):
|
||||
|
||||
# time conditioning
|
||||
|
||||
time = time.type_as(x)
|
||||
time_hiddens = self.to_time_hiddens(time)
|
||||
|
||||
time_tokens = self.to_time_tokens(time_hiddens)
|
||||
@@ -1660,44 +1727,55 @@ class Unet(nn.Module):
|
||||
c = self.norm_cond(c)
|
||||
mid_c = self.norm_mid_cond(mid_c)
|
||||
|
||||
# initial resnet block
|
||||
|
||||
if exists(self.init_resnet_block):
|
||||
x = self.init_resnet_block(x, t)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
hiddens = []
|
||||
|
||||
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
|
||||
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
|
||||
if exists(pre_downsample):
|
||||
x = pre_downsample(x)
|
||||
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = init_block(x, t, c)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
x = resnet_block(x, t, c)
|
||||
hiddens.append(x)
|
||||
|
||||
x = attn(x)
|
||||
hiddens.append(x)
|
||||
|
||||
if exists(post_downsample):
|
||||
x = post_downsample(x)
|
||||
|
||||
x = self.mid_block1(x, mid_c, t)
|
||||
x = self.mid_block1(x, t, mid_c)
|
||||
|
||||
if exists(self.mid_attn):
|
||||
x = self.mid_attn(x)
|
||||
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
x = self.mid_block2(x, t, mid_c)
|
||||
|
||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||
|
||||
for init_block, resnet_blocks, attn, upsample in self.ups:
|
||||
x = connect_skip(x)
|
||||
x = init_block(x, t, c)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
x = connect_skip(x)
|
||||
x = resnet_block(x, t, c)
|
||||
|
||||
x = attn(x)
|
||||
x = upsample(x)
|
||||
|
||||
x = torch.cat((x, r), dim = 1)
|
||||
return self.final_conv(x)
|
||||
|
||||
x = self.final_resnet_block(x, t)
|
||||
return self.to_out(x)
|
||||
|
||||
class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
@@ -1705,9 +1783,12 @@ class LowresConditioner(nn.Module):
|
||||
downsample_first = True,
|
||||
blur_sigma = 0.6,
|
||||
blur_kernel_size = 3,
|
||||
input_image_range = None
|
||||
):
|
||||
super().__init__()
|
||||
self.downsample_first = downsample_first
|
||||
self.input_image_range = input_image_range
|
||||
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_kernel_size = blur_kernel_size
|
||||
|
||||
@@ -1721,7 +1802,7 @@ class LowresConditioner(nn.Module):
|
||||
blur_kernel_size = None
|
||||
):
|
||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size)
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range)
|
||||
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
@@ -1741,7 +1822,7 @@ class LowresConditioner(nn.Module):
|
||||
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
|
||||
|
||||
return cond_fmap
|
||||
|
||||
@@ -1764,16 +1845,15 @@ class Decoder(nn.Module):
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
||||
blur_sigma = 0.6, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict(),
|
||||
learned_variance = True,
|
||||
learned_variance_constrain_frac = False,
|
||||
vb_loss_weight = 0.001,
|
||||
unconditional = False,
|
||||
unconditional = False, # set to True for generating images without conditioning
|
||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||
use_dynamic_thres = False, # from the Imagen paper
|
||||
dynamic_thres_percentile = 0.9,
|
||||
@@ -1782,13 +1862,6 @@ class Decoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.unconditional = unconditional
|
||||
|
||||
# text conditioning
|
||||
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# clip
|
||||
|
||||
self.clip = None
|
||||
@@ -1820,12 +1893,16 @@ class Decoder(nn.Module):
|
||||
|
||||
self.channels = channels
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
# verify conditioning method
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
num_unets = len(unets)
|
||||
|
||||
self.unconditional = unconditional
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
|
||||
@@ -1852,8 +1929,8 @@ class Decoder(nn.Module):
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first and not unconditional,
|
||||
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
||||
cond_on_image_embeds = not unconditional and is_first,
|
||||
cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings,
|
||||
channels = unet_channels,
|
||||
channels_out = unet_channels_out
|
||||
)
|
||||
@@ -1899,6 +1976,10 @@ class Decoder(nn.Module):
|
||||
|
||||
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||
|
||||
# input image range
|
||||
|
||||
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
|
||||
|
||||
# cascading ddpm related stuff
|
||||
|
||||
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
||||
@@ -1908,6 +1989,7 @@ class Decoder(nn.Module):
|
||||
downsample_first = lowres_downsample_first,
|
||||
blur_sigma = blur_sigma,
|
||||
blur_kernel_size = blur_kernel_size,
|
||||
input_image_range = self.input_image_range
|
||||
)
|
||||
|
||||
# classifier free guidance
|
||||
@@ -1939,6 +2021,10 @@ class Decoder(nn.Module):
|
||||
def device(self):
|
||||
return self._dummy.device
|
||||
|
||||
@property
|
||||
def condition_on_text_encodings(self):
|
||||
return any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
@@ -2194,7 +2280,8 @@ class Decoder(nn.Module):
|
||||
image_embed = None,
|
||||
text_encodings = None,
|
||||
text_mask = None,
|
||||
unet_number = None
|
||||
unet_number = None,
|
||||
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
|
||||
):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
@@ -2244,7 +2331,12 @@ class Decoder(nn.Module):
|
||||
image = vae.encode(image)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
|
||||
if not return_lowres_cond_image:
|
||||
return losses
|
||||
|
||||
return losses, lowres_cond_img
|
||||
|
||||
# main class
|
||||
|
||||
@@ -2292,6 +2384,6 @@ class DALLE2(nn.Module):
|
||||
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
||||
|
||||
if one_text:
|
||||
return images[0]
|
||||
return first(images)
|
||||
|
||||
return images
|
||||
|
||||
@@ -21,7 +21,7 @@ def get_example_file(fs, path, file_format):
|
||||
"""
|
||||
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
||||
|
||||
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
|
||||
def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
|
||||
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
||||
previous_tar_url = None
|
||||
current_embeddings = None
|
||||
@@ -56,7 +56,7 @@ def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handler
|
||||
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
||||
if torch.count_nonzero(embedding) == 0:
|
||||
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
||||
sample["npy"] = embedding
|
||||
sample[sample_key] = embedding
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
@@ -84,18 +84,20 @@ def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.re
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
||||
|
||||
def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
||||
def join_embeddings(samples, handler=wds.handlers.reraise_exception):
|
||||
"""
|
||||
Requires that both the image and embedding are present in the sample
|
||||
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
||||
Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
|
||||
either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
|
||||
"""
|
||||
for sample in samples:
|
||||
try:
|
||||
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
|
||||
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
|
||||
sample['emb'] = {}
|
||||
if 'text_emb' in sample:
|
||||
sample['emb']['text'] = sample['text_emb']
|
||||
if 'img_emb' in sample:
|
||||
sample['emb']['img'] = sample['img_emb']
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
@@ -103,6 +105,23 @@ def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
||||
else:
|
||||
break
|
||||
|
||||
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
|
||||
"""
|
||||
Requires that both the image and embedding are present in the sample
|
||||
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
||||
"""
|
||||
for sample in samples:
|
||||
try:
|
||||
for key in required_keys:
|
||||
assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
|
||||
yield sample
|
||||
except Exception as exn: # From wds implementation
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
key_verifier = wds.filters.pipelinefilter(verify_keys)
|
||||
|
||||
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
"""
|
||||
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
||||
@@ -112,7 +131,8 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
def __init__(
|
||||
self,
|
||||
urls,
|
||||
embedding_folder_url=None,
|
||||
img_embedding_folder_url=None,
|
||||
text_embedding_folder_url=None,
|
||||
index_width=None,
|
||||
img_preproc=None,
|
||||
extra_keys=[],
|
||||
@@ -136,7 +156,12 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
keys = ["jpg", "npy"] + extra_keys
|
||||
keys = ["jpg", "emb"] + extra_keys
|
||||
# if img_embedding_folder_url is not None:
|
||||
# keys.append("img_emb")
|
||||
# if text_embedding_folder_url is not None:
|
||||
# keys.append("text_emb")
|
||||
# keys.extend(extra_keys)
|
||||
self.key_map = {key: i for i, key in enumerate(keys)}
|
||||
self.resampling = resample
|
||||
self.img_preproc = img_preproc
|
||||
@@ -145,7 +170,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
# Then this has an s3 link for the webdataset and we need extra packages
|
||||
if shutil.which("s3cmd") is None:
|
||||
raise RuntimeError("s3cmd is required for s3 webdataset")
|
||||
if "s3:" in embedding_folder_url:
|
||||
if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
|
||||
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
||||
try:
|
||||
import s3fs
|
||||
@@ -160,17 +185,24 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||
if shuffle_shards:
|
||||
self.append(wds.filters.shuffle(1000))
|
||||
|
||||
if embedding_folder_url is not None:
|
||||
if img_embedding_folder_url is not None:
|
||||
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
||||
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
|
||||
self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
|
||||
if text_embedding_folder_url is not None:
|
||||
self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
|
||||
|
||||
self.append(wds.tarfile_to_samples(handler=handler))
|
||||
self.append(wds.decode("pilrgb", handler=handler))
|
||||
if embedding_folder_url is not None:
|
||||
# Then we are loading embeddings for a remote source
|
||||
if img_embedding_folder_url is not None:
|
||||
# Then we are loading image embeddings for a remote source
|
||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
|
||||
self.append(verify_keys)
|
||||
self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
|
||||
if text_embedding_folder_url is not None:
|
||||
# Then we are loading image embeddings for a remote source
|
||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||
self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
|
||||
self.append(join_embeddings)
|
||||
self.append(key_verifier(required_keys=keys, handler=handler))
|
||||
# Apply preprocessing
|
||||
self.append(wds.map(self.preproc))
|
||||
self.append(wds.to_tuple(*keys))
|
||||
@@ -185,7 +217,8 @@ def create_image_embedding_dataloader(
|
||||
tar_url,
|
||||
num_workers,
|
||||
batch_size,
|
||||
embeddings_url=None,
|
||||
img_embeddings_url=None,
|
||||
text_embeddings_url=None,
|
||||
index_width=None,
|
||||
shuffle_num = None,
|
||||
shuffle_shards = True,
|
||||
@@ -211,7 +244,8 @@ def create_image_embedding_dataloader(
|
||||
"""
|
||||
ds = ImageEmbeddingDataset(
|
||||
tar_url,
|
||||
embeddings_url,
|
||||
img_embedding_folder_url=img_embeddings_url,
|
||||
text_embedding_folder_url=text_embeddings_url,
|
||||
index_width=index_width,
|
||||
shuffle_shards=shuffle_shards,
|
||||
resample=resample_shards,
|
||||
@@ -228,4 +262,4 @@ def create_image_embedding_dataloader(
|
||||
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
|
||||
pin_memory=True,
|
||||
shuffle=False
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import urllib.request
|
||||
import os
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
import shutil
|
||||
from itertools import zip_longest
|
||||
from typing import Optional, List, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
# constants
|
||||
|
||||
@@ -27,126 +30,484 @@ def load_wandb_file(run_path, file_path, **kwargs):
|
||||
def load_local_file(file_path, **kwargs):
|
||||
return file_path
|
||||
|
||||
# base class
|
||||
|
||||
class BaseTracker(nn.Module):
|
||||
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
||||
super().__init__()
|
||||
class BaseLogger:
|
||||
"""
|
||||
An abstract class representing an object that can log data.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
verbose (bool): Whether of not to always print logs to the console.
|
||||
"""
|
||||
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.data_path.mkdir(parents = True, exist_ok = True)
|
||||
self.verbose = verbose
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_state_dict(self, recall_source, *args, **kwargs):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
"""
|
||||
Loads a state dict from any source.
|
||||
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
|
||||
this should not be linked to any individual tracker.
|
||||
Initializes the logger.
|
||||
Errors if the logger is invalid.
|
||||
"""
|
||||
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
|
||||
if recall_source == 'wandb':
|
||||
return torch.load(load_wandb_file(*args, **kwargs))
|
||||
elif recall_source == 'local':
|
||||
return torch.load(load_local_file(*args, **kwargs))
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_file(self, recall_source, *args, **kwargs):
|
||||
if recall_source == 'wandb':
|
||||
return load_wandb_file(*args, **kwargs)
|
||||
elif recall_source == 'local':
|
||||
return load_local_file(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
def log(self, log, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
# Tracker that no-ops all calls except for recall
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
class DummyTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def log_file(self, file_path, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
pass
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
pass
|
||||
class ConsoleLogger(BaseLogger):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
print("Logging to console")
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
pass
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
pass
|
||||
|
||||
# basic stdout class
|
||||
|
||||
class ConsoleTracker(BaseTracker):
|
||||
def init(self, **config):
|
||||
print(config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
def log(self, log, **kwargs) -> None:
|
||||
print(log)
|
||||
|
||||
def log_images(self, images, **kwargs): # noop for logging images
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
torch.save(state_dict, str(self.data_path / relative_path))
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
# This is a no-op for local file systems since it is already saved locally
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
# basic wandb class
|
||||
def log_file(self, file_path, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
class WandbTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
print(error_string)
|
||||
|
||||
class WandbLogger(BaseLogger):
|
||||
"""
|
||||
Logs to a wandb run.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
wandb_entity (str): The wandb entity to log to.
|
||||
wandb_project (str): The wandb project to log to.
|
||||
wandb_run_id (str): The wandb run id to resume.
|
||||
wandb_run_name (str): The wandb run name to use.
|
||||
wandb_resume (bool): Whether to resume a wandb run.
|
||||
"""
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
wandb_entity: str,
|
||||
wandb_project: str,
|
||||
wandb_run_id: Optional[str] = None,
|
||||
wandb_run_name: Optional[str] = None,
|
||||
wandb_resume: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.entity = wandb_entity
|
||||
self.project = wandb_project
|
||||
self.run_id = wandb_run_id
|
||||
self.run_name = wandb_run_name
|
||||
self.resume = wandb_resume
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
||||
assert self.project is not None, "wandb_project must be specified for wandb logger"
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
# Initializes the wandb run
|
||||
init_object = {
|
||||
"entity": self.entity,
|
||||
"project": self.project,
|
||||
"config": {**full_config.dict(), **extra_config}
|
||||
}
|
||||
if self.run_name is not None:
|
||||
init_object['name'] = self.run_name
|
||||
if self.resume:
|
||||
assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
|
||||
if self.run_name is not None:
|
||||
print("You are renaming a run. I hope that is what you intended.")
|
||||
init_object['resume'] = 'must'
|
||||
init_object['id'] = self.run_id
|
||||
|
||||
def init(self, **config):
|
||||
self.wandb.init(**config)
|
||||
self.wandb.init(**init_object)
|
||||
print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
|
||||
|
||||
def log(self, log, verbose=False, **kwargs):
|
||||
if verbose:
|
||||
def log(self, log, **kwargs) -> None:
|
||||
if self.verbose:
|
||||
print(log)
|
||||
self.wandb.log(log, **kwargs)
|
||||
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs):
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
"""
|
||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||
"""
|
||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||
self.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
"""
|
||||
Saves a state_dict to disk and uploads it
|
||||
"""
|
||||
full_path = str(self.data_path / relative_path)
|
||||
torch.save(state_dict, full_path)
|
||||
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
||||
self.wandb.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_file(self, file_path, base_path=None, **kwargs):
|
||||
"""
|
||||
Uploads a file from disk to wandb
|
||||
"""
|
||||
def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
|
||||
if base_path is None:
|
||||
base_path = self.data_path
|
||||
# Then we take the basepath as the parent of the file_path
|
||||
base_path = Path(file_path).parent
|
||||
self.wandb.save(str(file_path), base_path = str(base_path))
|
||||
|
||||
def log_error(self, error_string, step=None, **kwargs) -> None:
|
||||
if self.verbose:
|
||||
print(error_string)
|
||||
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
||||
|
||||
logger_type_map = {
|
||||
'console': ConsoleLogger,
|
||||
'wandb': WandbLogger,
|
||||
}
|
||||
def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
|
||||
if logger_type == 'custom':
|
||||
raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
|
||||
try:
|
||||
logger_class = logger_type_map[logger_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
|
||||
return logger_class(data_path, **kwargs)
|
||||
|
||||
class BaseLoader:
|
||||
"""
|
||||
An abstract class representing an object that can load a model checkpoint.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
"""
|
||||
def __init__(self, data_path: str, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def recall() -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
class UrlLoader(BaseLoader):
|
||||
"""
|
||||
A loader that downloads the file from a url and loads it
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
url (str): The url to download the file from.
|
||||
"""
|
||||
def __init__(self, data_path: str, url: str, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.url = url
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the file exists to be downloaded
|
||||
pass # TODO: Actually implement that
|
||||
|
||||
def recall(self) -> dict:
|
||||
# Download the file
|
||||
save_path = self.data_path / 'loaded_checkpoint.pth'
|
||||
urllib.request.urlretrieve(self.url, str(save_path))
|
||||
# Load the file
|
||||
return torch.load(str(save_path), map_location='cpu')
|
||||
|
||||
|
||||
class LocalLoader(BaseLoader):
|
||||
"""
|
||||
A loader that loads a file from a local path
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
file_path (str): The path to the file to load.
|
||||
"""
|
||||
def __init__(self, data_path: str, file_path: str, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.file_path = Path(file_path)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the file exists to be loaded
|
||||
if not self.file_path.exists():
|
||||
raise FileNotFoundError(f'Model not found at {self.file_path}')
|
||||
|
||||
def recall(self) -> dict:
|
||||
# Load the file
|
||||
return torch.load(str(self.file_path), map_location='cpu')
|
||||
|
||||
class WandbLoader(BaseLoader):
|
||||
"""
|
||||
A loader that loads a model from an existing wandb run
|
||||
"""
|
||||
def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.run_path = wandb_run_path
|
||||
self.file_path = wandb_file_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||
# Make sure the file can be downloaded
|
||||
if self.wandb.run is not None and self.run_path is None:
|
||||
self.run_path = self.wandb.run.path
|
||||
assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
|
||||
assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
|
||||
assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
|
||||
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
pass # TODO: Actually implement that
|
||||
|
||||
def recall(self) -> dict:
|
||||
file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
|
||||
return torch.load(file_reference.name, map_location='cpu')
|
||||
|
||||
loader_type_map = {
|
||||
'url': UrlLoader,
|
||||
'local': LocalLoader,
|
||||
'wandb': WandbLoader,
|
||||
}
|
||||
def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
|
||||
if loader_type == 'custom':
|
||||
raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
|
||||
try:
|
||||
loader_class = loader_type_map[loader_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
|
||||
return loader_class(data_path, **kwargs)
|
||||
|
||||
class BaseSaver:
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
|
||||
save_best_to: Optional[Union[str, bool]] = 'best.pth',
|
||||
save_meta_to: str = './',
|
||||
save_type: str = 'checkpoint',
|
||||
**kwargs
|
||||
):
|
||||
self.data_path = Path(data_path)
|
||||
self.save_latest_to = save_latest_to
|
||||
self.saving_latest = save_latest_to is not None and save_latest_to is not False
|
||||
self.save_best_to = save_best_to
|
||||
self.saving_best = save_best_to is not None and save_best_to is not False
|
||||
self.save_meta_to = save_meta_to
|
||||
self.save_type = save_type
|
||||
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
|
||||
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
|
||||
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
|
||||
"""
|
||||
Save a general file under save_meta_to
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class LocalSaver(BaseSaver):
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the directory exists to be saved to
|
||||
print(f"Saving {self.save_type} locally")
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
|
||||
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
||||
# Copy the file to save_path
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
||||
shutil.copy(local_path, save_path)
|
||||
|
||||
class WandbSaver(BaseSaver):
|
||||
def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.run_path = wandb_run_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
# Makes sure that the user can upload tot his run
|
||||
if self.run_path is not None:
|
||||
entity, project, run_id = self.run_path.split("/")
|
||||
self.run = self.wandb.init(entity=entity, project=project, id=run_id)
|
||||
else:
|
||||
assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
|
||||
self.run = self.wandb.run
|
||||
# TODO: Now actually check if upload is possible
|
||||
print(f"Saving to wandb run {self.run.path}-{self.run.name}")
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||
# In order to log something in the correct place in wandb, we need to have the same file structure here
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
|
||||
save_path = Path(self.data_path) / save_path
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(local_path, save_path)
|
||||
self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
|
||||
|
||||
class HuggingfaceSaver(BaseSaver):
|
||||
def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.huggingface_repo = huggingface_repo
|
||||
self.token_path = token_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs):
|
||||
# Makes sure this user can upload to the repo
|
||||
self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
|
||||
try:
|
||||
identity = self.hub.whoami() # Errors if not logged in
|
||||
# Then we are logged in
|
||||
except:
|
||||
# We are not logged in. Use the token_path to set the token.
|
||||
if not os.path.exists(self.token_path):
|
||||
raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
|
||||
with open(self.token_path, "r") as f:
|
||||
token = f.read().strip()
|
||||
self.hub.HfApi.set_access_token(token)
|
||||
identity = self.hub.whoami()
|
||||
print(f"Saving to huggingface repo {self.huggingface_repo}")
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||
# Saving to huggingface is easy, we just need to upload the file with the correct name
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
|
||||
self.hub.upload_file(
|
||||
path_or_fileobj=str(local_path),
|
||||
path_in_repo=str(save_path),
|
||||
repo_id=self.huggingface_repo
|
||||
)
|
||||
|
||||
saver_type_map = {
|
||||
'local': LocalSaver,
|
||||
'wandb': WandbSaver,
|
||||
'huggingface': HuggingfaceSaver
|
||||
}
|
||||
def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
|
||||
if saver_type == 'custom':
|
||||
raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
|
||||
try:
|
||||
saver_class = saver_type_map[saver_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
|
||||
return saver_class(data_path, **kwargs)
|
||||
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
||||
self.data_path = Path(data_path)
|
||||
if not dummy_mode:
|
||||
if overwrite_data_path:
|
||||
if self.data_path.exists():
|
||||
shutil.rmtree(self.data_path)
|
||||
self.data_path.mkdir(parents=True)
|
||||
else:
|
||||
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
self.logger: BaseLogger = None
|
||||
self.loader: Optional[BaseLoader] = None
|
||||
self.savers: List[BaseSaver]= []
|
||||
self.dummy_mode = dummy_mode
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict):
|
||||
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||
if self.dummy_mode:
|
||||
# The only thing we need is a loader
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
return
|
||||
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
||||
self.logger.init(full_config, extra_config)
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
for saver in self.savers:
|
||||
saver.init(self.logger)
|
||||
|
||||
def add_logger(self, logger: BaseLogger):
|
||||
self.logger = logger
|
||||
|
||||
def add_loader(self, loader: BaseLoader):
|
||||
self.loader = loader
|
||||
|
||||
def add_saver(self, saver: BaseSaver):
|
||||
self.savers.append(saver)
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log(*args, **kwargs)
|
||||
|
||||
def log_images(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log_images(*args, **kwargs)
|
||||
|
||||
def log_file(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log_file(*args, **kwargs)
|
||||
|
||||
def save_config(self, current_config_path: str, config_name = 'config.json'):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
# Save the config under config_name in the root folder of data_path
|
||||
shutil.copy(current_config_path, self.data_path / config_name)
|
||||
for saver in self.savers:
|
||||
remote_path = Path(saver.save_meta_to) / config_name
|
||||
saver.save_file(current_config_path, str(remote_path))
|
||||
|
||||
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
|
||||
"""
|
||||
Gets the state dict to be saved and writes it to file_path.
|
||||
If save_type is 'checkpoint', we save the entire trainer state dict.
|
||||
If save_type is 'model', we save only the model state dict.
|
||||
"""
|
||||
assert save_type in ['checkpoint', 'model']
|
||||
if save_type == 'checkpoint':
|
||||
trainer.save(file_path, overwrite=True, **kwargs)
|
||||
elif save_type == 'model':
|
||||
if isinstance(trainer, DiffusionPriorTrainer):
|
||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||
state_dict = trainer.unwrap_model(prior).state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
elif isinstance(trainer, DecoderTrainer):
|
||||
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||
if trainer.use_ema:
|
||||
trainable_unets = decoder.unets
|
||||
decoder.unets = trainer.unets # Swap EMA unets in
|
||||
state_dict = decoder.state_dict()
|
||||
decoder.unets = trainable_unets # Swap back
|
||||
else:
|
||||
state_dict = decoder.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
else:
|
||||
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
|
||||
return Path(file_path)
|
||||
|
||||
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
if not is_best and not is_latest:
|
||||
# Nothing to do
|
||||
return
|
||||
# Save the checkpoint and model to data_path
|
||||
checkpoint_path = self.data_path / 'checkpoint.pth'
|
||||
self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
|
||||
model_path = self.data_path / 'model.pth'
|
||||
self._save_state_dict(trainer, 'model', model_path, **kwargs)
|
||||
print("Saved cached models")
|
||||
# Call the save methods on the savers
|
||||
for saver in self.savers:
|
||||
local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
|
||||
if saver.saving_latest and is_latest:
|
||||
latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
|
||||
try:
|
||||
saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
if saver.saving_best and is_best:
|
||||
best_checkpoint_path = saver.save_best_to.format(**kwargs)
|
||||
try:
|
||||
saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
|
||||
def recall(self):
|
||||
if self.loader is not None:
|
||||
return self.loader.recall()
|
||||
else:
|
||||
raise ValueError('No loader specified')
|
||||
|
||||
|
||||
|
||||
@@ -13,8 +13,9 @@ from dalle2_pytorch.dalle2_pytorch import (
|
||||
Decoder,
|
||||
DiffusionPrior,
|
||||
DiffusionPriorNetwork,
|
||||
XClipAdapter,
|
||||
XClipAdapter
|
||||
)
|
||||
from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
|
||||
|
||||
# helper functions
|
||||
|
||||
@@ -44,13 +45,66 @@ class TrainSplitConfig(BaseModel):
|
||||
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||
return fields
|
||||
|
||||
class TrackerLogConfig(BaseModel):
|
||||
log_type: str = 'console'
|
||||
verbose: bool = False
|
||||
|
||||
class Config:
|
||||
# Each individual log type has it's own arguments that will be passed through the config
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
return create_logger(self.log_type, data_path, **kwargs)
|
||||
|
||||
class TrackerLoadConfig(BaseModel):
|
||||
load_from: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
if self.load_from is None:
|
||||
return None
|
||||
return create_loader(self.load_from, data_path, **kwargs)
|
||||
|
||||
class TrackerSaveConfig(BaseModel):
|
||||
save_to: str = 'local'
|
||||
save_all: bool = False
|
||||
save_latest: bool = True
|
||||
save_best: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
return create_saver(self.save_to, data_path, **kwargs)
|
||||
|
||||
class TrackerConfig(BaseModel):
|
||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||
data_path: str = './models' # The path where files will be saved locally
|
||||
init_config: Dict[str, Any] = None
|
||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||
wandb_project: str = ''
|
||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||
data_path: str = '.tracker_data'
|
||||
overwrite_data_path: bool = False
|
||||
log: TrackerLogConfig
|
||||
load: Optional[TrackerLoadConfig]
|
||||
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
|
||||
|
||||
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
|
||||
tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
|
||||
# Add the logger
|
||||
tracker.add_logger(self.log.create(self.data_path))
|
||||
# Add the loader
|
||||
if self.load is not None:
|
||||
tracker.add_loader(self.load.create(self.data_path))
|
||||
# Add the saver or savers
|
||||
if isinstance(self.save, list):
|
||||
for save_config in self.save:
|
||||
tracker.add_saver(save_config.create(self.data_path))
|
||||
else:
|
||||
tracker.add_saver(self.save.create(self.data_path))
|
||||
# Initialize all the components and verify that all data is valid
|
||||
tracker.init(full_config, extra_config)
|
||||
return tracker
|
||||
|
||||
# diffusion prior pydantic classes
|
||||
|
||||
@@ -158,8 +212,11 @@ class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: ListOrTuple(int)
|
||||
image_embed_dim: int = None
|
||||
text_embed_dim: int = None
|
||||
cond_on_text_encodings: bool = None
|
||||
cond_dim: int = None
|
||||
channels: int = 3
|
||||
self_attn: ListOrTuple(int)
|
||||
attn_dim_head: int = 32
|
||||
attn_heads: int = 16
|
||||
|
||||
@@ -170,6 +227,7 @@ class DecoderConfig(BaseModel):
|
||||
unets: ListOrTuple(UnetConfig)
|
||||
image_size: int = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
loss_type: str = 'l2'
|
||||
@@ -180,9 +238,16 @@ class DecoderConfig(BaseModel):
|
||||
|
||||
def create(self):
|
||||
decoder_kwargs = self.dict()
|
||||
|
||||
unet_configs = decoder_kwargs.pop('unets')
|
||||
unets = [Unet(**config) for config in unet_configs]
|
||||
return Decoder(unets, **decoder_kwargs)
|
||||
|
||||
has_clip = exists(decoder_kwargs.pop('clip'))
|
||||
clip = None
|
||||
if has_clip:
|
||||
clip = self.clip.create()
|
||||
|
||||
return Decoder(unets, clip=clip, **decoder_kwargs)
|
||||
|
||||
@validator('image_sizes')
|
||||
def check_image_sizes(cls, image_sizes, values):
|
||||
@@ -194,8 +259,9 @@ class DecoderConfig(BaseModel):
|
||||
extra = "allow"
|
||||
|
||||
class DecoderDataConfig(BaseModel):
|
||||
webdataset_base_url: str # path to a webdataset with jpg images
|
||||
embeddings_url: str # path to .npy files with embeddings
|
||||
webdataset_base_url: str # path to a webdataset with jpg images
|
||||
img_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||
text_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||
num_workers: int = 4
|
||||
batch_size: int = 64
|
||||
start_shard: int = 0
|
||||
@@ -227,6 +293,8 @@ class DecoderTrainConfig(BaseModel):
|
||||
epochs: int = 20
|
||||
lr: SingularOrIterable(float) = 1e-4
|
||||
wd: SingularOrIterable(float) = 0.01
|
||||
warmup_steps: Optional[SingularOrIterable(int)] = None
|
||||
find_unused_parameters: bool = True
|
||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
@@ -236,9 +304,6 @@ class DecoderTrainConfig(BaseModel):
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.999
|
||||
amp: bool = False
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||
|
||||
class DecoderEvaluateConfig(BaseModel):
|
||||
@@ -260,7 +325,6 @@ class TrainDecoderConfig(BaseModel):
|
||||
train: DecoderTrainConfig
|
||||
evaluate: DecoderEvaluateConfig
|
||||
tracker: TrackerConfig
|
||||
load: DecoderLoadConfig
|
||||
seed: int = 0
|
||||
|
||||
@classmethod
|
||||
@@ -268,3 +332,32 @@ class TrainDecoderConfig(BaseModel):
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
@root_validator
|
||||
def check_has_embeddings(cls, values):
|
||||
# Makes sure that enough information is provided to get the embeddings specified for training
|
||||
data_config, decoder_config = values.get('data'), values.get('decoder')
|
||||
|
||||
if not exists(data_config) or not exists(decoder_config):
|
||||
# Then something else errored and we should just pass through
|
||||
return values
|
||||
|
||||
using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
|
||||
using_clip = exists(decoder_config.clip)
|
||||
img_emb_url = data_config.img_embeddings_url
|
||||
text_emb_url = data_config.text_embeddings_url
|
||||
|
||||
if using_text_embeddings:
|
||||
# Then we need some way to get the embeddings
|
||||
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
|
||||
|
||||
if using_clip:
|
||||
if using_text_embeddings:
|
||||
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
|
||||
else:
|
||||
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||
|
||||
if text_emb_url:
|
||||
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||
|
||||
return values
|
||||
|
||||
@@ -3,10 +3,13 @@ import copy
|
||||
from pathlib import Path
|
||||
from math import ceil
|
||||
from functools import partial, wraps
|
||||
from contextlib import nullcontext
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
@@ -14,6 +17,10 @@ from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
import pytorch_warmup as warmup
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
import numpy as np
|
||||
@@ -62,16 +69,6 @@ def num_to_groups(num, divisor):
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
def clamp(value, min_value = None, max_value = None):
|
||||
assert exists(min_value) or exists(max_value)
|
||||
if exists(min_value):
|
||||
value = max(value, min_value)
|
||||
|
||||
if exists(max_value):
|
||||
value = min(value, max_value)
|
||||
|
||||
return value
|
||||
|
||||
# decorators
|
||||
|
||||
def cast_torch_tensor(fn):
|
||||
@@ -145,146 +142,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||
chunk_size_frac = chunk_size / batch_size
|
||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||
|
||||
# saving and loading functions
|
||||
|
||||
# for diffusion prior
|
||||
|
||||
def load_diffusion_model(dprior_path, device):
|
||||
dprior_path = Path(dprior_path)
|
||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
||||
|
||||
# Get hyperparameters of loaded model
|
||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
||||
|
||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
||||
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior, loaded_obj
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
print_ribbon('Saving checkpoint')
|
||||
|
||||
state_dict = dict(model=model.state_dict(),
|
||||
optimizer=optimizer.state_dict(),
|
||||
scaler=scaler.state_dict(),
|
||||
hparams = config,
|
||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
"""
|
||||
Implements exponential moving average shadowing for your model.
|
||||
|
||||
Utilizes an inverse decay schedule to manage longer term training runs.
|
||||
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
|
||||
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
||||
good values for models you plan to train for a million or more steps (reaches decay
|
||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||
215.4k steps).
|
||||
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.9999,
|
||||
update_after_step = 100,
|
||||
update_every = 10,
|
||||
inv_gamma = 1.0,
|
||||
power = 2/3,
|
||||
min_value = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.update_every = update_every
|
||||
self.update_after_step = update_after_step
|
||||
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
def restore_ema_model_device(self):
|
||||
device = self.initted.device
|
||||
self.ema_model.to(device)
|
||||
|
||||
def copy_params_from_model_to_ema(self):
|
||||
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
||||
ma_param.data.copy_(current_param.data)
|
||||
|
||||
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
|
||||
ma_buffer.data.copy_(current_buffer.data)
|
||||
|
||||
def get_current_decay(self):
|
||||
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
|
||||
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
||||
|
||||
if epoch <= 0:
|
||||
return 0.
|
||||
|
||||
return clamp(value, min_value = self.min_value, max_value = self.beta)
|
||||
|
||||
def update(self):
|
||||
step = self.step.item()
|
||||
self.step += 1
|
||||
|
||||
if (step % self.update_every) != 0:
|
||||
return
|
||||
|
||||
if step <= self.update_after_step:
|
||||
self.copy_params_from_model_to_ema()
|
||||
return
|
||||
|
||||
if not self.initted.item():
|
||||
self.copy_params_from_model_to_ema()
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
@torch.no_grad()
|
||||
def update_moving_average(self, ma_model, current_model):
|
||||
current_decay = self.get_current_decay()
|
||||
|
||||
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
|
||||
difference = ma_params.data - current_params.data
|
||||
difference.mul_(1.0 - current_decay)
|
||||
ma_params.sub_(difference)
|
||||
|
||||
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
|
||||
difference = ma_buffer - current_buffer
|
||||
difference.mul_(1.0 - current_decay)
|
||||
ma_buffer.sub_(difference)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
def prior_sample_in_chunks(fn):
|
||||
@@ -310,19 +167,32 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
group_wd_params = True,
|
||||
device = None,
|
||||
accelerator = None,
|
||||
verbose = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
||||
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
# verbosity
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
# assign some helpful member vars
|
||||
|
||||
self.accelerator = accelerator
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
self.text_conditioned = diffusion_prior.condition_on_text_encodings
|
||||
|
||||
# setting the device
|
||||
|
||||
if not exists(accelerator) and not exists(device):
|
||||
diffusion_prior_device = next(diffusion_prior.parameters()).device
|
||||
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
|
||||
self.device = diffusion_prior_device
|
||||
else:
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
|
||||
# save model
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
@@ -358,11 +228,14 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# track steps internally
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
self.register_buffer('step', torch.tensor([0], device = self.device))
|
||||
|
||||
# accelerator wrappers
|
||||
|
||||
def print(self, msg):
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.print(msg)
|
||||
else:
|
||||
@@ -505,26 +378,20 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.p_sample_loop(*args, **kwargs)
|
||||
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||
return model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.sample(*args, **kwargs)
|
||||
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||
return model.sample(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.sample_batch_size(*args, **kwargs)
|
||||
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||
return model.sample_batch_size(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@@ -582,6 +449,7 @@ class DecoderTrainer(nn.Module):
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
warmup_steps = None,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
@@ -603,11 +471,15 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
||||
|
||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
optimizers = []
|
||||
schedulers = []
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
@@ -619,6 +491,13 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
optimizers.append(optimizer)
|
||||
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
|
||||
schedulers.append(scheduler)
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
|
||||
@@ -626,15 +505,27 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
schedulers = list(self.accelerator.prepare(*schedulers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
# store optimizers
|
||||
|
||||
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||
setattr(self, f'optim{opt_ind}', optimizer)
|
||||
|
||||
# store schedulers
|
||||
|
||||
for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
|
||||
setattr(self, f'sched{sched_ind}', scheduler)
|
||||
|
||||
# store warmup schedulers
|
||||
|
||||
self.warmup_schedulers = warmup_schedulers
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
@@ -643,7 +534,7 @@ class DecoderTrainer(nn.Module):
|
||||
save_obj = dict(
|
||||
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
|
||||
version = __version__,
|
||||
step = self.step.item(),
|
||||
steps = self.steps.cpu(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -657,30 +548,38 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.accelerator.save(save_obj, str(path))
|
||||
|
||||
def load_state_dict(self, loaded_obj, only_model = False, strict = True):
|
||||
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.steps.copy_(loaded_obj['steps'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
|
||||
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
|
||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
def load(self, path, only_model = False, strict = True):
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path), map_location = 'cpu')
|
||||
|
||||
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
@@ -688,6 +587,12 @@ class DecoderTrainer(nn.Module):
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
def increment_step(self, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
|
||||
unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
|
||||
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
|
||||
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
@@ -696,17 +601,25 @@ class DecoderTrainer(nn.Module):
|
||||
index = unet_number - 1
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
scheduler = getattr(self, f'sched{index}')
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
warmup_scheduler = self.warmup_schedulers[index]
|
||||
scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
|
||||
|
||||
with scheduler_context():
|
||||
scheduler.step()
|
||||
|
||||
if self.use_ema:
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
self.step += 1
|
||||
self.increment_step(unet_number)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@@ -730,6 +643,18 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_text(self, *args, **kwargs):
|
||||
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_image(self, *args, **kwargs):
|
||||
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
@@ -744,7 +669,6 @@ class DecoderTrainer(nn.Module):
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
# with autocast(enabled = self.amp):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import time
|
||||
import importlib
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# time helpers
|
||||
|
||||
class Timer:
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.11.2'
|
||||
__version__ = '0.16.16'
|
||||
|
||||
@@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from dalle2_pytorch.train import EMA
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
@@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module):
|
||||
valid_frac = 0.05,
|
||||
random_split_seed = 42,
|
||||
ema_beta = 0.995,
|
||||
ema_update_after_step = 2000,
|
||||
ema_update_after_step = 500,
|
||||
ema_update_every = 10,
|
||||
apply_grad_penalty_every = 4,
|
||||
amp = False
|
||||
|
||||
183
prior.md
Normal file
183
prior.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# Diffusion Prior
|
||||
This readme serves as an introduction to the diffusion prior.
|
||||
|
||||
## Intro
|
||||
|
||||
A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful.
|
||||
|
||||
### Motivation
|
||||
|
||||
Before we dive into the model, let’s look at a quick example of where the model may be helpful.
|
||||
|
||||
For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.
|
||||
|
||||
> [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets.
|
||||
|
||||
```python
|
||||
# Load Models
|
||||
clip_model = clip.load("ViT-L/14")
|
||||
decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings
|
||||
|
||||
# Retrieve prompt from user and encode with CLIP
|
||||
prompt = "A corgi wearing sunglasses"
|
||||
tokenized_text = tokenize(prompt)
|
||||
text_embedding = clip_model.encode_text(tokenized_text)
|
||||
|
||||
# Now, pass the text embedding to the decoder
|
||||
predicted_image = decoder.sample(text_embedding)
|
||||
```
|
||||
|
||||
> **Question**: *Can you spot the issue here?*
|
||||
>
|
||||
> **Answer**: *We’re trying to generate an image from a text embedding!*
|
||||
|
||||
Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution
|
||||
|
||||
```python
|
||||
# Load Models
|
||||
prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb
|
||||
decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings
|
||||
|
||||
# Retrieve prompt from user and encode with a prior
|
||||
prompt = "A corgi wearing sunglasses"
|
||||
tokenized_text = tokenize(prompt)
|
||||
text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!
|
||||
|
||||
# Now, pass the predicted image embedding to the decoder
|
||||
predicted_image = decoder.sample(text_embedding)
|
||||
```
|
||||
|
||||
With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.
|
||||
|
||||
> **You may be asking yourself the following question:**
|
||||
>
|
||||
> *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"*
|
||||
>
|
||||
> OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile:
|
||||
|
||||
## Usage
|
||||
|
||||
To utilize a pre-trained prior, it’s quite simple.
|
||||
|
||||
### Loading Checkpoints
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer
|
||||
|
||||
def load_diffusion_model(dprior_path):
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim=768,
|
||||
depth=24,
|
||||
dim_head=64,
|
||||
heads=32,
|
||||
normformer=True,
|
||||
attn_dropout=5e-2,
|
||||
ff_dropout=5e-2,
|
||||
num_time_embeds=1,
|
||||
num_image_embeds=1,
|
||||
num_text_embeds=1,
|
||||
num_timesteps=1000,
|
||||
ff_mult=4
|
||||
)
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net=prior_network,
|
||||
clip=OpenAIClipAdapter("ViT-L/14"),
|
||||
image_embed_dim=768,
|
||||
timesteps=1000,
|
||||
cond_drop_prob=0.1,
|
||||
loss_type="l2",
|
||||
condition_on_text_encodings=True,
|
||||
|
||||
)
|
||||
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior=diffusion_prior,
|
||||
lr=1.1e-4,
|
||||
wd=6.02e-2,
|
||||
max_grad_norm=0.5,
|
||||
amp=False,
|
||||
group_wd_params=True,
|
||||
use_ema=True,
|
||||
device=device,
|
||||
accelerator=None,
|
||||
)
|
||||
|
||||
trainer.load(dprior_path)
|
||||
|
||||
return trainer
|
||||
```
|
||||
|
||||
Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*)
|
||||
|
||||
### Sampling
|
||||
Once we have a pre-trained model, generating embeddings is quite simple!
|
||||
```python
|
||||
# tokenize the text
|
||||
tokenized_text = clip.tokenize("<your amazing prompt>")
|
||||
# predict an embedding
|
||||
predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)
|
||||
```
|
||||
|
||||
The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768).
|
||||
|
||||
> For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().
|
||||
|
||||
**Some things to note:**
|
||||
* It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.
|
||||
* You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*.
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
### Overview
|
||||
|
||||
Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration
|
||||
|
||||
## Dataset
|
||||
|
||||
To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader.
|
||||
|
||||
## Configuration
|
||||
|
||||
The configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that will specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.
|
||||
|
||||
## Distributed Training
|
||||
|
||||
If you would like to train in a distributed manner we have opted to leverage huggingface’ new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPU’s and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator).
|
||||
|
||||
## Evaluation
|
||||
|
||||
There are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:
|
||||
| Metric | Description | Comments |
|
||||
| ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Online Model Validation | The validation loss associated with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. |
|
||||
| EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your "online" model's validation loss, but should outperform in the long-term. |
|
||||
| Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. |
|
||||
| Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) |
|
||||
| Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting somehow. |
|
||||
| Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. |
|
||||
| Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. |
|
||||
|
||||
## Launching the script
|
||||
|
||||
Now that you’ve done all the prep it’s time for the easy part! 🚀
|
||||
|
||||
To actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path <path to your config>` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate.
|
||||
|
||||
## Checkpointing
|
||||
|
||||
Checkpoints will be saved to the directory specified in your configuration file.
|
||||
|
||||
Additionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and titled “latest.pth”. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data.
|
||||
|
||||
## Things To Keep In Mind
|
||||
|
||||
The prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet.
|
||||
|
||||
As we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.
|
||||
|
||||
With that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you don’t see documentation for!
|
||||
2
setup.py
2
setup.py
@@ -28,6 +28,7 @@ setup(
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'coca-pytorch>=0.0.5',
|
||||
'ema-pytorch>=0.0.7',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
@@ -36,6 +37,7 @@ setup(
|
||||
'packaging',
|
||||
'pillow',
|
||||
'pydantic',
|
||||
'pytorch-warmup',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
'torch>=1.10',
|
||||
|
||||
282
train_decoder.py
282
train_decoder.py
@@ -1,11 +1,13 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from dalle2_pytorch.trainer import DecoderTrainer
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.trackers import Tracker
|
||||
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
|
||||
from clip import tokenize
|
||||
|
||||
import torchvision
|
||||
import torch
|
||||
@@ -33,7 +35,8 @@ def exists(val):
|
||||
def create_dataloaders(
|
||||
available_shards,
|
||||
webdataset_base_url,
|
||||
embeddings_url,
|
||||
img_embeddings_url=None,
|
||||
text_embeddings_url=None,
|
||||
shard_width=6,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
@@ -63,14 +66,15 @@ def create_dataloaders(
|
||||
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
||||
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
||||
|
||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
|
||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
|
||||
tar_url=tar_urls,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size if not for_sampling else n_sample_images,
|
||||
embeddings_url=embeddings_url,
|
||||
img_embeddings_url=img_embeddings_url,
|
||||
text_embeddings_url=text_embeddings_url,
|
||||
index_width=index_width,
|
||||
shuffle_num = None,
|
||||
extra_keys= ["txt"] if with_text else [],
|
||||
extra_keys= ["txt"],
|
||||
shuffle_shards = shuffle,
|
||||
resample_shards = resample,
|
||||
img_preproc=img_preproc,
|
||||
@@ -79,8 +83,8 @@ def create_dataloaders(
|
||||
|
||||
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
||||
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
||||
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
|
||||
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
|
||||
val_dataloader = create_dataloader(val_urls, shuffle=False)
|
||||
test_dataloader = create_dataloader(test_urls, shuffle=False)
|
||||
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
||||
return {
|
||||
"train": train_dataloader,
|
||||
@@ -104,42 +108,65 @@ def get_example_data(dataloader, device, n=5):
|
||||
Samples the dataloader and returns a zipped list of examples
|
||||
"""
|
||||
images = []
|
||||
embeddings = []
|
||||
img_embeddings = []
|
||||
text_embeddings = []
|
||||
captions = []
|
||||
dataset_keys = get_dataset_keys(dataloader)
|
||||
has_caption = "txt" in dataset_keys
|
||||
for data in dataloader:
|
||||
if has_caption:
|
||||
img, emb, txt = data
|
||||
for img, emb, txt in dataloader:
|
||||
img_emb, text_emb = emb.get('img'), emb.get('text')
|
||||
if img_emb is not None:
|
||||
img_emb = img_emb.to(device=device, dtype=torch.float)
|
||||
img_embeddings.extend(list(img_emb))
|
||||
else:
|
||||
img, emb = data
|
||||
txt = [""] * emb.shape[0]
|
||||
# Then we add None img.shape[0] times
|
||||
img_embeddings.extend([None]*img.shape[0])
|
||||
if text_emb is not None:
|
||||
text_emb = text_emb.to(device=device, dtype=torch.float)
|
||||
text_embeddings.extend(list(text_emb))
|
||||
else:
|
||||
# Then we add None img.shape[0] times
|
||||
text_embeddings.extend([None]*img.shape[0])
|
||||
img = img.to(device=device, dtype=torch.float)
|
||||
emb = emb.to(device=device, dtype=torch.float)
|
||||
images.extend(list(img))
|
||||
embeddings.extend(list(emb))
|
||||
captions.extend(list(txt))
|
||||
if len(images) >= n:
|
||||
break
|
||||
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, text_prepend=""):
|
||||
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
"""
|
||||
real_images, embeddings, txts = zip(*example_data)
|
||||
embeddings_tensor = torch.stack(embeddings)
|
||||
samples = trainer.sample(embeddings_tensor)
|
||||
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
|
||||
sample_params = {}
|
||||
if img_embeddings[0] is None:
|
||||
# Generate image embeddings from clip
|
||||
imgs_tensor = torch.stack(real_images)
|
||||
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
else:
|
||||
# Then we are using precomputed image embeddings
|
||||
img_embeddings = torch.stack(img_embeddings)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
if condition_on_text_encodings:
|
||||
if text_embeddings[0] is None:
|
||||
# Generate text embeddings from text
|
||||
tokenized_texts = tokenize(txts, truncate=True)
|
||||
sample_params["text"] = tokenized_texts
|
||||
else:
|
||||
# Then we are using precomputed text embeddings
|
||||
text_embeddings = torch.stack(text_embeddings)
|
||||
sample_params["text_encodings"] = text_embeddings
|
||||
samples = trainer.sample(**sample_params)
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
||||
"""
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
||||
|
||||
real_image_size = real_images[0].shape[-1]
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
@@ -151,7 +178,7 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
"""
|
||||
Computes evaluation metrics for the decoder
|
||||
"""
|
||||
@@ -161,7 +188,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
|
||||
if len(examples) == 0:
|
||||
print("No data to evaluate. Check that your dataloader has shards.")
|
||||
return metrics
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
@@ -213,43 +240,35 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
|
||||
metrics[metric_name] = metrics_tensor[i].item()
|
||||
return metrics
|
||||
|
||||
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
|
||||
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
|
||||
"""
|
||||
Logs the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
if isinstance(relative_paths, str):
|
||||
relative_paths = [relative_paths]
|
||||
for relative_path in relative_paths:
|
||||
local_path = str(tracker.data_path / relative_path)
|
||||
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
|
||||
tracker.save_file(local_path)
|
||||
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
|
||||
|
||||
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
||||
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
|
||||
"""
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
|
||||
local_filepath = tracker.recall_file(recall_source, **load_config)
|
||||
state_dict = trainer.load(local_filepath)
|
||||
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
|
||||
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
|
||||
state_dict = tracker.recall()
|
||||
trainer.load_state_dict(state_dict, only_model=False, strict=True)
|
||||
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
|
||||
|
||||
def train(
|
||||
dataloaders,
|
||||
decoder,
|
||||
accelerator,
|
||||
tracker,
|
||||
decoder: Decoder,
|
||||
accelerator: Accelerator,
|
||||
tracker: Tracker,
|
||||
inference_device,
|
||||
load_config=None,
|
||||
evaluate_config=None,
|
||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||
validation_samples = None,
|
||||
epochs = 20,
|
||||
n_sample_images = 5,
|
||||
save_every_n_samples = 100000,
|
||||
save_all=False,
|
||||
save_latest=True,
|
||||
save_best=True,
|
||||
unet_training_mask=None,
|
||||
condition_on_text_encodings=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -258,8 +277,8 @@ def train(
|
||||
is_master = accelerator.process_index == 0
|
||||
|
||||
trainer = DecoderTrainer(
|
||||
accelerator,
|
||||
decoder,
|
||||
decoder=decoder,
|
||||
accelerator=accelerator,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -268,16 +287,17 @@ def train(
|
||||
validation_losses = []
|
||||
next_task = 'train'
|
||||
sample = 0
|
||||
samples_seen = 0
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.step.item())
|
||||
|
||||
if exists(load_config) and exists(load_config.source):
|
||||
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
|
||||
if tracker.loader is not None:
|
||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||
if next_task == 'train':
|
||||
sample = recalled_sample
|
||||
if next_task == 'val':
|
||||
val_sample = recalled_sample
|
||||
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
||||
trainer.to(device=inference_device)
|
||||
|
||||
@@ -306,13 +326,22 @@ def train(
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
total_samples = all_samples.sum().item()
|
||||
sample += total_samples
|
||||
img, emb = send_to_device((img, emb))
|
||||
samples_seen += total_samples
|
||||
img_emb = emb.get('img')
|
||||
has_img_embedding = img_emb is not None
|
||||
if has_img_embedding:
|
||||
img_emb, = send_to_device((img_emb,))
|
||||
text_emb = emb.get('text')
|
||||
has_text_embedding = text_emb is not None
|
||||
if has_text_embedding:
|
||||
text_emb, = send_to_device((text_emb,))
|
||||
img, = send_to_device((img,))
|
||||
|
||||
trainer.train()
|
||||
for unet in range(1, trainer.num_unets+1):
|
||||
@@ -320,7 +349,20 @@ def train(
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
continue
|
||||
|
||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
||||
forward_params = {}
|
||||
if has_img_embedding:
|
||||
forward_params['image_embed'] = img_emb
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||
|
||||
@@ -334,31 +376,32 @@ def train(
|
||||
mask = unet_all_losses != 0
|
||||
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
||||
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
||||
|
||||
# gather decay rate on each UNet
|
||||
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
|
||||
|
||||
log_data = {
|
||||
"Epoch": epoch,
|
||||
"Sample": sample,
|
||||
"Step": i,
|
||||
"Samples per second": samples_per_sec,
|
||||
"Samples Seen": samples_seen,
|
||||
**ema_decay_list,
|
||||
**loss_map
|
||||
}
|
||||
# print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}")
|
||||
|
||||
if is_master:
|
||||
tracker.log(log_data, step=step(), verbose=True)
|
||||
tracker.log(log_data, step=step())
|
||||
|
||||
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
||||
print("Saving snapshot")
|
||||
last_snapshot = sample
|
||||
# We need to know where the model should be saved
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_all:
|
||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
if epoch_samples is not None and sample >= epoch_samples:
|
||||
@@ -381,14 +424,35 @@ def train(
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
val_sample += total_samples
|
||||
img, emb = send_to_device((img, emb))
|
||||
img_emb = emb.get('img')
|
||||
has_img_embedding = img_emb is not None
|
||||
if has_img_embedding:
|
||||
img_emb, = send_to_device((img_emb,))
|
||||
text_emb = emb.get('text')
|
||||
has_text_embedding = text_emb is not None
|
||||
if has_text_embedding:
|
||||
text_emb, = send_to_device((text_emb,))
|
||||
img, = send_to_device((img,))
|
||||
|
||||
for unet in range(1, len(decoder.unets)+1):
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
# No need to evaluate an unchanging unet
|
||||
continue
|
||||
|
||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
||||
|
||||
forward_params = {}
|
||||
if has_img_embedding:
|
||||
forward_params['image_embed'] = img_emb.float()
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb.float()
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
|
||||
average_val_loss_tensor[0, unet-1] += loss
|
||||
|
||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||
@@ -409,15 +473,15 @@ def train(
|
||||
if is_master:
|
||||
unet_average_val_loss = all_average_val_losses.mean(dim=0)
|
||||
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
|
||||
tracker.log(val_loss_map, step=step(), verbose=True)
|
||||
tracker.log(val_loss_map, step=step())
|
||||
next_task = 'eval'
|
||||
|
||||
if next_task == 'eval':
|
||||
if exists(evaluate_config):
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step(), verbose=True)
|
||||
tracker.log(evaluation, step=step())
|
||||
next_task = 'sample'
|
||||
val_sample = 0
|
||||
|
||||
@@ -426,28 +490,22 @@ def train(
|
||||
# Generate examples and save the model if we are the master
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||
# Get the same paths
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
is_best = False
|
||||
if all_average_val_losses is not None:
|
||||
average_loss = all_average_val_losses.mean(dim=0).item()
|
||||
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
||||
save_paths.append("best.pth")
|
||||
if len(validation_losses) == 0 or average_loss < min(validation_losses):
|
||||
is_best = True
|
||||
validation_losses.append(average_loss)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
|
||||
next_task = 'train'
|
||||
|
||||
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
|
||||
"""
|
||||
Creates a tracker of the specified type and initializes special features based on the full config
|
||||
"""
|
||||
def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
|
||||
tracker_config = config.tracker
|
||||
accelerator_config = {
|
||||
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
|
||||
@@ -455,40 +513,16 @@ def create_tracker(accelerator, config, config_path, tracker_type=None, data_pat
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision
|
||||
}
|
||||
init_config = { "config": {**config.dict(), **accelerator_config} }
|
||||
data_path = data_path or tracker_config.data_path
|
||||
tracker_type = tracker_type or tracker_config.tracker_type
|
||||
|
||||
if tracker_type == "dummy":
|
||||
tracker = DummyTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "console":
|
||||
tracker = ConsoleTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "wandb":
|
||||
# We need to initialize the resume state here
|
||||
load_config = config.load
|
||||
if load_config.source == "wandb" and load_config.resume:
|
||||
# Then we are resuming the run load_config["run_path"]
|
||||
run_id = load_config.run_path.split("/")[-1]
|
||||
init_config["id"] = run_id
|
||||
init_config["resume"] = "must"
|
||||
|
||||
init_config["entity"] = tracker_config.wandb_entity
|
||||
init_config["project"] = tracker_config.wandb_project
|
||||
tracker = WandbTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
|
||||
else:
|
||||
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
||||
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||
return tracker
|
||||
|
||||
def initialize_training(config, config_path):
|
||||
def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
# Make sure if we are not loading, distributed models are initialized to the same values
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
# Set up data
|
||||
@@ -515,16 +549,38 @@ def initialize_training(config, config_path):
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
||||
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
|
||||
|
||||
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
|
||||
|
||||
has_clip_model = config.decoder.clip is not None
|
||||
data_source_string = ""
|
||||
|
||||
if has_img_embeddings:
|
||||
data_source_string += "precomputed image embeddings"
|
||||
elif has_clip_model:
|
||||
data_source_string += "clip image embeddings generation"
|
||||
else:
|
||||
raise ValueError("No image embeddings source specified")
|
||||
if conditioning_on_text:
|
||||
if has_text_embeddings:
|
||||
data_source_string += " and precomputed text embeddings"
|
||||
elif has_clip_model:
|
||||
data_source_string += " and clip text encoding generation"
|
||||
else:
|
||||
raise ValueError("No text embeddings source specified")
|
||||
|
||||
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
||||
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||
train(dataloaders, decoder, accelerator,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
condition_on_text_encodings=conditioning_on_text,
|
||||
**config.train.dict(),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user