Compare commits

...

39 Commits

Author SHA1 Message Date
Phil Wang
b7e22f7da0 complete ddim integration of diffusion prior as well as decoder for each unet, feature complete for https://github.com/lucidrains/DALLE2-pytorch/issues/157 2022-07-09 17:25:34 -07:00
Romain Beaumont
68de937aac Fix decoder test by fixing the resizing output size (#197) 2022-07-09 07:48:07 -07:00
Phil Wang
097afda606 0.18.0 2022-07-08 18:18:38 -07:00
Aidan Dempster
5c520db825 Added deepspeed support (#195) 2022-07-08 18:18:08 -07:00
Phil Wang
3070610231 just force it so researcher can never pass in an image that is less than the size that is required for CLIP or CoCa 2022-07-08 18:17:29 -07:00
Aidan Dempster
870aeeca62 Fixed issue where evaluation would error when large image was loaded (#194) 2022-07-08 17:11:34 -07:00
Romain Beaumont
f28dc6dc01 setup simple ci (#193) 2022-07-08 16:51:56 -07:00
Phil Wang
081d8d3484 0.17.0 2022-07-08 13:36:26 -07:00
Aidan Dempster
a71f693a26 Add the ability to auto restart the last run when started after a crash (#191)
* Added autoresume after crash functionality to the trackers

* Updated documentation

* Clarified what goes in the autorestart object

* Fixed style issues

Unraveled conditional block

Chnaged to using helper function to get step count
2022-07-08 13:35:40 -07:00
Phil Wang
d7bc5fbedd expose num_steps_taken helper method on trainer to retrieve number of training steps of each unet 2022-07-08 13:00:56 -07:00
Phil Wang
8c823affff allow for control over use of nearest interp method of downsampling low res conditioning, in addition to being able to turn it off 2022-07-08 11:44:43 -07:00
Phil Wang
ec7cab01d9 extra insurance that diffusion prior is on the correct device, when using trainer with accelerator or device was given 2022-07-07 10:08:33 -07:00
Phil Wang
46be8c32d3 fix a potential issue in the low resolution conditioner, when downsampling and then upsampling using resize right, thanks to @marunine 2022-07-07 09:41:49 -07:00
Phil Wang
900f086a6d fix condition_on_text_encodings in dalle2 orchestrator class, fix readme 2022-07-07 07:43:41 -07:00
zion
b3e646fd3b add readme for prior (#159)
* add readme for prior

* offload prior info in main readme

* typos
2022-07-06 20:50:52 -07:00
Phil Wang
6a59c7093d more shots in the dark regarding fp16 with learned variance for deepspeed issue 2022-07-06 19:05:50 -07:00
Phil Wang
a6cdbe0b9c relax learning rate constraint, as @rom1504 wants to try a higher one 2022-07-06 18:09:11 -07:00
Phil Wang
e928ae5c34 default the device to the device that the diffusion prior parameters are on, if the trainer was never given the accelerator nor device 2022-07-06 12:47:48 -07:00
Phil Wang
1bd8a7835a attempting to fix issue with deepspeed fp16 seeing overflowing gradient 2022-07-06 08:27:34 -07:00
Phil Wang
f33453df9f debugging with Aidan 2022-07-05 18:22:43 -07:00
Phil Wang
1e4bb2bafb cast long as float before deriving sinusoidal pos emb 2022-07-05 18:01:22 -07:00
Phil Wang
ee75515c7d remove forcing of softmax in f32, in case it is interfering with deepspeed 2022-07-05 16:53:58 -07:00
Phil Wang
ec68243479 set ability to do warmup steps for each unet during training 2022-07-05 16:24:16 -07:00
Phil Wang
3afdcdfe86 need to keep track of training steps separately for each unet in decoder trainer 2022-07-05 15:17:59 -07:00
Phil Wang
b9a908ff75 bring in two tricks from the cogview paper for reducing the chances of overflow, for attention and layernorm 2022-07-05 14:27:04 -07:00
Phil Wang
e1fe3089df do bias-less layernorm manually 2022-07-05 13:09:58 -07:00
Phil Wang
6d477d7654 link to dalle2 laion 2022-07-05 11:43:07 -07:00
Phil Wang
531fe4b62f status 2022-07-05 10:46:55 -07:00
Phil Wang
ec5a77fc55 0.15.4 2022-07-02 08:56:34 -07:00
Aidan Dempster
fac63c61bc Fixed variable naming issue (#183) 2022-07-02 08:56:03 -07:00
Phil Wang
3d23ba4aa5 add ability to specify full self attention on specific stages in the unet 2022-07-01 10:22:07 -07:00
Phil Wang
282c35930f 0.15.2 2022-07-01 09:40:11 -07:00
Aidan Dempster
27b0f7ca0d Overhauled the tracker system (#172)
* Overhauled the tracker system
Separated the logging and saving capabilities
Changed creation to be consistent and initializing behavior to be defined by a class initializer instead of in the training script
Added class separation between different types of loaders and savers to make the system more verbose

* Changed the saver system to only save the checkpoint once

* Added better error handling for saving checkpoints

* Fixed an error where wandb would error when passed arbitrary kwargs

* Fixed variable naming issues for improved saver
Added more logging during long pauses

* Fixed which methods need to be dummy to immediatly return
Added the ability to set whether you find unused parameters

* Added more logging for when a wandb loader fails
2022-07-01 09:39:40 -07:00
Phil Wang
7b0edf9e42 allow for returning low resolution conditioning image on forward through decoder with return_lowres_cond_image flag 2022-07-01 09:35:39 -07:00
Phil Wang
a922a539de bring back convtranspose2d upsampling, allow for nearest upsample with hyperparam, change kernel size of last conv to 1, make configurable, cleanup 2022-07-01 09:21:47 -07:00
Phil Wang
8f2466f1cd blur sigma for upsampling training was 0.6 in the paper, make that the default value 2022-06-30 17:03:16 -07:00
Phil Wang
908ab83799 add skip connections for all intermediate resnet blocks, also add an extra resnet block for memory efficient version of unet, time condition for both initial resnet block and last one before output 2022-06-29 08:16:58 -07:00
Phil Wang
46a2558d53 bug in pydantic decoder config class 2022-06-29 07:17:35 -07:00
yytdfc
86109646e3 fix a bug of name error (#179) 2022-06-29 07:16:44 -07:00
27 changed files with 1613 additions and 411 deletions

33
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Continuous integration
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install
run: |
python3 -m venv .env
source .env/bin/activate
make install
- name: Tests
run: |
source .env/bin/activate
make test

2
.gitignore vendored
View File

@@ -136,3 +136,5 @@ dmypy.json
# Pyre type checker
.pyre/
.tracker_data
*.pth

6
Makefile Normal file
View File

@@ -0,0 +1,6 @@
install:
pip install -U pip
pip install -e .
test:
CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json

View File

@@ -20,18 +20,20 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
<img src="./samples/oxford.png" width="600px" />
<img src="./samples/oxford.png" width="450px" />
*ongoing at 21k steps*
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
- <a href="https://github.com/rom1504">Romain</a> has scaled up training to 800 GPUs with the available scripts without any issues
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
- DALL-E 2 🚧
- DALL-E 2 🚧 - <a href="https://github.com/LAION-AI/dalle2-laion">DALL-E 2 Laion repository</a>
## Appreciation
@@ -42,6 +44,7 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
@@ -579,7 +582,9 @@ unet1 = Unet(
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
dim_mults=(1, 2, 4, 8),
text_embed_dim = 512,
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()
unet2 = Unet(
@@ -594,14 +599,14 @@ decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
timesteps = 1000,
sample_timesteps = (250, 27),
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
text_cond_drop_prob = 0.5
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -988,34 +993,7 @@ dataset = ImageEmbeddingDataset(
#### `train_diffusion_prior.py`
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
#### Usage
```bash
$ python train_diffusion_prior.py
```
The most significant parameters for the script are as follows:
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
- `learning-rate`, default = `1.1e-4`
- `weight-decay`, default = `6.02e-2`
- `max-grad-norm`, default = `0.5`
- `batch-size`, default = `10 ** 4`
- `num-epochs`, default = `5`
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
## CLI (wip)
@@ -1112,15 +1090,6 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022},
url = {https://arxiv.org/abs/2204.01697}
}
```
```bibtex
@article{Yu2021VectorquantizedIM,
title = {Vector-quantized Image Modeling with Improved VQGAN},

View File

@@ -30,6 +30,7 @@ Defines the configuration options for the decoder model. The unets defined above
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
| `learned_variance` | No | `True` | Whether to learn the variance. |
| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
Any parameter from the `Decoder` constructor can also be given here.
@@ -39,7 +40,8 @@ Settings for creation of the dataloaders.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
| `batch_size` | No | `64` | The batch size. |
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
@@ -91,21 +93,95 @@ Each metric can be enabled by setting its configuration. The configuration keys
**<ins>Tracker</ins>:**
Selects which tracker to use and configures it.
Selects how the experiment will be tracked.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
| `data_path` | No | `./models` | Where the tracker will store local data. |
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
| `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |
| `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |
| `log` | Yes | N/A | Logging configuration. |
| `load` | No | `None` | Checkpoint loading configuration. |
| `save` | Yes | N/A | Checkpoint/Model saving configuration. |
Tracking is split up into three sections:
* Log: Where to save run metadata and image output. Options are `console` or `wandb`.
* Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.
* Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
**Logging:**
**<ins>Load</ins>:**
Selects where to load a pretrained model from.
All loggers have the following keys:
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `source` | No | `None` | Supports `file` or `wandb`. |
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
| `log_type` | Yes | N/A | The type of logger class to use. |
| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
If using `console` there is no further configuration than setting `log_type` to `console`.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `log_type` | Yes | N/A | Must be `console`. |
If using `wandb`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `log_type` | Yes | N/A | Must be `wandb`. |
| `wandb_entity` | Yes | N/A | The wandb entity to log to. |
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
| `wandb_run_name` | No | `None` | The wandb run name. |
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
**Loading:**
All loaders have the following keys:
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | The type of loader class to use. |
| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
If using `local`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | Must be `local`. |
| `file_path` | Yes | N/A | The path to the checkpoint file. |
If using `url`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | Must be `url`. |
| `url` | Yes | N/A | The url of the checkpoint file. |
If using `wandb`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | Must be `wandb`. |
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |
| `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |
**Saving:**
Unlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.
All save locations have these configuration options
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
If using `local`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`. |
If using `huggingface`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `huggingface`. |
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
If using `wandb`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `wandb`. |
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |

View File

@@ -20,7 +20,7 @@
},
"data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"embeddings_url": "s3://bucket/embeddings/path/",
"img_embeddings_url": "s3://bucket/img_embeddings/path/",
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
@@ -80,20 +80,32 @@
}
},
"tracker": {
"tracker_type": "console",
"data_path": "./models",
"overwrite_data_path": true,
"wandb_entity": "",
"wandb_project": "",
"log": {
"log_type": "wandb",
"verbose": false
},
"load": {
"source": null,
"wandb_entity": "your_wandb",
"wandb_project": "your_project",
"run_path": "",
"file_path": "",
"verbose": true
},
"resume": false
"load": {
"load_from": null
},
"save": [{
"save_to": "wandb"
}, {
"save_to": "huggingface",
"huggingface_repo": "Veldrovive/test_model",
"save_all": true,
"save_latest": true,
"save_best": true,
"save_type": "model"
}]
}
}

View File

@@ -0,0 +1,102 @@
{
"decoder": {
"unets": [
{
"dim": 16,
"image_embed_dim": 768,
"cond_dim": 16,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 16,
"attn_heads": 4,
"self_attn": [false, true, true, true]
}
],
"clip": {
"make": "openai",
"model": "ViT-L/14"
},
"timesteps": 10,
"image_sizes": [64],
"channels": 3,
"loss_type": "l2",
"beta_schedule": ["cosine"],
"learned_variance": true
},
"data": {
"webdataset_base_url": "test_data/{}.tar",
"num_workers": 4,
"batch_size": 4,
"start_shard": 0,
"end_shard": 9,
"shard_width": 1,
"index_width": 1,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": false,
"resample_train": true,
"preprocessing": {
"RandomResizedCrop": {
"size": [224, 224],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 1,
"lr": 1e-16,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100,
"n_sample_images": 1,
"device": "cpu",
"epoch_samples": 50,
"validation_samples": 5,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
"n_evaluation_samples": 2,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 2
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"overwrite_data_path": true,
"log": {
"log_type": "console"
},
"load": {
"load_from": null
},
"save": [{
"save_to": "local"
}]
}
}

View File

@@ -45,6 +45,11 @@ def exists(val):
def identity(t, *args, **kwargs):
return t
def first(arr, d = None):
if len(arr) == 0:
return d
return arr[0]
def maybe(fn):
@wraps(fn)
def inner(x):
@@ -58,11 +63,16 @@ def default(val, d):
return val
return d() if callable(d) else d
def cast_tuple(val, length = 1):
def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else ((val,) * length)
out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(out) == length
return out
def module_device(module):
return next(module.parameters()).device
@@ -115,14 +125,28 @@ def log(t, eps = 1e-12):
def l2norm(t):
return F.normalize(t, dim = -1)
def resize_image_to(image, target_image_size):
def resize_image_to(
image,
target_image_size,
clamp_range = None,
nearest = False,
**kwargs
):
orig_image_size = image.shape[-1]
if orig_image_size == target_image_size:
return image
scale_factors = target_image_size / orig_image_size
return resize(image, scale_factors = scale_factors)
if not nearest:
scale_factors = target_image_size / orig_image_size
out = resize(image, scale_factors = scale_factors, **kwargs)
else:
out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False)
if exists(clamp_range):
out = out.clamp(*clamp_range)
return out
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
@@ -145,6 +169,11 @@ class BaseClipAdapter(nn.Module):
self.clip = clip
self.overrides = kwargs
def validate_and_resize_image(self, image):
image_size = image.shape[-1]
assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}'
return resize_image_to(image, self.image_size)
@property
def dim_latent(self):
raise NotImplementedError
@@ -195,7 +224,7 @@ class XClipAdapter(BaseClipAdapter):
@torch.no_grad()
def embed_image(self, image):
image = resize_image_to(image, self.image_size)
image = self.validate_and_resize_image(image)
encoder_output = self.clip.visual_transformer(image)
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
image_embed = self.clip.to_visual_latent(image_cls)
@@ -230,7 +259,7 @@ class CoCaAdapter(BaseClipAdapter):
@torch.no_grad()
def embed_image(self, image):
image = resize_image_to(image, self.image_size)
image = self.validate_and_resize_image(image)
image_embed, image_encodings = self.clip.embed_image(image)
return EmbeddedImage(image_embed, image_encodings)
@@ -291,7 +320,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = resize_image_to(image, self.image_size)
image = self.validate_and_resize_image(image)
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
@@ -325,21 +354,25 @@ def approx_standard_normal_cdf(x):
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
assert x.shape == means.shape == log_scales.shape
# attempting to correct nan gradients when learned variance is turned on
# in the setting of deepspeed fp16
eps = 1e-12 if x.dtype == torch.float32 else 1e-3
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1. / 255.)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = log(cdf_plus)
log_one_minus_cdf_min = log(1. - cdf_min)
log_cdf_plus = log(cdf_plus, eps = eps)
log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(x < -thres,
log_cdf_plus,
torch.where(x > thres,
log_one_minus_cdf_min,
log(cdf_delta)))
log(cdf_delta, eps = eps)))
return log_probs
@@ -351,7 +384,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas_cumprod = alphas_cumprod / first(alphas_cumprod)
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
@@ -472,6 +505,12 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
return (
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def p2_reweigh_loss(self, loss, times):
if not self.has_p2_loss_reweighting:
return loss
@@ -480,14 +519,16 @@ class NoiseScheduler(nn.Module):
# diffusion prior
class LayerNorm(nn.Module):
def __init__(self, dim):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
@@ -496,10 +537,10 @@ class ChanLayerNorm(nn.Module):
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
return (x - mean) * (var + self.eps).rsqrt() * self.g
class Residual(nn.Module):
def __init__(self, fn):
@@ -619,10 +660,13 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
rotary_emb = None
rotary_emb = None,
pb_relax_alpha = 32 ** 2
):
super().__init__()
self.scale = dim_head ** -0.5
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.heads = heads
inner_dim = dim_head * heads
@@ -686,7 +730,10 @@ class Attention(nn.Module):
# attention
attn = sim.softmax(dim = -1, dtype = torch.float32)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate values
@@ -870,6 +917,7 @@ class DiffusionPrior(nn.Module):
image_size = None,
image_channels = 3,
timesteps = 1000,
sample_timesteps = None,
cond_drop_prob = 0.,
loss_type = "l2",
predict_x_start = True,
@@ -883,6 +931,8 @@ class DiffusionPrior(nn.Module):
):
super().__init__()
self.sample_timesteps = sample_timesteps
self.noise_scheduler = NoiseScheduler(
beta_schedule = beta_schedule,
timesteps = timesteps,
@@ -937,8 +987,6 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start:
x_recon = pred
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -961,21 +1009,75 @@ class DiffusionPrior(nn.Module):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
device = self.device
b = shape[0]
image_embed = torch.randn(shape, device=device)
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
batch, device = shape[0], self.device
image_embed = torch.randn(shape, device = device)
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long)
times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
return image_embed
@torch.no_grad()
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
image_embed = torch.randn(shape, device = device)
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start:
x_start = pred
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
pred_noise = pred
if not self.predict_x_start:
x_start.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = l2norm(x_start) * self.image_embed_scale
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
new_noise = torch.randn_like(image_embed)
img = x_start * alpha_next.sqrt() + \
c1 * new_noise + \
c2 * pred_noise
return image_embed
@torch.no_grad()
def p_sample_loop(self, *args, timesteps = None, **kwargs):
timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
assert timesteps <= self.noise_scheduler.num_timesteps
is_ddim = timesteps < self.noise_scheduler.num_timesteps
if not is_ddim:
return self.p_sample_loop_ddpm(*args, **kwargs)
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
@@ -1010,7 +1112,15 @@ class DiffusionPrior(nn.Module):
@torch.no_grad()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
def sample(
self,
text,
num_samples_per_batch = 2,
cond_scale = 1.,
timesteps = None
):
timesteps = default(timesteps, self.sample_timesteps)
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
@@ -1025,7 +1135,7 @@ class DiffusionPrior(nn.Module):
if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
# retrieve original unscaled image embed
@@ -1088,8 +1198,16 @@ class DiffusionPrior(nn.Module):
# decoder
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def ConvTransposeUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
def NearestUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, dim_out, 3, padding = 1)
)
def Downsample(dim, *, dim_out = None):
dim_out = default(dim_out, dim)
@@ -1101,11 +1219,12 @@ class SinusoidalPosEmb(nn.Module):
self.dim = dim
def forward(self, x):
dtype, device = x.dtype, x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
class Block(nn.Module):
def __init__(
@@ -1166,7 +1285,7 @@ class ResnetBlock(nn.Module):
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None):
def forward(self, x, time_emb = None, cond = None):
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
@@ -1192,10 +1311,12 @@ class CrossAttention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
norm_context = False
norm_context = False,
pb_relax_alpha = 32 ** 2
):
super().__init__()
self.scale = dim_head ** -0.5
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.heads = heads
inner_dim = dim_head * heads
@@ -1241,26 +1362,15 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class GridAttention(nn.Module):
def __init__(self, *args, window_size = 8, **kwargs):
super().__init__()
self.window_size = window_size
self.attn = Attention(*args, **kwargs)
def forward(self, x):
h, w = x.shape[-2:]
wsz = self.window_size
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
out = self.attn(x)
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class LinearAttention(nn.Module):
def __init__(
self,
@@ -1342,6 +1452,7 @@ class Unet(nn.Module):
dim_mults=(1, 2, 4, 8),
channels = 3,
channels_out = None,
self_attn = False,
attn_dim_head = 32,
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
@@ -1360,6 +1471,8 @@ class Unet(nn.Module):
cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False,
scale_skip_connection = False,
nearest_upsample = False,
final_conv_kernel_size = 1,
**kwargs
):
super().__init__()
@@ -1386,6 +1499,8 @@ class Unet(nn.Module):
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
num_stages = len(in_out)
# time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim)
@@ -1449,12 +1564,16 @@ class Unet(nn.Module):
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
self_attn = cast_tuple(self_attn, num_stages)
create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs)))
# resnet block klass
resnet_groups = cast_tuple(resnet_groups, len(in_out))
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
resnet_groups = cast_tuple(resnet_groups, num_stages)
top_level_resnet_group = first(resnet_groups)
assert len(resnet_groups) == len(in_out)
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages)
# downsample klass
@@ -1462,46 +1581,71 @@ class Unet(nn.Module):
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# upsample klass
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
# give memory efficient unet an initial resnet block
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
skip_connect_dims = [] # keeping track of skip connection dimensions
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
dim_layer = dim_out if memory_efficient else dim_in
skip_connect_dims.append(dim_layer)
attention = nn.Identity()
if layer_self_attn:
attention = create_self_attn(dim_layer)
elif sparse_attn:
attention = Residual(LinearAttention(dim_layer, **attn_kwargs))
self.downs.append(nn.ModuleList([
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
downsample_klass(dim_out) if not is_last and not memory_efficient else None
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
]))
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_attn = create_self_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
is_last = ind >= (len(in_out) - 1)
layer_cond_dim = cond_dim if not is_last else None
skip_connect_dim = skip_connect_dims.pop()
attention = nn.Identity()
if layer_self_attn:
attention = create_self_attn(dim_out)
elif sparse_attn:
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
]))
self.final_conv = nn.Sequential(
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1)
)
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
@@ -1575,6 +1719,7 @@ class Unet(nn.Module):
# time conditioning
time = time.type_as(x)
time_hiddens = self.to_time_hiddens(time)
time_tokens = self.to_time_tokens(time_hiddens)
@@ -1665,56 +1810,71 @@ class Unet(nn.Module):
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# initial resnet block
if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
# go through the layers of the unet, down and up
hiddens = []
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
x = init_block(x, c, t)
x = sparse_attn(x)
x = init_block(x, t, c)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
x = resnet_block(x, t, c)
hiddens.append(x)
x = attn(x)
hiddens.append(x)
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, mid_c, t)
x = self.mid_block1(x, t, mid_c)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_block2(x, mid_c, t)
x = self.mid_block2(x, t, mid_c)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
skip_connect = hiddens.pop() * self.skip_connect_scale
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
x = torch.cat((x, skip_connect), dim = 1)
x = init_block(x, c, t)
x = sparse_attn(x)
for init_block, resnet_blocks, attn, upsample in self.ups:
x = connect_skip(x)
x = init_block(x, t, c)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
x = connect_skip(x)
x = resnet_block(x, t, c)
x = attn(x)
x = upsample(x)
x = torch.cat((x, r), dim = 1)
return self.final_conv(x)
x = self.final_resnet_block(x, t)
return self.to_out(x)
class LowresConditioner(nn.Module):
def __init__(
self,
downsample_first = True,
downsample_mode_nearest = False,
blur_sigma = 0.6,
blur_kernel_size = 3,
input_image_range = None
):
super().__init__()
self.downsample_first = downsample_first
self.downsample_mode_nearest = downsample_mode_nearest
self.input_image_range = input_image_range
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
@@ -1728,7 +1888,7 @@ class LowresConditioner(nn.Module):
blur_kernel_size = None
):
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size)
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
if self.training:
# when training, blur the low resolution conditional image
@@ -1748,7 +1908,7 @@ class LowresConditioner(nn.Module):
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size)
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
return cond_fmap
@@ -1762,6 +1922,7 @@ class Decoder(nn.Module):
channels = 3,
vae = tuple(),
timesteps = 1000,
sample_timesteps = None,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l2',
@@ -1771,7 +1932,8 @@ class Decoder(nn.Module):
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
blur_sigma = 0.6, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
clip_denoised = True,
clip_x_start = True,
@@ -1784,7 +1946,8 @@ class Decoder(nn.Module):
use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1
p2_loss_weight_k = 1,
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
):
super().__init__()
@@ -1864,9 +2027,10 @@ class Decoder(nn.Module):
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# determine from unets whether conditioning on text encoding is needed
# sampling timesteps, defaults to non-ddim with full timesteps sampling
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)
self.ddim_sampling_eta = ddim_sampling_eta
# create noise schedulers per unet
@@ -1878,7 +2042,9 @@ class Decoder(nn.Module):
self.noise_schedulers = nn.ModuleList([])
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,
timesteps = timesteps,
@@ -1906,6 +2072,10 @@ class Decoder(nn.Module):
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# input image range
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
@@ -1913,8 +2083,10 @@ class Decoder(nn.Module):
self.to_lowres_cond = LowresConditioner(
downsample_first = lowres_downsample_first,
downsample_mode_nearest = lowres_downsample_mode_nearest,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
input_image_range = self.input_image_range
)
# classifier free guidance
@@ -1946,6 +2118,10 @@ class Decoder(nn.Module):
def device(self):
return self._dummy.device
@property
def condition_on_text_encodings(self):
return any([unet.cond_on_text_encodings for unet in self.unets])
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
@@ -1969,6 +2145,26 @@ class Decoder(nn.Module):
for unet, device in zip(self.unets, devices):
unet.to(device)
def dynamic_threshold(self, x):
""" proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """
# s is the threshold amount
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x = x.clamp(-s, s) / s
return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -1983,21 +2179,7 @@ class Decoder(nn.Module):
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
# s is the threshold amount
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x_recon, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
x_recon = self.dynamic_threshold(x_recon)
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
@@ -2027,7 +2209,7 @@ class Decoder(nn.Module):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.device
b = shape[0]
@@ -2055,6 +2237,61 @@ class Decoder(nn.Module):
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
@torch.no_grad()
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
img = torch.randn(shape, device = device)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
if predict_x_start:
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
if clip_denoised:
x_start = self.dynamic_threshold(x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
img = x_start * alpha_next.sqrt() + \
c1 * torch.randn_like(img) + \
c2 * pred_noise
img = self.unnormalize_img(img)
return img
@torch.no_grad()
def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
num_timesteps = noise_scheduler.num_timesteps
timesteps = default(timesteps, num_timesteps)
assert timesteps <= num_timesteps
is_ddim = timesteps < num_timesteps
if not is_ddim:
return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -2155,7 +2392,7 @@ class Decoder(nn.Module):
img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
@@ -2184,7 +2421,8 @@ class Decoder(nn.Module):
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
is_latent_diffusion = is_latent_diffusion,
noise_scheduler = noise_scheduler
noise_scheduler = noise_scheduler,
timesteps = sample_timesteps
)
img = vae.decode(img)
@@ -2201,7 +2439,8 @@ class Decoder(nn.Module):
image_embed = None,
text_encodings = None,
text_mask = None,
unet_number = None
unet_number = None,
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
@@ -2251,7 +2490,12 @@ class Decoder(nn.Module):
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
if not return_lowres_cond_image:
return losses
return losses, lowres_cond_img
# main class
@@ -2299,6 +2543,6 @@ class DALLE2(nn.Module):
images = list(map(self.to_pil, images.unbind(dim = 0)))
if one_text:
return images[0]
return first(images)
return images

View File

@@ -1,6 +1,7 @@
import os
import webdataset as wds
import torch
from torch.utils.data import DataLoader
import numpy as np
import fsspec
import shutil
@@ -255,7 +256,7 @@ def create_image_embedding_dataloader(
)
if shuffle_num is not None and shuffle_num > 0:
ds.shuffle(1000)
return wds.WebLoader(
return DataLoader(
ds,
num_workers=num_workers,
batch_size=batch_size,

View File

@@ -1,12 +1,16 @@
import urllib.request
import os
import json
from pathlib import Path
import importlib
import shutil
from itertools import zip_longest
from typing import Optional, List, Union
from pydantic import BaseModel
import torch
from torch import nn
from dalle2_pytorch.utils import import_or_print_error
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
# constants
@@ -27,126 +31,553 @@ def load_wandb_file(run_path, file_path, **kwargs):
def load_local_file(file_path, **kwargs):
return file_path
# base class
class BaseTracker(nn.Module):
def __init__(self, data_path = DEFAULT_DATA_PATH):
super().__init__()
class BaseLogger:
"""
An abstract class representing an object that can log data.
Parameters:
data_path (str): A file path for storing temporary data.
verbose (bool): Whether of not to always print logs to the console.
"""
def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
self.data_path = Path(data_path)
self.data_path.mkdir(parents = True, exist_ok = True)
self.resume = resume
self.auto_resume = auto_resume
self.verbose = verbose
def init(self, config, **kwargs):
raise NotImplementedError
def log(self, log, **kwargs):
raise NotImplementedError
def log_images(self, images, **kwargs):
raise NotImplementedError
def save_state_dict(self, state_dict, relative_path, **kwargs):
raise NotImplementedError
def recall_state_dict(self, recall_source, *args, **kwargs):
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
"""
Loads a state dict from any source.
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
this should not be linked to any individual tracker.
Initializes the logger.
Errors if the logger is invalid.
full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
"""
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
if recall_source == 'wandb':
return torch.load(load_wandb_file(*args, **kwargs))
elif recall_source == 'local':
return torch.load(load_local_file(*args, **kwargs))
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
def save_file(self, file_path, **kwargs):
raise NotImplementedError
def recall_file(self, recall_source, *args, **kwargs):
if recall_source == 'wandb':
return load_wandb_file(*args, **kwargs)
elif recall_source == 'local':
return load_local_file(*args, **kwargs)
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
def log(self, log, **kwargs) -> None:
raise NotImplementedError
# Tracker that no-ops all calls except for recall
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
raise NotImplementedError
class DummyTracker(BaseTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def log_file(self, file_path, **kwargs) -> None:
raise NotImplementedError
def init(self, config, **kwargs):
pass
def log_error(self, error_string, **kwargs) -> None:
raise NotImplementedError
def log(self, log, **kwargs):
pass
def get_resume_data(self, **kwargs) -> dict:
"""
Sets tracker attributes that along with { "resume": True } will be used to resume training.
It is assumed that after init is called this data will be complete.
If the logger does not have any resume functionality, it should return an empty dict.
"""
raise NotImplementedError
def log_images(self, images, **kwargs):
pass
class ConsoleLogger(BaseLogger):
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
print("Logging to console")
def save_state_dict(self, state_dict, relative_path, **kwargs):
pass
def save_file(self, file_path, **kwargs):
pass
# basic stdout class
class ConsoleTracker(BaseTracker):
def init(self, **config):
print(config)
def log(self, log, **kwargs):
def log(self, log, **kwargs) -> None:
print(log)
def log_images(self, images, **kwargs): # noop for logging images
pass
def save_state_dict(self, state_dict, relative_path, **kwargs):
torch.save(state_dict, str(self.data_path / relative_path))
def save_file(self, file_path, **kwargs):
# This is a no-op for local file systems since it is already saved locally
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
pass
# basic wandb class
def log_file(self, file_path, **kwargs) -> None:
pass
class WandbTracker(BaseTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
def log_error(self, error_string, **kwargs) -> None:
print(error_string)
def get_resume_data(self, **kwargs) -> dict:
return {}
class WandbLogger(BaseLogger):
"""
Logs to a wandb run.
Parameters:
data_path (str): A file path for storing temporary data.
wandb_entity (str): The wandb entity to log to.
wandb_project (str): The wandb project to log to.
wandb_run_id (str): The wandb run id to resume.
wandb_run_name (str): The wandb run name to use.
"""
def __init__(self,
data_path: str,
wandb_entity: str,
wandb_project: str,
wandb_run_id: Optional[str] = None,
wandb_run_name: Optional[str] = None,
**kwargs
):
super().__init__(data_path, **kwargs)
self.entity = wandb_entity
self.project = wandb_project
self.run_id = wandb_run_id
self.run_name = wandb_run_name
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
assert self.project is not None, "wandb_project must be specified for wandb logger"
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
os.environ["WANDB_SILENT"] = "true"
# Initializes the wandb run
init_object = {
"entity": self.entity,
"project": self.project,
"config": {**full_config.dict(), **extra_config}
}
if self.run_name is not None:
init_object['name'] = self.run_name
if self.resume:
assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
if self.run_name is not None:
print("You are renaming a run. I hope that is what you intended.")
init_object['resume'] = 'must'
init_object['id'] = self.run_id
def init(self, **config):
self.wandb.init(**config)
self.wandb.init(**init_object)
print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
def log(self, log, verbose=False, **kwargs):
if verbose:
def log(self, log, **kwargs) -> None:
if self.verbose:
print(log)
self.wandb.log(log, **kwargs)
def log_images(self, images, captions=[], image_section="images", **kwargs):
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
"""
Takes a tensor of images and a list of captions and logs them to wandb.
"""
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
self.log({ image_section: wandb_images }, **kwargs)
def save_state_dict(self, state_dict, relative_path, **kwargs):
"""
Saves a state_dict to disk and uploads it
"""
full_path = str(self.data_path / relative_path)
torch.save(state_dict, full_path)
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
self.wandb.log({ image_section: wandb_images }, **kwargs)
def save_file(self, file_path, base_path=None, **kwargs):
"""
Uploads a file from disk to wandb
"""
def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
if base_path is None:
base_path = self.data_path
# Then we take the basepath as the parent of the file_path
base_path = Path(file_path).parent
self.wandb.save(str(file_path), base_path = str(base_path))
def log_error(self, error_string, step=None, **kwargs) -> None:
if self.verbose:
print(error_string)
self.wandb.log({"error": error_string, **kwargs}, step=step)
def get_resume_data(self, **kwargs) -> dict:
# In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
return {
"entity": self.entity,
"project": self.project,
"run_id": self.wandb.run.id
}
logger_type_map = {
'console': ConsoleLogger,
'wandb': WandbLogger,
}
def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
if logger_type == 'custom':
raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
try:
logger_class = logger_type_map[logger_type]
except KeyError:
raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
return logger_class(data_path, **kwargs)
class BaseLoader:
"""
An abstract class representing an object that can load a model checkpoint.
Parameters:
data_path (str): A file path for storing temporary data.
"""
def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
self.data_path = Path(data_path)
self.only_auto_resume = only_auto_resume
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
def recall() -> dict:
raise NotImplementedError
class UrlLoader(BaseLoader):
"""
A loader that downloads the file from a url and loads it
Parameters:
data_path (str): A file path for storing temporary data.
url (str): The url to download the file from.
"""
def __init__(self, data_path: str, url: str, **kwargs):
super().__init__(data_path, **kwargs)
self.url = url
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be downloaded
pass # TODO: Actually implement that
def recall(self) -> dict:
# Download the file
save_path = self.data_path / 'loaded_checkpoint.pth'
urllib.request.urlretrieve(self.url, str(save_path))
# Load the file
return torch.load(str(save_path), map_location='cpu')
class LocalLoader(BaseLoader):
"""
A loader that loads a file from a local path
Parameters:
data_path (str): A file path for storing temporary data.
file_path (str): The path to the file to load.
"""
def __init__(self, data_path: str, file_path: str, **kwargs):
super().__init__(data_path, **kwargs)
self.file_path = Path(file_path)
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be loaded
if not self.file_path.exists():
raise FileNotFoundError(f'Model not found at {self.file_path}')
def recall(self) -> dict:
# Load the file
return torch.load(str(self.file_path), map_location='cpu')
class WandbLoader(BaseLoader):
"""
A loader that loads a model from an existing wandb run
"""
def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
super().__init__(data_path, **kwargs)
self.run_path = wandb_run_path
self.file_path = wandb_file_path
def init(self, logger: BaseLogger, **kwargs) -> None:
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
# Make sure the file can be downloaded
if self.wandb.run is not None and self.run_path is None:
self.run_path = self.wandb.run.path
assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
os.environ["WANDB_SILENT"] = "true"
pass # TODO: Actually implement that
def recall(self) -> dict:
file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
return torch.load(file_reference.name, map_location='cpu')
loader_type_map = {
'url': UrlLoader,
'local': LocalLoader,
'wandb': WandbLoader,
}
def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
if loader_type == 'custom':
raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
try:
loader_class = loader_type_map[loader_type]
except KeyError:
raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
return loader_class(data_path, **kwargs)
class BaseSaver:
def __init__(self,
data_path: str,
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
save_best_to: Optional[Union[str, bool]] = 'best.pth',
save_meta_to: str = './',
save_type: str = 'checkpoint',
**kwargs
):
self.data_path = Path(data_path)
self.save_latest_to = save_latest_to
self.saving_latest = save_latest_to is not None and save_latest_to is not False
self.save_best_to = save_best_to
self.saving_best = save_best_to is not None and save_best_to is not False
self.save_meta_to = save_meta_to
self.save_type = save_type
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
"""
Save a general file under save_meta_to
"""
raise NotImplementedError
class LocalSaver(BaseSaver):
def __init__(self,
data_path: str,
**kwargs
):
super().__init__(data_path, **kwargs)
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the directory exists to be saved to
print(f"Saving {self.save_type} locally")
if not self.data_path.exists():
self.data_path.mkdir(parents=True)
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
# Copy the file to save_path
save_path_file_name = Path(save_path).name
# Make sure parent directory exists
save_path_parent = Path(save_path).parent
if not save_path_parent.exists():
save_path_parent.mkdir(parents=True)
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
shutil.copy(local_path, save_path)
class WandbSaver(BaseSaver):
def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
super().__init__(data_path, **kwargs)
self.run_path = wandb_run_path
def init(self, logger: BaseLogger, **kwargs) -> None:
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
os.environ["WANDB_SILENT"] = "true"
# Makes sure that the user can upload tot his run
if self.run_path is not None:
entity, project, run_id = self.run_path.split("/")
self.run = self.wandb.init(entity=entity, project=project, id=run_id)
else:
assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
self.run = self.wandb.run
# TODO: Now actually check if upload is possible
print(f"Saving to wandb run {self.run.path}-{self.run.name}")
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
# In order to log something in the correct place in wandb, we need to have the same file structure here
save_path_file_name = Path(save_path).name
print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
save_path = Path(self.data_path) / save_path
save_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(local_path, save_path)
self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
class HuggingfaceSaver(BaseSaver):
def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
super().__init__(data_path, **kwargs)
self.huggingface_repo = huggingface_repo
self.token_path = token_path
def init(self, logger: BaseLogger, **kwargs):
# Makes sure this user can upload to the repo
self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
try:
identity = self.hub.whoami() # Errors if not logged in
# Then we are logged in
except:
# We are not logged in. Use the token_path to set the token.
if not os.path.exists(self.token_path):
raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
with open(self.token_path, "r") as f:
token = f.read().strip()
self.hub.HfApi.set_access_token(token)
identity = self.hub.whoami()
print(f"Saving to huggingface repo {self.huggingface_repo}")
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
# Saving to huggingface is easy, we just need to upload the file with the correct name
save_path_file_name = Path(save_path).name
print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
self.hub.upload_file(
path_or_fileobj=str(local_path),
path_in_repo=str(save_path),
repo_id=self.huggingface_repo
)
saver_type_map = {
'local': LocalSaver,
'wandb': WandbSaver,
'huggingface': HuggingfaceSaver
}
def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
if saver_type == 'custom':
raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
try:
saver_class = saver_type_map[saver_type]
except KeyError:
raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
return saver_class(data_path, **kwargs)
class Tracker:
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
self.data_path = Path(data_path)
if not dummy_mode:
if not overwrite_data_path:
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
if not self.data_path.exists():
self.data_path.mkdir(parents=True)
self.logger: BaseLogger = None
self.loader: Optional[BaseLoader] = None
self.savers: List[BaseSaver]= []
self.dummy_mode = dummy_mode
def _load_auto_resume(self) -> bool:
# If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
if not self.auto_resume_path.exists():
if self.logger.auto_resume:
print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
return False
# Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
if not self.logger.auto_resume:
print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
self.auto_resume_path.unlink()
return False
# Otherwise we read the json into a dictionary will will override parts of logger.__dict__
with open(self.auto_resume_path, 'r') as f:
auto_resume_dict = json.load(f)
# Check if the logger is of the same type as the autoresume save
if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
# Then we are ready to override the logger with the autoresume save
self.logger.__dict__["resume"] = True
print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
self.logger.__dict__.update(auto_resume_dict)
return True
def _save_auto_resume(self):
# Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
auto_resume_dict = self.logger.get_resume_data()
auto_resume_dict['logger_type'] = self.logger.__class__.__name__
with open(self.auto_resume_path, 'w') as f:
json.dump(auto_resume_dict, f)
def init(self, full_config: BaseModel, extra_config: dict):
self.auto_resume_path = self.data_path / 'auto_resume.json'
# Check for resuming the run
self.did_auto_resume = self._load_auto_resume()
if self.did_auto_resume:
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
print(f"New logger config: {self.logger.__dict__}")
assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode:
# The only thing we need is a loader
if self.loader is not None:
self.loader.init(self.logger)
return
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
self.logger.init(full_config, extra_config)
if self.loader is not None:
self.loader.init(self.logger)
for saver in self.savers:
saver.init(self.logger)
if self.logger.auto_resume:
# Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
self._save_auto_resume()
def add_logger(self, logger: BaseLogger):
self.logger = logger
def add_loader(self, loader: BaseLoader):
self.loader = loader
def add_saver(self, saver: BaseSaver):
self.savers.append(saver)
def log(self, *args, **kwargs):
if self.dummy_mode:
return
self.logger.log(*args, **kwargs)
def log_images(self, *args, **kwargs):
if self.dummy_mode:
return
self.logger.log_images(*args, **kwargs)
def log_file(self, *args, **kwargs):
if self.dummy_mode:
return
self.logger.log_file(*args, **kwargs)
def save_config(self, current_config_path: str, config_name = 'config.json'):
if self.dummy_mode:
return
# Save the config under config_name in the root folder of data_path
shutil.copy(current_config_path, self.data_path / config_name)
for saver in self.savers:
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
"""
Gets the state dict to be saved and writes it to file_path.
If save_type is 'checkpoint', we save the entire trainer state dict.
If save_type is 'model', we save only the model state dict.
"""
assert save_type in ['checkpoint', 'model']
if save_type == 'checkpoint':
trainer.save(file_path, overwrite=True, **kwargs)
elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
state_dict = trainer.unwrap_model(prior).state_dict()
torch.save(state_dict, file_path)
elif isinstance(trainer, DecoderTrainer):
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
if trainer.use_ema:
trainable_unets = decoder.unets
decoder.unets = trainer.unets # Swap EMA unets in
state_dict = decoder.state_dict()
decoder.unets = trainable_unets # Swap back
else:
state_dict = decoder.state_dict()
torch.save(state_dict, file_path)
else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
return Path(file_path)
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
if self.dummy_mode:
return
if not is_best and not is_latest:
# Nothing to do
return
# Save the checkpoint and model to data_path
checkpoint_path = self.data_path / 'checkpoint.pth'
self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
model_path = self.data_path / 'model.pth'
self._save_state_dict(trainer, 'model', model_path, **kwargs)
print("Saved cached models")
# Call the save methods on the savers
for saver in self.savers:
local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
if saver.saving_latest and is_latest:
latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
try:
saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
except Exception as e:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}')
if saver.saving_best and is_best:
best_checkpoint_path = saver.save_best_to.format(**kwargs)
try:
saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
except Exception as e:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}')
@property
def can_recall(self):
# Defines whether a recall can be performed.
return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
def recall(self):
if self.can_recall:
return self.loader.recall()
else:
raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')

View File

@@ -15,6 +15,7 @@ from dalle2_pytorch.dalle2_pytorch import (
DiffusionPriorNetwork,
XClipAdapter
)
from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
# helper functions
@@ -44,13 +45,69 @@ class TrainSplitConfig(BaseModel):
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
return fields
class TrackerLogConfig(BaseModel):
log_type: str = 'console'
resume: bool = False # For logs that are saved to unique locations, resume a previous run
auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
verbose: bool = False
class Config:
# Each individual log type has it's own arguments that will be passed through the config
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
return create_logger(self.log_type, data_path, **kwargs)
class TrackerLoadConfig(BaseModel):
load_from: Optional[str] = None
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
class Config:
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
if self.load_from is None:
return None
return create_loader(self.load_from, data_path, **kwargs)
class TrackerSaveConfig(BaseModel):
save_to: str = 'local'
save_all: bool = False
save_latest: bool = True
save_best: bool = True
class Config:
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
return create_saver(self.save_to, data_path, **kwargs)
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
data_path: str = '.tracker_data'
overwrite_data_path: bool = False
log: TrackerLogConfig
load: Optional[TrackerLoadConfig]
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
# Add the logger
tracker.add_logger(self.log.create(self.data_path))
# Add the loader
if self.load is not None:
tracker.add_loader(self.load.create(self.data_path))
# Add the saver or savers
if isinstance(self.save, list):
for save_config in self.save:
tracker.add_saver(save_config.create(self.data_path))
else:
tracker.add_saver(self.save.create(self.data_path))
# Initialize all the components and verify that all data is valid
tracker.init(full_config, extra_config)
return tracker
# diffusion prior pydantic classes
@@ -97,6 +154,7 @@ class DiffusionPriorConfig(BaseModel):
image_size: int
image_channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[int] = None
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
@@ -162,6 +220,7 @@ class UnetConfig(BaseModel):
cond_on_text_encodings: bool = None
cond_dim: int = None
channels: int = 3
self_attn: ListOrTuple(int)
attn_dim_head: int = 32
attn_heads: int = 16
@@ -175,6 +234,7 @@ class DecoderConfig(BaseModel):
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable(int)] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True
@@ -238,6 +298,8 @@ class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
warmup_steps: Optional[SingularOrIterable(int)] = None
find_unused_parameters: bool = True
max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
@@ -247,9 +309,6 @@ class DecoderTrainConfig(BaseModel):
use_ema: bool = True
ema_beta: float = 0.999
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
@@ -271,7 +330,6 @@ class TrainDecoderConfig(BaseModel):
train: DecoderTrainConfig
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
seed: int = 0
@classmethod
@@ -289,7 +347,7 @@ class TrainDecoderConfig(BaseModel):
# Then something else errored and we should just pass through
return values
using_text_encodings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
using_clip = exists(decoder_config.clip)
img_emb_url = data_config.img_embeddings_url
text_emb_url = data_config.text_embeddings_url

View File

@@ -3,10 +3,13 @@ import copy
from pathlib import Path
from math import ceil
from functools import partial, wraps
from contextlib import nullcontext
from collections.abc import Iterable
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -14,9 +17,11 @@ from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
import pytorch_warmup as warmup
from ema_pytorch import EMA
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
import numpy as np
@@ -71,6 +76,7 @@ def cast_torch_tensor(fn):
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
@@ -80,6 +86,21 @@ def cast_torch_tensor(fn):
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
if cast_deepspeed_precision:
try:
accelerator = model.accelerator
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
except AttributeError:
# Then this model doesn't have an accelerator
pass
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
@@ -162,19 +183,33 @@ class DiffusionPriorTrainer(nn.Module):
group_wd_params = True,
device = None,
accelerator = None,
verbose = True,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
# verbosity
self.verbose = verbose
# assign some helpful member vars
self.accelerator = accelerator
self.device = accelerator.device if exists(accelerator) else device
self.text_conditioned = diffusion_prior.condition_on_text_encodings
# setting the device
if not exists(accelerator) and not exists(device):
diffusion_prior_device = next(diffusion_prior.parameters()).device
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
self.device = diffusion_prior_device
else:
self.device = accelerator.device if exists(accelerator) else device
diffusion_prior.to(self.device)
# save model
self.diffusion_prior = diffusion_prior
@@ -210,11 +245,14 @@ class DiffusionPriorTrainer(nn.Module):
# track steps internally
self.register_buffer('step', torch.tensor([0]))
self.register_buffer('step', torch.tensor([0], device = self.device))
# accelerator wrappers
def print(self, msg):
if not self.verbose:
return
if exists(self.accelerator):
self.accelerator.print(msg)
else:
@@ -424,10 +462,12 @@ class DecoderTrainer(nn.Module):
self,
decoder,
accelerator = None,
dataloaders = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
eps = 1e-8,
warmup_steps = None,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
@@ -449,13 +489,15 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay
# per unet
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
optimizers = []
schedulers = []
warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
@@ -467,6 +509,13 @@ class DecoderTrainer(nn.Module):
optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
schedulers.append(scheduler)
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
@@ -474,15 +523,58 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip
clip.to(precision_type)
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = decoder
# prepare dataloaders
train_loader = val_loader = None
if exists(dataloaders):
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
self.train_loader = train_loader
self.val_loader = val_loader
# store optimizers
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
setattr(self, f'optim{opt_ind}', optimizer)
# store schedulers
for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
setattr(self, f'sched{sched_ind}', scheduler)
# store warmup schedulers
self.warmup_schedulers = warmup_schedulers
def validate_and_return_unet_number(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
return unet_number
def num_steps_taken(self, unet_number = None):
unet_number = self.validate_and_return_unet_number(unet_number)
return self.steps[unet_number - 1].item()
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
@@ -491,7 +583,7 @@ class DecoderTrainer(nn.Module):
save_obj = dict(
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
version = __version__,
step = self.step.item(),
steps = self.steps.cpu(),
**kwargs
)
@@ -505,30 +597,38 @@ class DecoderTrainer(nn.Module):
self.accelerator.save(save_obj, str(path))
def load_state_dict(self, loaded_obj, only_model = False, strict = True):
if version.parse(__version__) != version.parse(loaded_obj['version']):
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
self.steps.copy_(loaded_obj['steps'])
if only_model:
return loaded_obj
for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
warmup_scheduler = self.warmup_schedulers[ind]
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path), map_location = 'cpu')
if version.parse(__version__) != version.parse(loaded_obj['version']):
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
return loaded_obj
@@ -536,25 +636,36 @@ class DecoderTrainer(nn.Module):
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
def increment_step(self, unet_number):
assert 1 <= unet_number <= self.num_unets
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
def update(self, unet_number = None):
unet_number = self.validate_and_return_unet_number(unet_number)
index = unet_number - 1
optimizer = getattr(self, f'optim{index}')
scheduler = getattr(self, f'sched{index}')
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
optimizer.step()
optimizer.zero_grad()
warmup_scheduler = self.warmup_schedulers[index]
scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
with scheduler_context():
scheduler.step()
if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()
self.step += 1
self.increment_step(unet_number)
@torch.no_grad()
@cast_torch_tensor
@@ -598,13 +709,14 @@ class DecoderTrainer(nn.Module):
max_batch_size = None,
**kwargs
):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
unet_number = self.validate_and_return_unet_number(unet_number)
total_loss = 0.
using_amp = self.accelerator.mixed_precision != 'no'
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
# with autocast(enabled = self.amp):
with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac

View File

@@ -1,6 +1,11 @@
import time
import importlib
# helper functions
def exists(val):
return val is not None
# time helpers
class Timer:

View File

@@ -1 +1 @@
__version__ = '0.12.3'
__version__ = '0.19.1'

183
prior.md Normal file
View File

@@ -0,0 +1,183 @@
# Diffusion Prior
This readme serves as an introduction to the diffusion prior.
## Intro
A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful.
### Motivation
Before we dive into the model, lets look at a quick example of where the model may be helpful.
For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.
> [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets.
```python
# Load Models
clip_model = clip.load("ViT-L/14")
decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings
# Retrieve prompt from user and encode with CLIP
prompt = "A corgi wearing sunglasses"
tokenized_text = tokenize(prompt)
text_embedding = clip_model.encode_text(tokenized_text)
# Now, pass the text embedding to the decoder
predicted_image = decoder.sample(text_embedding)
```
> **Question**: *Can you spot the issue here?*
>
> **Answer**: *Were trying to generate an image from a text embedding!*
Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution
```python
# Load Models
prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb
decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings
# Retrieve prompt from user and encode with a prior
prompt = "A corgi wearing sunglasses"
tokenized_text = tokenize(prompt)
text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!
# Now, pass the predicted image embedding to the decoder
predicted_image = decoder.sample(text_embedding)
```
With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.
> **You may be asking yourself the following question:**
>
> *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"*
>
> OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile:
## Usage
To utilize a pre-trained prior, its quite simple.
### Loading Checkpoints
```python
import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer
def load_diffusion_model(dprior_path):
prior_network = DiffusionPriorNetwork(
dim=768,
depth=24,
dim_head=64,
heads=32,
normformer=True,
attn_dropout=5e-2,
ff_dropout=5e-2,
num_time_embeds=1,
num_image_embeds=1,
num_text_embeds=1,
num_timesteps=1000,
ff_mult=4
)
diffusion_prior = DiffusionPrior(
net=prior_network,
clip=OpenAIClipAdapter("ViT-L/14"),
image_embed_dim=768,
timesteps=1000,
cond_drop_prob=0.1,
loss_type="l2",
condition_on_text_encodings=True,
)
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=1.1e-4,
wd=6.02e-2,
max_grad_norm=0.5,
amp=False,
group_wd_params=True,
use_ema=True,
device=device,
accelerator=None,
)
trainer.load(dprior_path)
return trainer
```
Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*)
### Sampling
Once we have a pre-trained model, generating embeddings is quite simple!
```python
# tokenize the text
tokenized_text = clip.tokenize("<your amazing prompt>")
# predict an embedding
predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)
```
The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768).
> For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().
**Some things to note:**
* It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.
* You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*.
---
## Training
### Overview
Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration
## Dataset
To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader.
## Configuration
The configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that will specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.
## Distributed Training
If you would like to train in a distributed manner we have opted to leverage huggingface new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPUs and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator).
## Evaluation
There are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:
| Metric | Description | Comments |
| ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Online Model Validation | The validation loss associated with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. |
| EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your "online" model's validation loss, but should outperform in the long-term. |
| Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. |
| Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) |
| Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting somehow. |
| Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. |
| Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. |
## Launching the script
Now that youve done all the prep its time for the easy part! 🚀
To actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path <path to your config>` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate.
## Checkpointing
Checkpoints will be saved to the directory specified in your configuration file.
Additionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and titled “latest.pth”. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data.
## Things To Keep In Mind
The prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet.
As we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.
With that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you dont see documentation for!

View File

@@ -37,6 +37,7 @@ setup(
'packaging',
'pillow',
'pydantic',
'pytorch-warmup',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',

BIN
test_data/0.tar Normal file

Binary file not shown.

BIN
test_data/1.tar Normal file

Binary file not shown.

BIN
test_data/2.tar Normal file

Binary file not shown.

BIN
test_data/3.tar Normal file

Binary file not shown.

BIN
test_data/4.tar Normal file

Binary file not shown.

BIN
test_data/5.tar Normal file

Binary file not shown.

BIN
test_data/6.tar Normal file

Binary file not shown.

BIN
test_data/7.tar Normal file

Binary file not shown.

BIN
test_data/8.tar Normal file

Binary file not shown.

BIN
test_data/9.tar Normal file

Binary file not shown.

View File

@@ -1,11 +1,12 @@
from pathlib import Path
from typing import List
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.trackers import Tracker
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
from clip import tokenize
import torchvision
@@ -131,7 +132,7 @@ def get_example_data(dataloader, device, n=5):
break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
@@ -159,6 +160,9 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
samples = trainer.sample(**sample_params)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
if match_image_size:
generated_image_size = generated_images[0].shape[-1]
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
@@ -166,14 +170,6 @@ def generate_grid_samples(trainer, examples, condition_on_text_encodings=False,
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
@@ -239,42 +235,33 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
metrics[metric_name] = metrics_tensor[i].item()
return metrics
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
"""
Logs the model with an appropriate method depending on the tracker
"""
if isinstance(relative_paths, str):
relative_paths = [relative_paths]
for relative_path in relative_paths:
local_path = str(tracker.data_path / relative_path)
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
tracker.save_file(local_path)
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
"""
Loads the model with an appropriate method depending on the tracker
"""
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
local_filepath = tracker.recall_file(recall_source, **load_config)
state_dict = trainer.load(local_filepath)
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
state_dict = tracker.recall()
trainer.load_state_dict(state_dict, only_model=False, strict=True)
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
def train(
dataloaders,
decoder,
accelerator,
tracker,
decoder: Decoder,
accelerator: Accelerator,
tracker: Tracker,
inference_device,
load_config=None,
evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None,
epochs = 20,
n_sample_images = 5,
save_every_n_samples = 100000,
save_all=False,
save_latest=True,
save_best=True,
unet_training_mask=None,
condition_on_text_encodings=False,
**kwargs
@@ -287,6 +274,7 @@ def train(
trainer = DecoderTrainer(
decoder=decoder,
accelerator=accelerator,
dataloaders=dataloaders,
**kwargs
)
@@ -297,21 +285,22 @@ def train(
sample = 0
samples_seen = 0
val_sample = 0
step = lambda: int(trainer.step.item())
if exists(load_config) and exists(load_config.source):
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
if next_task == 'train':
sample = recalled_sample
if next_task == 'val':
val_sample = recalled_sample
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device)
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
@@ -334,7 +323,7 @@ def train(
last_snapshot = sample
if next_task == 'train':
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
for i, (img, emb, txt) in enumerate(trainer.train_loader):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -399,19 +388,14 @@ def train(
}
if is_master:
tracker.log(log_data, step=step(), verbose=True)
tracker.log(log_data, step=step())
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
print("Saving snapshot")
last_snapshot = sample
# We need to know where the model should be saved
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
@@ -432,7 +416,7 @@ def train(
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
@@ -486,7 +470,7 @@ def train(
if is_master:
unet_average_val_loss = all_average_val_losses.mean(dim=0)
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
tracker.log(val_loss_map, step=step(), verbose=True)
tracker.log(val_loss_map, step=step())
next_task = 'eval'
if next_task == 'eval':
@@ -494,7 +478,7 @@ def train(
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
if is_master:
tracker.log(evaluation, step=step(), verbose=True)
tracker.log(evaluation, step=step())
next_task = 'sample'
val_sample = 0
@@ -509,22 +493,16 @@ def train(
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
# Get the same paths
save_paths = []
if save_latest:
save_paths.append("latest.pth")
is_best = False
if all_average_val_losses is not None:
average_loss = all_average_val_losses.mean(dim=0).item()
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
save_paths.append("best.pth")
if len(validation_losses) == 0 or average_loss < min(validation_losses):
is_best = True
validation_losses.append(average_loss)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
next_task = 'train'
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
tracker_config = config.tracker
accelerator_config = {
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
@@ -532,41 +510,31 @@ def create_tracker(accelerator, config, config_path, tracker_type=None, data_pat
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision
}
init_config = { "config": {**config.dict(), **accelerator_config} }
data_path = data_path or tracker_config.data_path
tracker_type = tracker_type or tracker_config.tracker_type
if tracker_type == "dummy":
tracker = DummyTracker(data_path)
tracker.init(**init_config)
elif tracker_type == "console":
tracker = ConsoleTracker(data_path)
tracker.init(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config.load
if load_config.source == "wandb" and load_config.resume:
# Then we are resuming the run load_config["run_path"]
run_id = load_config.run_path.split("/")[-1]
init_config["id"] = run_id
init_config["resume"] = "must"
init_config["entity"] = tracker_config.wandb_entity
init_config["project"] = tracker_config.wandb_project
tracker = WandbTracker(data_path)
tracker.init(**init_config)
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
else:
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json')
return tracker
def initialize_training(config, config_path):
def initialize_training(config: TrainDecoderConfig, config_path):
# Make sure if we are not loading, distributed models are initialized to the same values
torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
if accelerator.num_processes > 1:
# We are using distributed training and want to immediately ensure all can connect
accelerator.print("Waiting for all processes to connect...")
accelerator.wait_for_everyone()
accelerator.print("All processes online and connected")
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
# This is an invalid configuration until we figure out how to handle this
raise ValueError("DeepSpeed does not support multi-node distributed training")
# Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
@@ -592,7 +560,7 @@ def initialize_training(config, config_path):
num_parameters = sum(p.numel() for p in decoder.parameters())
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
has_img_embeddings = config.data.img_embeddings_url is not None
has_text_embeddings = config.data.text_embeddings_url is not None
@@ -622,7 +590,6 @@ def initialize_training(config, config_path):
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,
load_config=config.load,
evaluate_config=config.evaluate,
condition_on_text_encodings=conditioning_on_text,
**config.train.dict(),