mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 21:34:19 +01:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49de72040c | ||
|
|
271a376eaf | ||
|
|
e527002472 | ||
|
|
c12e067178 | ||
|
|
c6629c431a | ||
|
|
7ac2fc79f2 | ||
|
|
a1ef023193 | ||
|
|
d49eca62fa | ||
|
|
8aab69b91e | ||
|
|
b432df2f7b | ||
|
|
ebaa0d28c2 | ||
|
|
8b0d459b25 | ||
|
|
0064661729 | ||
|
|
b895f52843 | ||
|
|
80497e9839 | ||
|
|
f526f14d7c | ||
|
|
8997f178d6 | ||
|
|
022c94e443 | ||
|
|
430961cb97 | ||
|
|
721f9687c1 |
9
.gitignore
vendored
9
.gitignore
vendored
@@ -1,3 +1,12 @@
|
|||||||
|
# default experiment tracker data
|
||||||
|
.tracker-data/
|
||||||
|
|
||||||
|
# Configuration Files
|
||||||
|
configs/*
|
||||||
|
!configs/*.example
|
||||||
|
!configs/*_defaults.py
|
||||||
|
!configs/README.md
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|||||||
15
README.md
15
README.md
@@ -1034,6 +1034,18 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
|
|
||||||
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
||||||
|
|
||||||
|
## Appreciation
|
||||||
|
|
||||||
|
This library would not have gotten to this working state without the help of
|
||||||
|
|
||||||
|
- <a href="https://github.com/nousr">Zion</a> and <a href="https://github.com/krish240574">Kumar</a> for the diffusion training script
|
||||||
|
- <a href="https://github.com/Veldrovive">Aidan</a> for the decoder training script and dataloaders
|
||||||
|
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||||
|
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||||
|
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||||
|
|
||||||
|
... and many others. Thank you! 🙏
|
||||||
|
|
||||||
## Todo
|
## Todo
|
||||||
|
|
||||||
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
||||||
@@ -1064,6 +1076,8 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||||
- [x] cross embed layers for downsampling, as an option
|
- [x] cross embed layers for downsampling, as an option
|
||||||
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||||
|
- [x] use pydantic for config drive training
|
||||||
|
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
@@ -1078,7 +1092,6 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] decoder needs one day worth of refactor for tech debt
|
- [ ] decoder needs one day worth of refactor for tech debt
|
||||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||||
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
||||||
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
|
||||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|||||||
109
configs/README.md
Normal file
109
configs/README.md
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
## DALLE2 Training Configurations
|
||||||
|
|
||||||
|
For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
|
||||||
|
|
||||||
|
### Decoder Trainer
|
||||||
|
|
||||||
|
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
||||||
|
|
||||||
|
**<ins>Unets</ins>:**
|
||||||
|
|
||||||
|
Each member of this array defines a single unet that will be added to the decoder.
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `dim` | Yes | N/A | The starting channels of the unet. |
|
||||||
|
| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |
|
||||||
|
| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |
|
||||||
|
|
||||||
|
Any parameter from the `Unet` constructor can also be given here.
|
||||||
|
|
||||||
|
**<ins>Decoder</ins>:**
|
||||||
|
|
||||||
|
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
|
||||||
|
| `image_size` | Yes | N/A | Not used. Can be any number. |
|
||||||
|
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
|
||||||
|
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
|
||||||
|
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
|
||||||
|
| `learned_variance` | No | `True` | Whether to learn the variance. |
|
||||||
|
|
||||||
|
Any parameter from the `Decoder` constructor can also be given here.
|
||||||
|
|
||||||
|
**<ins>Data</ins>:**
|
||||||
|
|
||||||
|
Settings for creation of the dataloaders.
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
|
||||||
|
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
|
||||||
|
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
|
||||||
|
| `batch_size` | No | `64` | The batch size. |
|
||||||
|
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
|
||||||
|
| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |
|
||||||
|
| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |
|
||||||
|
| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |
|
||||||
|
| `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
|
||||||
|
| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |
|
||||||
|
| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |
|
||||||
|
| `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. |
|
||||||
|
|
||||||
|
[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.
|
||||||
|
|
||||||
|
[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.
|
||||||
|
|
||||||
|
[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.
|
||||||
|
|
||||||
|
**<ins>Train</ins>:**
|
||||||
|
|
||||||
|
Settings for controlling the training hyperparameters.
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `epochs` | No | `20` | The number of epochs in the training run. |
|
||||||
|
| `lr` | No | `1e-4` | The learning rate. |
|
||||||
|
| `wd` | No | `0.01` | The weight decay. |
|
||||||
|
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
|
||||||
|
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
|
||||||
|
| `device` | No | `cuda:0` | The device to train on. |
|
||||||
|
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
|
||||||
|
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
|
||||||
|
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
|
||||||
|
| `ema_beta` | No | `0.99` | The ema coefficient. |
|
||||||
|
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
|
||||||
|
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
|
||||||
|
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
|
||||||
|
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
|
||||||
|
|
||||||
|
**<ins>Evaluate</ins>:**
|
||||||
|
|
||||||
|
Defines which evaluation metrics will be used to test the model.
|
||||||
|
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||||
|
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
||||||
|
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
||||||
|
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
||||||
|
| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |
|
||||||
|
|
||||||
|
**<ins>Tracker</ins>:**
|
||||||
|
|
||||||
|
Selects which tracker to use and configures it.
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
|
||||||
|
| `data_path` | No | `./models` | Where the tracker will store local data. |
|
||||||
|
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
|
||||||
|
|
||||||
|
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
|
||||||
|
|
||||||
|
**<ins>Load</ins>:**
|
||||||
|
|
||||||
|
Selects where to load a pretrained model from.
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `source` | No | `None` | Supports `file` or `wandb`. |
|
||||||
|
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
|
||||||
|
|
||||||
|
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
|
||||||
99
configs/train_decoder_config.example.json
Normal file
99
configs/train_decoder_config.example.json
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
{
|
||||||
|
"unets": [
|
||||||
|
{
|
||||||
|
"dim": 128,
|
||||||
|
"image_embed_dim": 768,
|
||||||
|
"cond_dim": 64,
|
||||||
|
"channels": 3,
|
||||||
|
"dim_mults": [1, 2, 4, 8],
|
||||||
|
"attn_dim_head": 32,
|
||||||
|
"attn_heads": 16
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"decoder": {
|
||||||
|
"image_sizes": [64],
|
||||||
|
"channels": 3,
|
||||||
|
"timesteps": 1000,
|
||||||
|
"loss_type": "l2",
|
||||||
|
"beta_schedule": "cosine",
|
||||||
|
"learned_variance": true
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
|
||||||
|
"embeddings_url": "s3://bucket/embeddings/path/",
|
||||||
|
"num_workers": 4,
|
||||||
|
"batch_size": 64,
|
||||||
|
"start_shard": 0,
|
||||||
|
"end_shard": 9999999,
|
||||||
|
"shard_width": 6,
|
||||||
|
"index_width": 4,
|
||||||
|
"splits": {
|
||||||
|
"train": 0.75,
|
||||||
|
"val": 0.15,
|
||||||
|
"test": 0.1
|
||||||
|
},
|
||||||
|
"shuffle_train": true,
|
||||||
|
"resample_train": false,
|
||||||
|
"preprocessing": {
|
||||||
|
"RandomResizedCrop": {
|
||||||
|
"size": [128, 128],
|
||||||
|
"scale": [0.75, 1.0],
|
||||||
|
"ratio": [1.0, 1.0]
|
||||||
|
},
|
||||||
|
"ToTensor": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"train": {
|
||||||
|
"epochs": 20,
|
||||||
|
"lr": 1e-4,
|
||||||
|
"wd": 0.01,
|
||||||
|
"max_grad_norm": 0.5,
|
||||||
|
"save_every_n_samples": 100000,
|
||||||
|
"n_sample_images": 6,
|
||||||
|
"device": "cuda:0",
|
||||||
|
"epoch_samples": null,
|
||||||
|
"validation_samples": null,
|
||||||
|
"use_ema": true,
|
||||||
|
"ema_beta": 0.99,
|
||||||
|
"amp": false,
|
||||||
|
"save_all": false,
|
||||||
|
"save_latest": true,
|
||||||
|
"save_best": true,
|
||||||
|
"unet_training_mask": [true]
|
||||||
|
},
|
||||||
|
"evaluate": {
|
||||||
|
"n_evaluation_samples": 1000,
|
||||||
|
"FID": {
|
||||||
|
"feature": 64
|
||||||
|
},
|
||||||
|
"IS": {
|
||||||
|
"feature": 64,
|
||||||
|
"splits": 10
|
||||||
|
},
|
||||||
|
"KID": {
|
||||||
|
"feature": 64,
|
||||||
|
"subset_size": 10
|
||||||
|
},
|
||||||
|
"LPIPS": {
|
||||||
|
"net_type": "vgg",
|
||||||
|
"reduction": "mean"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tracker": {
|
||||||
|
"tracker_type": "console",
|
||||||
|
"data_path": "./models",
|
||||||
|
|
||||||
|
"wandb_entity": "",
|
||||||
|
"wandb_project": "",
|
||||||
|
|
||||||
|
"verbose": false
|
||||||
|
},
|
||||||
|
"load": {
|
||||||
|
"source": null,
|
||||||
|
|
||||||
|
"run_path": "",
|
||||||
|
"file_path": "",
|
||||||
|
|
||||||
|
"resume": false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -59,6 +59,9 @@ def default(val, d):
|
|||||||
return d() if isfunction(d) else d
|
return d() if isfunction(d) else d
|
||||||
|
|
||||||
def cast_tuple(val, length = 1):
|
def cast_tuple(val, length = 1):
|
||||||
|
if isinstance(val, list):
|
||||||
|
val = tuple(val)
|
||||||
|
|
||||||
return val if isinstance(val, tuple) else ((val,) * length)
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
def module_device(module):
|
def module_device(module):
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ def get_optimizer(
|
|||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.999),
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
filter_by_requires_grad = False
|
filter_by_requires_grad = False,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
if filter_by_requires_grad:
|
if filter_by_requires_grad:
|
||||||
params = list(filter(lambda t: t.requires_grad, params))
|
params = list(filter(lambda t: t.requires_grad, params))
|
||||||
|
|||||||
@@ -1,20 +1,32 @@
|
|||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import importlib
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from enum import Enum
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
# constants
|
||||||
|
|
||||||
|
DEFAULT_DATA_PATH = './.tracker-data'
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
def import_or_print_error(pkg_name, err_str = None):
|
||||||
try:
|
try:
|
||||||
import wandb
|
return importlib.import_module(pkg_name)
|
||||||
except ImportError as e:
|
except ModuleNotFoundError as e:
|
||||||
print('`pip install wandb` to use the wandb recall function')
|
if exists(err_str):
|
||||||
raise e
|
print(err_str)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# load state dict functions
|
||||||
|
|
||||||
|
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
||||||
|
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||||
file_reference = wandb.restore(file_path, run_path=run_path)
|
file_reference = wandb.restore(file_path, run_path=run_path)
|
||||||
return torch.load(file_reference.name)
|
return torch.load(file_reference.name)
|
||||||
|
|
||||||
@@ -24,11 +36,10 @@ def load_local_state_dict(file_path, **kwargs):
|
|||||||
# base class
|
# base class
|
||||||
|
|
||||||
class BaseTracker(nn.Module):
|
class BaseTracker(nn.Module):
|
||||||
def __init__(self, data_path):
|
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert data_path is not None, "Tracker must have a data_path to save local content"
|
self.data_path = Path(data_path)
|
||||||
self.data_path = os.path.abspath(data_path)
|
self.data_path.mkdir(parents = True, exist_ok = True)
|
||||||
os.makedirs(self.data_path, exist_ok=True)
|
|
||||||
|
|
||||||
def init(self, config, **kwargs):
|
def init(self, config, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -66,28 +77,19 @@ class ConsoleTracker(BaseTracker):
|
|||||||
def log(self, log, **kwargs):
|
def log(self, log, **kwargs):
|
||||||
print(log)
|
print(log)
|
||||||
|
|
||||||
def log_images(self, images, **kwargs):
|
def log_images(self, images, **kwargs): # noop for logging images
|
||||||
"""
|
|
||||||
Currently, do nothing with console logged images
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||||
torch.save(state_dict, os.path.join(self.data_path, relative_path))
|
torch.save(state_dict, str(self.data_path / relative_path))
|
||||||
|
|
||||||
# basic wandb class
|
# basic wandb class
|
||||||
|
|
||||||
class WandbTracker(BaseTracker):
|
class WandbTracker(BaseTracker):
|
||||||
def __init__(self, data_path):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(data_path)
|
super().__init__(*args, **kwargs)
|
||||||
try:
|
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
||||||
import wandb
|
|
||||||
except ImportError as e:
|
|
||||||
print('`pip install wandb` to use the wandb experiment tracker')
|
|
||||||
raise e
|
|
||||||
|
|
||||||
os.environ["WANDB_SILENT"] = "true"
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
self.wandb = wandb
|
|
||||||
|
|
||||||
def init(self, **config):
|
def init(self, **config):
|
||||||
self.wandb.init(**config)
|
self.wandb.init(**config)
|
||||||
@@ -108,6 +110,6 @@ class WandbTracker(BaseTracker):
|
|||||||
"""
|
"""
|
||||||
Saves a state_dict to disk and uploads it
|
Saves a state_dict to disk and uploads it
|
||||||
"""
|
"""
|
||||||
full_path = os.path.join(self.data_path, relative_path)
|
full_path = str(self.data_path / relative_path)
|
||||||
torch.save(state_dict, full_path)
|
torch.save(state_dict, full_path)
|
||||||
self.wandb.save(full_path, base_path=self.data_path) # Upload and keep relative to data_path
|
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
||||||
|
|||||||
135
dalle2_pytorch/train_configs.py
Normal file
135
dalle2_pytorch/train_configs.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import json
|
||||||
|
from torchvision import transforms as T
|
||||||
|
from pydantic import BaseModel, validator, root_validator
|
||||||
|
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
class UnetConfig(BaseModel):
|
||||||
|
dim: int
|
||||||
|
dim_mults: List[int]
|
||||||
|
image_embed_dim: int = None
|
||||||
|
cond_dim: int = None
|
||||||
|
channels: int = 3
|
||||||
|
attn_dim_head: int = 32
|
||||||
|
attn_heads: int = 16
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
class DecoderConfig(BaseModel):
|
||||||
|
image_size: int = None
|
||||||
|
image_sizes: Union[List[int], Tuple[int]] = None
|
||||||
|
channels: int = 3
|
||||||
|
timesteps: int = 1000
|
||||||
|
loss_type: str = 'l2'
|
||||||
|
beta_schedule: str = 'cosine'
|
||||||
|
learned_variance: bool = True
|
||||||
|
|
||||||
|
@validator('image_sizes')
|
||||||
|
def check_image_sizes(cls, image_sizes, values):
|
||||||
|
if exists(values.get('image_size')) ^ exists(image_sizes):
|
||||||
|
return image_sizes
|
||||||
|
raise ValueError('either image_size or image_sizes is required, but not both')
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
class TrainSplitConfig(BaseModel):
|
||||||
|
train: float = 0.75
|
||||||
|
val: float = 0.15
|
||||||
|
test: float = 0.1
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_all(cls, fields):
|
||||||
|
if sum([*fields.values()]) != 1.:
|
||||||
|
raise ValueError(f'{fields.keys()} must sum to 1.0')
|
||||||
|
return fields
|
||||||
|
|
||||||
|
class DecoderDataConfig(BaseModel):
|
||||||
|
webdataset_base_url: str # path to a webdataset with jpg images
|
||||||
|
embeddings_url: str # path to .npy files with embeddings
|
||||||
|
num_workers: int = 4
|
||||||
|
batch_size: int = 64
|
||||||
|
start_shard: int = 0
|
||||||
|
end_shard: int = 9999999
|
||||||
|
shard_width: int = 6
|
||||||
|
index_width: int = 4
|
||||||
|
splits: TrainSplitConfig
|
||||||
|
shuffle_train: bool = True
|
||||||
|
resample_train: bool = False
|
||||||
|
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
||||||
|
|
||||||
|
class DecoderTrainConfig(BaseModel):
|
||||||
|
epochs: int = 20
|
||||||
|
lr: float = 1e-4
|
||||||
|
wd: float = 0.01
|
||||||
|
max_grad_norm: float = 0.5
|
||||||
|
save_every_n_samples: int = 100000
|
||||||
|
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||||
|
device: str = 'cuda:0'
|
||||||
|
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||||
|
validation_samples: int = None # Same as above but for validation.
|
||||||
|
use_ema: bool = True
|
||||||
|
ema_beta: float = 0.99
|
||||||
|
amp: bool = False
|
||||||
|
save_all: bool = False # Whether to preserve all checkpoints
|
||||||
|
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||||
|
save_best: bool = True # Whether to save the best checkpoint
|
||||||
|
unet_training_mask: List[bool] = None # If None, use all unets
|
||||||
|
|
||||||
|
class DecoderEvaluateConfig(BaseModel):
|
||||||
|
n_evaluation_samples: int = 1000
|
||||||
|
FID: Dict[str, Any] = None
|
||||||
|
IS: Dict[str, Any] = None
|
||||||
|
KID: Dict[str, Any] = None
|
||||||
|
LPIPS: Dict[str, Any] = None
|
||||||
|
|
||||||
|
class TrackerConfig(BaseModel):
|
||||||
|
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||||
|
data_path: str = './models' # The path where files will be saved locally
|
||||||
|
init_config: Dict[str, Any] = None
|
||||||
|
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||||
|
wandb_project: str = ''
|
||||||
|
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||||
|
|
||||||
|
class DecoderLoadConfig(BaseModel):
|
||||||
|
source: str = None # Supports file and wandb
|
||||||
|
run_path: str = '' # Used only if source is wandb
|
||||||
|
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
||||||
|
resume: bool = False # If using wandb, whether to resume the run
|
||||||
|
|
||||||
|
class TrainDecoderConfig(BaseModel):
|
||||||
|
unets: List[UnetConfig]
|
||||||
|
decoder: DecoderConfig
|
||||||
|
data: DecoderDataConfig
|
||||||
|
train: DecoderTrainConfig
|
||||||
|
evaluate: DecoderEvaluateConfig
|
||||||
|
tracker: TrackerConfig
|
||||||
|
load: DecoderLoadConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_path(cls, json_path):
|
||||||
|
with open(json_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def img_preproc(self):
|
||||||
|
def _get_transformation(transformation_name, **kwargs):
|
||||||
|
if transformation_name == "RandomResizedCrop":
|
||||||
|
return T.RandomResizedCrop(**kwargs)
|
||||||
|
elif transformation_name == "RandomHorizontalFlip":
|
||||||
|
return T.RandomHorizontalFlip()
|
||||||
|
elif transformation_name == "ToTensor":
|
||||||
|
return T.ToTensor()
|
||||||
|
|
||||||
|
transforms = []
|
||||||
|
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
|
||||||
|
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
||||||
|
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
||||||
|
return T.Compose(transforms)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
|
from pathlib import Path
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
@@ -55,6 +56,10 @@ def num_to_groups(num, divisor):
|
|||||||
arr.append(remainder)
|
arr.append(remainder)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
def get_pkg_version():
|
||||||
|
from pkg_resources import get_distribution
|
||||||
|
return get_distribution('dalle2_pytorch').version
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
def cast_torch_tensor(fn):
|
def cast_torch_tensor(fn):
|
||||||
@@ -191,7 +196,7 @@ class EMA(nn.Module):
|
|||||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
||||||
|
|
||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
|
|
||||||
def restore_ema_model_device(self):
|
def restore_ema_model_device(self):
|
||||||
device = self.initted.device
|
device = self.initted.device
|
||||||
@@ -287,7 +292,47 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
|
|
||||||
|
def save(self, path, overwrite = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert not (path.exists() and not overwrite)
|
||||||
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
save_obj = dict(
|
||||||
|
scaler = self.scaler.state_dict(),
|
||||||
|
optimizer = self.optimizer.state_dict(),
|
||||||
|
model = self.diffusion_prior.state_dict(),
|
||||||
|
version = get_pkg_version(),
|
||||||
|
step = self.step.item()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
|
||||||
|
|
||||||
|
torch.save(save_obj, str(path))
|
||||||
|
|
||||||
|
def load(self, path, only_model = False, strict = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
|
if get_pkg_version() != loaded_obj['version']:
|
||||||
|
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
|
||||||
|
|
||||||
|
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
|
if only_model:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||||
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
assert 'ema' in loaded_obj
|
||||||
|
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
@@ -410,6 +455,57 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
|
def save(self, path, overwrite = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert not (path.exists() and not overwrite)
|
||||||
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
save_obj = dict(
|
||||||
|
model = self.decoder.state_dict(),
|
||||||
|
version = get_pkg_version(),
|
||||||
|
step = self.step.item()
|
||||||
|
)
|
||||||
|
|
||||||
|
for ind in range(0, self.num_unets):
|
||||||
|
scaler_key = f'scaler{ind}'
|
||||||
|
optimizer_key = f'scaler{ind}'
|
||||||
|
scaler = getattr(self, scaler_key)
|
||||||
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||||
|
|
||||||
|
torch.save(save_obj, str(path))
|
||||||
|
|
||||||
|
def load(self, path, only_model = False, strict = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
|
if get_pkg_version() != loaded_obj['version']:
|
||||||
|
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
||||||
|
|
||||||
|
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
|
if only_model:
|
||||||
|
return
|
||||||
|
|
||||||
|
for ind in range(0, self.num_unets):
|
||||||
|
scaler_key = f'scaler{ind}'
|
||||||
|
optimizer_key = f'scaler{ind}'
|
||||||
|
scaler = getattr(self, scaler_key)
|
||||||
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
|
||||||
|
scaler.load_state_dict(loaded_obj[scaler_key])
|
||||||
|
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
assert 'ema' in loaded_obj
|
||||||
|
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unets(self):
|
def unets(self):
|
||||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
|
|||||||
11
dalle2_pytorch/utils.py
Normal file
11
dalle2_pytorch/utils.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.last_time = time.time()
|
||||||
|
|
||||||
|
def elapsed(self):
|
||||||
|
return time.time() - self.last_time
|
||||||
6
setup.py
6
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.3.3',
|
version = '0.4.5',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -32,6 +32,7 @@ setup(
|
|||||||
'kornia>=0.5.4',
|
'kornia>=0.5.4',
|
||||||
'numpy',
|
'numpy',
|
||||||
'pillow',
|
'pillow',
|
||||||
|
'pydantic',
|
||||||
'resize-right>=0.0.2',
|
'resize-right>=0.0.2',
|
||||||
'rotary-embedding-torch',
|
'rotary-embedding-torch',
|
||||||
'torch>=1.10',
|
'torch>=1.10',
|
||||||
@@ -41,7 +42,8 @@ setup(
|
|||||||
'x-clip>=0.4.4',
|
'x-clip>=0.4.4',
|
||||||
'youtokentome',
|
'youtokentome',
|
||||||
'webdataset>=0.2.5',
|
'webdataset>=0.2.5',
|
||||||
'fsspec>=2022.1.0'
|
'fsspec>=2022.1.0',
|
||||||
|
'torchmetrics[image]>=0.8.0'
|
||||||
],
|
],
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Development Status :: 4 - Beta',
|
'Development Status :: 4 - Beta',
|
||||||
|
|||||||
456
train_decoder.py
Normal file
456
train_decoder.py
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
from dalle2_pytorch import Unet, Decoder
|
||||||
|
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
||||||
|
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||||
|
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||||
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
|
from dalle2_pytorch.utils import Timer
|
||||||
|
|
||||||
|
import torchvision
|
||||||
|
import torch
|
||||||
|
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||||
|
from torchmetrics.image.inception import InceptionScore
|
||||||
|
from torchmetrics.image.kid import KernelInceptionDistance
|
||||||
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||||
|
import webdataset as wds
|
||||||
|
import click
|
||||||
|
|
||||||
|
# constants
|
||||||
|
|
||||||
|
TRAIN_CALC_LOSS_EVERY_ITERS = 10
|
||||||
|
VALID_CALC_LOSS_EVERY_ITERS = 10
|
||||||
|
|
||||||
|
# helpers functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
# main functions
|
||||||
|
|
||||||
|
def create_dataloaders(
|
||||||
|
available_shards,
|
||||||
|
webdataset_base_url,
|
||||||
|
embeddings_url,
|
||||||
|
shard_width=6,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=32,
|
||||||
|
n_sample_images=6,
|
||||||
|
shuffle_train=True,
|
||||||
|
resample_train=False,
|
||||||
|
img_preproc = None,
|
||||||
|
index_width=4,
|
||||||
|
train_prop = 0.75,
|
||||||
|
val_prop = 0.15,
|
||||||
|
test_prop = 0.10,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
|
||||||
|
"""
|
||||||
|
assert train_prop + test_prop + val_prop == 1
|
||||||
|
num_train = round(train_prop*len(available_shards))
|
||||||
|
num_test = round(test_prop*len(available_shards))
|
||||||
|
num_val = len(available_shards) - num_train - num_test
|
||||||
|
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
|
||||||
|
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
|
||||||
|
|
||||||
|
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
|
||||||
|
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
|
||||||
|
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
||||||
|
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
||||||
|
|
||||||
|
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
|
||||||
|
tar_url=tar_urls,
|
||||||
|
num_workers=num_workers,
|
||||||
|
batch_size=batch_size if not for_sampling else n_sample_images,
|
||||||
|
embeddings_url=embeddings_url,
|
||||||
|
index_width=index_width,
|
||||||
|
shuffle_num = None,
|
||||||
|
extra_keys= ["txt"] if with_text else [],
|
||||||
|
shuffle_shards = shuffle,
|
||||||
|
resample_shards = resample,
|
||||||
|
img_preproc=img_preproc,
|
||||||
|
handler=wds.handlers.warn_and_continue
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
||||||
|
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
||||||
|
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
|
||||||
|
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
|
||||||
|
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
||||||
|
return {
|
||||||
|
"train": train_dataloader,
|
||||||
|
"train_sampling": train_sampling_dataloader,
|
||||||
|
"val": val_dataloader,
|
||||||
|
"test": test_dataloader,
|
||||||
|
"test_sampling": test_sampling_dataloader
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_decoder(device, decoder_config, unets_config):
|
||||||
|
"""Creates a sample decoder"""
|
||||||
|
|
||||||
|
unets = [Unet(**config.dict()) for config in unets_config]
|
||||||
|
|
||||||
|
decoder = Decoder(
|
||||||
|
unet=unets,
|
||||||
|
**decoder_config.dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder.to(device=device)
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
def get_dataset_keys(dataloader):
|
||||||
|
"""
|
||||||
|
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
||||||
|
"""
|
||||||
|
# If the dataloader is actually a WebLoader, we need to extract the real dataloader
|
||||||
|
if isinstance(dataloader, wds.WebLoader):
|
||||||
|
dataloader = dataloader.pipeline[0]
|
||||||
|
return dataloader.dataset.key_map
|
||||||
|
|
||||||
|
def get_example_data(dataloader, device, n=5):
|
||||||
|
"""
|
||||||
|
Samples the dataloader and returns a zipped list of examples
|
||||||
|
"""
|
||||||
|
images = []
|
||||||
|
embeddings = []
|
||||||
|
captions = []
|
||||||
|
dataset_keys = get_dataset_keys(dataloader)
|
||||||
|
has_caption = "txt" in dataset_keys
|
||||||
|
for data in dataloader:
|
||||||
|
if has_caption:
|
||||||
|
img, emb, txt = data
|
||||||
|
else:
|
||||||
|
img, emb = data
|
||||||
|
txt = [""] * emb.shape[0]
|
||||||
|
img = img.to(device=device, dtype=torch.float)
|
||||||
|
emb = emb.to(device=device, dtype=torch.float)
|
||||||
|
images.extend(list(img))
|
||||||
|
embeddings.extend(list(emb))
|
||||||
|
captions.extend(list(txt))
|
||||||
|
if len(images) >= n:
|
||||||
|
break
|
||||||
|
print("Generated {} examples".format(len(images)))
|
||||||
|
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
||||||
|
|
||||||
|
def generate_samples(trainer, example_data, text_prepend=""):
|
||||||
|
"""
|
||||||
|
Takes example data and generates images from the embeddings
|
||||||
|
Returns three lists: real images, generated images, and captions
|
||||||
|
"""
|
||||||
|
real_images, embeddings, txts = zip(*example_data)
|
||||||
|
embeddings_tensor = torch.stack(embeddings)
|
||||||
|
samples = trainer.sample(embeddings_tensor)
|
||||||
|
generated_images = list(samples)
|
||||||
|
captions = [text_prepend + txt for txt in txts]
|
||||||
|
return real_images, generated_images, captions
|
||||||
|
|
||||||
|
def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||||
|
"""
|
||||||
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||||
|
"""
|
||||||
|
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||||
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||||
|
return grid_images, captions
|
||||||
|
|
||||||
|
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||||
|
"""
|
||||||
|
Computes evaluation metrics for the decoder
|
||||||
|
"""
|
||||||
|
metrics = {}
|
||||||
|
# Prepare the data
|
||||||
|
examples = get_example_data(dataloader, device, n_evaluation_samples)
|
||||||
|
real_images, generated_images, captions = generate_samples(trainer, examples)
|
||||||
|
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||||
|
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||||
|
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||||
|
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||||
|
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||||
|
if exists(FID):
|
||||||
|
fid = FrechetInceptionDistance(**FID)
|
||||||
|
fid.to(device=device)
|
||||||
|
fid.update(int_real_images, real=True)
|
||||||
|
fid.update(int_generated_images, real=False)
|
||||||
|
metrics["FID"] = fid.compute().item()
|
||||||
|
if exists(IS):
|
||||||
|
inception = InceptionScore(**IS)
|
||||||
|
inception.to(device=device)
|
||||||
|
inception.update(int_real_images)
|
||||||
|
is_mean, is_std = inception.compute()
|
||||||
|
metrics["IS_mean"] = is_mean.item()
|
||||||
|
metrics["IS_std"] = is_std.item()
|
||||||
|
if exists(KID):
|
||||||
|
kernel_inception = KernelInceptionDistance(**KID)
|
||||||
|
kernel_inception.to(device=device)
|
||||||
|
kernel_inception.update(int_real_images, real=True)
|
||||||
|
kernel_inception.update(int_generated_images, real=False)
|
||||||
|
kid_mean, kid_std = kernel_inception.compute()
|
||||||
|
metrics["KID_mean"] = kid_mean.item()
|
||||||
|
metrics["KID_std"] = kid_std.item()
|
||||||
|
if exists(LPIPS):
|
||||||
|
# Convert from [0, 1] to [-1, 1]
|
||||||
|
renorm_real_images = real_images.mul(2).sub(1)
|
||||||
|
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||||
|
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
|
||||||
|
lpips.to(device=device)
|
||||||
|
lpips.update(renorm_real_images, renorm_generated_images)
|
||||||
|
metrics["LPIPS"] = lpips.compute().item()
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
|
||||||
|
"""
|
||||||
|
Logs the model with an appropriate method depending on the tracker
|
||||||
|
"""
|
||||||
|
if isinstance(relative_paths, str):
|
||||||
|
relative_paths = [relative_paths]
|
||||||
|
trainer_state_dict = {}
|
||||||
|
trainer_state_dict["trainer"] = trainer.state_dict()
|
||||||
|
trainer_state_dict['epoch'] = epoch
|
||||||
|
trainer_state_dict['step'] = step
|
||||||
|
trainer_state_dict['validation_losses'] = validation_losses
|
||||||
|
for relative_path in relative_paths:
|
||||||
|
tracker.save_state_dict(trainer_state_dict, relative_path)
|
||||||
|
|
||||||
|
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
||||||
|
"""
|
||||||
|
Loads the model with an appropriate method depending on the tracker
|
||||||
|
"""
|
||||||
|
print(print_ribbon(f"Loading model from {recall_source}"))
|
||||||
|
state_dict = tracker.recall_state_dict(recall_source, **load_config)
|
||||||
|
trainer.load_state_dict(state_dict["trainer"])
|
||||||
|
print("Model loaded")
|
||||||
|
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
||||||
|
|
||||||
|
def train(
|
||||||
|
dataloaders,
|
||||||
|
decoder,
|
||||||
|
tracker,
|
||||||
|
inference_device,
|
||||||
|
load_config=None,
|
||||||
|
evaluate_config=None,
|
||||||
|
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||||
|
validation_samples = None,
|
||||||
|
epochs = 20,
|
||||||
|
n_sample_images = 5,
|
||||||
|
save_every_n_samples = 100000,
|
||||||
|
save_all=False,
|
||||||
|
save_latest=True,
|
||||||
|
save_best=True,
|
||||||
|
unet_training_mask=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Trains a decoder on a dataset.
|
||||||
|
"""
|
||||||
|
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
|
||||||
|
decoder,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
# Set up starting model and parameters based on a recalled state dict
|
||||||
|
start_step = 0
|
||||||
|
start_epoch = 0
|
||||||
|
validation_losses = []
|
||||||
|
|
||||||
|
if exists(load_config) and exists(load_config.source):
|
||||||
|
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
|
||||||
|
trainer.to(device=inference_device)
|
||||||
|
|
||||||
|
if not exists(unet_training_mask):
|
||||||
|
# Then the unet mask should be true for all unets in the decoder
|
||||||
|
unet_training_mask = [True] * trainer.num_unets
|
||||||
|
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||||
|
|
||||||
|
print(print_ribbon("Generating Example Data", repeat=40))
|
||||||
|
print("This can take a while to load the shard lists...")
|
||||||
|
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
|
||||||
|
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
|
||||||
|
|
||||||
|
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
||||||
|
step = start_step
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, epochs):
|
||||||
|
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
||||||
|
|
||||||
|
timer = Timer()
|
||||||
|
|
||||||
|
sample = 0
|
||||||
|
last_sample = 0
|
||||||
|
last_snapshot = 0
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||||
|
step += 1
|
||||||
|
sample += img.shape[0]
|
||||||
|
img, emb = send_to_device((img, emb))
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
for unet in range(1, trainer.num_unets+1):
|
||||||
|
# Check if this is a unet we are training
|
||||||
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
||||||
|
trainer.update(unet_number=unet)
|
||||||
|
losses.append(loss)
|
||||||
|
|
||||||
|
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
||||||
|
|
||||||
|
timer.reset()
|
||||||
|
last_sample = sample
|
||||||
|
|
||||||
|
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
|
||||||
|
average_loss = sum(losses) / len(losses)
|
||||||
|
log_data = {
|
||||||
|
"Training loss": average_loss,
|
||||||
|
"Epoch": epoch,
|
||||||
|
"Sample": sample,
|
||||||
|
"Step": i,
|
||||||
|
"Samples per second": samples_per_sec
|
||||||
|
}
|
||||||
|
tracker.log(log_data, step=step, verbose=True)
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||||
|
last_snapshot = sample
|
||||||
|
# We need to know where the model should be saved
|
||||||
|
save_paths = []
|
||||||
|
if save_latest:
|
||||||
|
save_paths.append("latest.pth")
|
||||||
|
if save_all:
|
||||||
|
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
||||||
|
|
||||||
|
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||||
|
|
||||||
|
if exists(n_sample_images) and n_sample_images > 0:
|
||||||
|
trainer.eval()
|
||||||
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||||
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||||
|
|
||||||
|
if exists(epoch_samples) and sample >= epoch_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
trainer.eval()
|
||||||
|
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
|
||||||
|
with torch.no_grad():
|
||||||
|
sample = 0
|
||||||
|
average_loss = 0
|
||||||
|
timer = Timer()
|
||||||
|
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||||
|
sample += img.shape[0]
|
||||||
|
img, emb = send_to_device((img, emb))
|
||||||
|
|
||||||
|
for unet in range(1, len(decoder.unets)+1):
|
||||||
|
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
||||||
|
average_loss += loss
|
||||||
|
|
||||||
|
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||||
|
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
|
||||||
|
print(f"Loss: {average_loss / (i+1)}")
|
||||||
|
print("")
|
||||||
|
|
||||||
|
if exists(validation_samples) and sample >= validation_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
average_loss /= i+1
|
||||||
|
log_data = {
|
||||||
|
"Validation loss": average_loss
|
||||||
|
}
|
||||||
|
tracker.log(log_data, step=step, verbose=True)
|
||||||
|
|
||||||
|
# Compute evaluation metrics
|
||||||
|
if exists(evaluate_config):
|
||||||
|
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
||||||
|
tracker.log(evaluation, step=step, verbose=True)
|
||||||
|
|
||||||
|
# Generate sample images
|
||||||
|
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||||
|
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
||||||
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||||
|
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
|
||||||
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||||
|
|
||||||
|
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||||
|
# Get the same paths
|
||||||
|
save_paths = []
|
||||||
|
if save_latest:
|
||||||
|
save_paths.append("latest.pth")
|
||||||
|
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
||||||
|
save_paths.append("best.pth")
|
||||||
|
validation_losses.append(average_loss)
|
||||||
|
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||||
|
|
||||||
|
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Creates a tracker of the specified type and initializes special features based on the full config
|
||||||
|
"""
|
||||||
|
tracker_config = config.tracker
|
||||||
|
init_config = {}
|
||||||
|
|
||||||
|
if exists(tracker_config.init_config):
|
||||||
|
init_config["config"] = tracker_config.init_config
|
||||||
|
|
||||||
|
if tracker_type == "console":
|
||||||
|
tracker = ConsoleTracker(**init_config)
|
||||||
|
elif tracker_type == "wandb":
|
||||||
|
# We need to initialize the resume state here
|
||||||
|
load_config = config.load
|
||||||
|
if load_config.source == "wandb" and load_config.resume:
|
||||||
|
# Then we are resuming the run load_config["run_path"]
|
||||||
|
run_id = load_config.run_path.split("/")[-1]
|
||||||
|
init_config["id"] = run_id
|
||||||
|
init_config["resume"] = "must"
|
||||||
|
|
||||||
|
init_config["entity"] = tracker_config.wandb_entity
|
||||||
|
init_config["project"] = tracker_config.wandb_project
|
||||||
|
tracker = WandbTracker(data_path)
|
||||||
|
tracker.init(**init_config)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
||||||
|
return tracker
|
||||||
|
|
||||||
|
def initialize_training(config):
|
||||||
|
# Create the save path
|
||||||
|
if "cuda" in config.train.device:
|
||||||
|
assert torch.cuda.is_available(), "CUDA is not available"
|
||||||
|
device = torch.device(config.train.device)
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||||
|
|
||||||
|
dataloaders = create_dataloaders (
|
||||||
|
available_shards=all_shards,
|
||||||
|
img_preproc = config.img_preproc,
|
||||||
|
train_prop = config.data.splits.train,
|
||||||
|
val_prop = config.data.splits.val,
|
||||||
|
test_prop = config.data.splits.test,
|
||||||
|
n_sample_images=config.train.n_sample_images,
|
||||||
|
**config.data.dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = create_decoder(device, config.decoder, config.unets)
|
||||||
|
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||||
|
print(print_ribbon("Loaded Config", repeat=40))
|
||||||
|
print(f"Number of parameters: {num_parameters}")
|
||||||
|
|
||||||
|
tracker = create_tracker(config, **config.tracker.dict())
|
||||||
|
|
||||||
|
train(dataloaders, decoder,
|
||||||
|
tracker=tracker,
|
||||||
|
inference_device=device,
|
||||||
|
load_config=config.load,
|
||||||
|
evaluate_config=config.evaluate,
|
||||||
|
**config.train.dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a simple click command line interface to load the config and start the training
|
||||||
|
@click.command()
|
||||||
|
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||||
|
def main(config_file):
|
||||||
|
print("Recalling config from {}".format(config_file))
|
||||||
|
config = TrainDecoderConfig.from_json_path(config_file)
|
||||||
|
initialize_training(config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import click
|
import click
|
||||||
import math
|
import math
|
||||||
import time
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -13,6 +12,7 @@ from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdap
|
|||||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||||
|
|
||||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||||
|
from dalle2_pytorch.utils import Timer
|
||||||
|
|
||||||
from embedding_reader import EmbeddingReader
|
from embedding_reader import EmbeddingReader
|
||||||
|
|
||||||
@@ -29,16 +29,6 @@ tracker = WandbTracker()
|
|||||||
def exists(val):
|
def exists(val):
|
||||||
val is not None
|
val is not None
|
||||||
|
|
||||||
class Timer:
|
|
||||||
def __init__(self):
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.last_time = time.time()
|
|
||||||
|
|
||||||
def elapsed(self):
|
|
||||||
return time.time() - self.last_time
|
|
||||||
|
|
||||||
# functions
|
# functions
|
||||||
|
|
||||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||||
|
|||||||
Reference in New Issue
Block a user