Compare commits

...

12 Commits
0.3.2 ... 0.3.9

Author SHA1 Message Date
Phil Wang
6ea65f59cc move config parsing logic to own file, consider whether to find an off-the-shelf solution at future date 2022-05-21 10:25:15 -07:00
Phil Wang
0064661729 small cleanup of decoder train script 2022-05-21 10:17:13 -07:00
Phil Wang
b895f52843 appreciation section 2022-05-21 08:32:12 -07:00
Phil Wang
80497e9839 accept unets as list for decoder 2022-05-20 20:31:26 -07:00
Phil Wang
f526f14d7c bump 2022-05-20 20:20:40 -07:00
Phil Wang
8997f178d6 small cleanup with timer 2022-05-20 20:05:01 -07:00
Aidan Dempster
022c94e443 Added single GPU training script for decoder (#108)
Added config files for training

Changed example image generation to be more efficient

Added configuration description to README

Removed unused import
2022-05-20 19:46:19 -07:00
Phil Wang
430961cb97 it was correct the first time, my bad 2022-05-20 18:05:15 -07:00
Phil Wang
721f9687c1 fix wandb logging in tracker, and do some cleanup 2022-05-20 17:27:43 -07:00
Aidan Dempster
e0524a6aff Implemented the wandb tracker (#106)
Added a base_path parameter to all trackers for storing any local information they need to
2022-05-20 16:39:23 -07:00
Aidan Dempster
c85e0d5c35 Update decoder dataloader (#105)
* Updated the decoder dataloader
Removed unnecessary logging for required packages
Transferred to using index width instead of shard width
Added the ability to select extra keys to return from the webdataset

* Added README for decoder loader
2022-05-20 16:38:55 -07:00
Phil Wang
db0642c4cd quick fix for @marunine 2022-05-18 20:22:52 -07:00
15 changed files with 1055 additions and 45 deletions

9
.gitignore vendored
View File

@@ -1,3 +1,12 @@
# default experiment tracker data
.tracker-data/
# Configuration Files
configs/*
!configs/*.example
!configs/*_defaults.py
!configs/README.md
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@@ -1034,6 +1034,18 @@ Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Appreciation
This library would not have gotten to this working state without the help of
- <a href="https://github.com/nousr">Zion</a> and <a href="https://github.com/krish240574">Kumar</a> for the diffusion training script
- <a href="https://github.com/Veldrovive">Aidan</a> for the decoder training script and dataloaders
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
... and many others. Thank you! 🙏
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon

109
configs/README.md Normal file
View File

@@ -0,0 +1,109 @@
## DALLE2 Training Configurations
For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
### Decoder Trainer
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example).
**<ins>Unets</ins>:**
Each member of this array defines a single unet that will be added to the decoder.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `dim` | Yes | N/A | The starting channels of the unet. |
| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |
| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |
Any parameter from the `Unet` constructor can also be given here.
**<ins>Decoder</ins>:**
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
| `image_size` | Yes | N/A | Not used. Can be any number. |
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
| `learned_variance` | No | `True` | Whether to learn the variance. |
Any parameter from the `Decoder` constructor can also be given here.
**<ins>Data</ins>:**
Settings for creation of the dataloaders.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
| `batch_size` | No | `64` | The batch size. |
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |
| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |
| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |
| `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |
| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |
| `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. |
[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.
[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.
[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.
**<ins>Train</ins>:**
Settings for controlling the training hyperparameters.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `epochs` | No | `20` | The number of epochs in the training run. |
| `lr` | No | `1e-4` | The learning rate. |
| `wd` | No | `0.01` | The weight decay. |
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
| `device` | No | `cuda:0` | The device to train on. |
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
| `ema_beta` | No | `0.99` | The ema coefficient. |
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
**<ins>Evaluate</ins>:**
Defines which evaluation metrics will be used to test the model.
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |
**<ins>Tracker</ins>:**
Selects which tracker to use and configures it.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
| `data_path` | No | `./models` | Where the tracker will store local data. |
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
**<ins>Load</ins>:**
Selects where to load a pretrained model from.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `source` | No | `None` | Supports `file` or `wandb`. |
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).

View File

@@ -0,0 +1,82 @@
"""
Defines the default values for the decoder config
"""
from enum import Enum
class ConfigField(Enum):
REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it.
default_config = {
"unets": ConfigField.REQUIRED,
"decoder": {
"image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet
"image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": True
},
"data": {
"webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images
"embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": True,
"resample_train": False,
"preprocessing": {
"ToTensor": True
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset
"device": "cuda:0",
"epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
"validation_samples": None, # Same as above but for validation.
"use_ema": True,
"ema_beta": 0.99,
"amp": False,
"save_all": False, # Whether to preserve all checkpoints
"save_latest": True, # Whether to always save the latest checkpoint
"save_best": True, # Whether to save the best checkpoint
"unet_training_mask": None # If None, use all unets
},
"evaluate": {
"n_evalation_samples": 1000,
"FID": None,
"IS": None,
"KID": None,
"LPIPS": None
},
"tracker": {
"tracker_type": "console", # Decoder currently supports console and wandb
"data_path": "./models", # The path where files will be saved locally
"wandb_entity": "", # Only needs to be set if tracker_type is wandb
"wandb_project": "",
"verbose": False # Whether to print console logging for non-console trackers
},
"load": {
"source": None, # Supports file and wandb
"run_path": "", # Used only if source is wandb
"file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
"resume": False # If using wandb, whether to resume the run
}
}

View File

@@ -0,0 +1,100 @@
{
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": {
"image_sizes": [64],
"image_size": [64],
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": true
},
"data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"embeddings_url": "s3://bucket/embeddings/path/",
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": true,
"resample_train": false,
"preprocessing": {
"RandomResizedCrop": {
"size": [128, 128],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6,
"device": "cuda:0",
"epoch_samples": null,
"validation_samples": null,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
"n_evalation_samples": 1000,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 10
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"tracker_type": "console",
"data_path": "./models",
"wandb_entity": "",
"wandb_project": "",
"verbose": false
},
"load": {
"source": null,
"run_path": "",
"file_path": "",
"resume": false
}
}

View File

@@ -59,6 +59,9 @@ def default(val, d):
return d() if isfunction(d) else d
def cast_tuple(val, length = 1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else ((val,) * length)
def module_device(module):
@@ -1697,7 +1700,8 @@ class Decoder(BaseGaussianDiffusion):
clip_adapter_overrides = dict(),
learned_variance = True,
vb_loss_weight = 0.001,
unconditional = False
unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
):
super().__init__(
beta_schedule = beta_schedule,
@@ -1806,6 +1810,10 @@ class Decoder(BaseGaussianDiffusion):
self.clip_denoised = clip_denoised
self.clip_x_start = clip_x_start
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
@@ -1877,7 +1885,7 @@ class Decoder(BaseGaussianDiffusion):
img = torch.randn(shape, device = device)
if not is_latent_diffusion:
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(
@@ -1894,7 +1902,7 @@ class Decoder(BaseGaussianDiffusion):
clip_denoised = clip_denoised
)
unnormalize_img = unnormalize_zero_to_one(img)
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
@@ -1903,8 +1911,8 @@ class Decoder(BaseGaussianDiffusion):
# normalize to [-1, 1]
if not is_latent_diffusion:
x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
x_start = self.normalize_img(x_start)
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
# get x_t

View File

@@ -0,0 +1,41 @@
## Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
### Decoder: Image Embedding Dataset
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
Generating a dataset of this type:
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
Usage:
```python
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4,
batch_size=32,
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
shuffle_shards=True, # Shuffle the order the shards are read in
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
)
for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb.shape) # torch.Size([32, 512])
# Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually
dataset = ImageEmbeddingDataset(
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
embedding_folder_url="path/or/url/to/embeddings/folder",
shard_width=4,
shuffle_shards=True,
resample=False
)
```

View File

@@ -3,6 +3,7 @@ import webdataset as wds
import torch
import numpy as np
import fsspec
import shutil
def get_shard(filename):
"""
@@ -20,7 +21,7 @@ def get_example_file(fs, path, file_format):
"""
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handlers.reraise_exception):
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
previous_tar_url = None
current_embeddings = None
@@ -50,8 +51,12 @@ def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handler
previous_tar_url = tar_url
current_embeddings = load_corresponding_embeds(tar_url)
embedding_index = int(key[shard_width:])
sample["npy"] = current_embeddings[embedding_index]
embedding_index = int(key[-index_width:])
embedding = current_embeddings[embedding_index]
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
if torch.count_nonzero(embedding) == 0:
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
sample["npy"] = embedding
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
@@ -60,6 +65,28 @@ def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handler
break
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
"""Finds if the is a corresponding embedding for the tarfile at { url: [URL] }"""
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
embedding_files = embeddings_fs.ls(embeddings_path)
get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
for tarfile in tarfiles:
try:
webdataset_shard = get_tar_shard(tarfile["url"])
# If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one
if webdataset_shard in embedding_shards:
yield tarfile
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
def verify_keys(samples, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
@@ -86,7 +113,9 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
self,
urls,
embedding_folder_url=None,
shard_width=None,
index_width=None,
img_preproc=None,
extra_keys=[],
handler=wds.handlers.reraise_exception,
resample=False,
shuffle_shards=True
@@ -97,13 +126,31 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
:param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
:param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard with this 4 and the last three digits are the index.
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
:param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.
:param handler: A webdataset handler.
:param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
:param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
"""
super().__init__()
keys = ["jpg", "npy"] + extra_keys
self.key_map = {key: i for i, key in enumerate(keys)}
self.resampling = resample
self.img_preproc = img_preproc
# If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up
if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])):
# Then this has an s3 link for the webdataset and we need extra packages
if shutil.which("s3cmd") is None:
raise RuntimeError("s3cmd is required for s3 webdataset")
if "s3:" in embedding_folder_url:
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
try:
import s3fs
except ImportError:
raise RuntimeError("s3fs is required to load embeddings from s3")
# Add the shardList and randomize or resample if requested
if resample:
assert not shuffle_shards, "Cannot both resample and shuffle"
@@ -112,28 +159,43 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
self.append(wds.SimpleShardList(urls))
if shuffle_shards:
self.append(wds.filters.shuffle(1000))
if embedding_folder_url is not None:
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
self.append(wds.split_by_node)
self.append(wds.split_by_worker)
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("torchrgb"))
self.append(wds.decode("pilrgb", handler=handler))
if embedding_folder_url is not None:
assert shard_width is not None, "Reading embeddings separately requires shard length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, shard_width=shard_width, handler=handler))
# Then we are loading embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
self.append(verify_keys)
self.append(wds.to_tuple("jpg", "npy"))
# Apply preprocessing
self.append(wds.map(self.preproc))
self.append(wds.to_tuple(*keys))
def preproc(self, sample):
"""Applies the preprocessing for images"""
if self.img_preproc is not None:
sample["jpg"] = self.img_preproc(sample["jpg"])
return sample
def create_image_embedding_dataloader(
tar_url,
num_workers,
batch_size,
embeddings_url=None,
shard_width=None,
index_width=None,
shuffle_num = None,
shuffle_shards = True,
resample_shards = False,
handler=wds.handlers.warn_and_continue
img_preproc=None,
extra_keys=[],
handler=wds.handlers.reraise_exception#warn_and_continue
):
"""
Convenience function to create an image embedding dataseta and dataloader in one line
@@ -143,8 +205,8 @@ def create_image_embedding_dataloader(
:param batch_size: The batch size to use for the dataloader
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index.
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
:param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
@@ -153,9 +215,11 @@ def create_image_embedding_dataloader(
ds = ImageEmbeddingDataset(
tar_url,
embeddings_url,
shard_width=shard_width,
index_width=index_width,
shuffle_shards=shuffle_shards,
resample=resample_shards,
extra_keys=extra_keys,
img_preproc=img_preproc,
handler=handler
)
if shuffle_num is not None and shuffle_num > 0:

View File

@@ -11,7 +11,8 @@ def get_optimizer(
wd = 1e-2,
betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False
filter_by_requires_grad = False,
**kwargs
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))

View File

@@ -1,17 +1,45 @@
import os
from pathlib import Path
import importlib
from itertools import zip_longest
import torch
from torch import nn
# constants
DEFAULT_DATA_PATH = './.tracker-data'
# helper functions
def exists(val):
return val is not None
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()
# load state dict functions
def load_wandb_state_dict(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return torch.load(file_reference.name)
def load_local_state_dict(file_path, **kwargs):
return torch.load(file_path)
# base class
class BaseTracker(nn.Module):
def __init__(self):
def __init__(self, data_path = DEFAULT_DATA_PATH):
super().__init__()
self.data_path = Path(data_path)
self.data_path.mkdir(parents = True, exist_ok = True)
def init(self, config, **kwargs):
raise NotImplementedError
@@ -19,6 +47,27 @@ class BaseTracker(nn.Module):
def log(self, log, **kwargs):
raise NotImplementedError
def log_images(self, images, **kwargs):
raise NotImplementedError
def save_state_dict(self, state_dict, relative_path, **kwargs):
raise NotImplementedError
def recall_state_dict(self, recall_source, *args, **kwargs):
"""
Loads a state dict from any source.
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
this should not be linked to any individual tracker.
"""
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
if recall_source == 'wandb':
return load_wandb_state_dict(*args, **kwargs)
elif recall_source == 'local':
return load_local_state_dict(*args, **kwargs)
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
# basic stdout class
class ConsoleTracker(BaseTracker):
@@ -28,22 +77,39 @@ class ConsoleTracker(BaseTracker):
def log(self, log, **kwargs):
print(log)
def log_images(self, images, **kwargs): # noop for logging images
pass
def save_state_dict(self, state_dict, relative_path, **kwargs):
torch.save(state_dict, str(self.data_path / relative_path))
# basic wandb class
class WandbTracker(BaseTracker):
def __init__(self):
super().__init__()
try:
import wandb
except ImportError as e:
print('`pip install wandb` to use the wandb experiment tracker')
raise e
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
os.environ["WANDB_SILENT"] = "true"
self.wandb = wandb
def init(self, **config):
self.wandb.init(**config)
def log(self, log, **kwargs):
def log(self, log, verbose=False, **kwargs):
if verbose:
print(log)
self.wandb.log(log, **kwargs)
def log_images(self, images, captions=[], image_section="images", **kwargs):
"""
Takes a tensor of images and a list of captions and logs them to wandb.
"""
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
self.log({ image_section: wandb_images }, **kwargs)
def save_state_dict(self, state_dict, relative_path, **kwargs):
"""
Saves a state_dict to disk and uploads it
"""
full_path = str(self.data_path / relative_path)
torch.save(state_dict, full_path)
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path

View File

@@ -0,0 +1,62 @@
from torchvision import transforms as T
from configs.decoder_defaults import default_config, ConfigField
class TrainDecoderConfig:
def __init__(self, config):
self.config = self.map_config(config, default_config)
def map_config(self, config, defaults):
"""
Returns a dictionary containing all config options in the union of config and defaults.
If the config value is an array, apply the default value to each element.
If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
"""
def _check_option(option, option_config, option_defaults):
for key, value in option_defaults.items():
if key not in option_config:
if value == ConfigField.REQUIRED:
raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
option_config[key] = value
for key, value in defaults.items():
if key not in config:
# Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
if value == ConfigField.REQUIRED:
raise RuntimeError("Required config value '{}' not supplied".format(key))
elif isinstance(value, dict):
config[key] = {}
elif isinstance(value, list):
config[key] = [{}]
# Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
# If it is a list, then we need to check each element
if isinstance(value, list):
assert isinstance(config[key], list)
for element in config[key]:
_check_option(key, element, value[0])
elif isinstance(value, dict):
_check_option(key, config[key], value)
# This object does not support checking
return config
def get_preprocessing(self):
"""
Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
"""
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transformations = []
for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
if isinstance(transformation_kwargs, dict):
transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
else:
transformations.append(_get_transformation(transformation_name))
return T.Compose(transformations)
def __getitem__(self, key):
return self.config[key]

11
dalle2_pytorch/utils.py Normal file
View File

@@ -0,0 +1,11 @@
import time
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.3.2',
version = '0.3.9',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -41,7 +41,8 @@ setup(
'x-clip>=0.4.4',
'youtokentome',
'webdataset>=0.2.5',
'fsspec>=2022.1.0'
'fsspec>=2022.1.0',
'torchmetrics[image]>=0.8.0'
],
classifiers=[
'Development Status :: 4 - Beta',

454
train_decoder.py Normal file
View File

@@ -0,0 +1,454 @@
from dalle2_pytorch import Unet, Decoder
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer
import json
import torchvision
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import webdataset as wds
import click
# constants
TRAIN_CALC_LOSS_EVERY_ITERS = 10
VALID_CALC_LOSS_EVERY_ITERS = 10
# helpers functions
def exists(val):
return val is not None
# main functions
def create_dataloaders(
available_shards,
webdataset_base_url,
embeddings_url,
shard_width=6,
num_workers=4,
batch_size=32,
n_sample_images=6,
shuffle_train=True,
resample_train=False,
img_preproc = None,
index_width=4,
train_prop = 0.75,
val_prop = 0.15,
test_prop = 0.10,
**kwargs
):
"""
Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
"""
assert train_prop + test_prop + val_prop == 1
num_train = round(train_prop*len(available_shards))
num_test = round(test_prop*len(available_shards))
num_val = len(available_shards) - num_train - num_test
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
tar_url=tar_urls,
num_workers=num_workers,
batch_size=batch_size if not for_sampling else n_sample_images,
embeddings_url=embeddings_url,
index_width=index_width,
shuffle_num = None,
extra_keys= ["txt"] if with_text else [],
shuffle_shards = shuffle,
resample_shards = resample,
img_preproc=img_preproc,
handler=wds.handlers.warn_and_continue
)
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
return {
"train": train_dataloader,
"train_sampling": train_sampling_dataloader,
"val": val_dataloader,
"test": test_dataloader,
"test_sampling": test_sampling_dataloader
}
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = [Unet(**config) for config in unets_config]
decoder = Decoder(
unet=unets,
**decoder_config
)
decoder.to(device=device)
return decoder
def get_dataset_keys(dataloader):
"""
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
"""
# If the dataloader is actually a WebLoader, we need to extract the real dataloader
if isinstance(dataloader, wds.WebLoader):
dataloader = dataloader.pipeline[0]
return dataloader.dataset.key_map
def get_example_data(dataloader, device, n=5):
"""
Samples the dataloader and returns a zipped list of examples
"""
images = []
embeddings = []
captions = []
dataset_keys = get_dataset_keys(dataloader)
has_caption = "txt" in dataset_keys
for data in dataloader:
if has_caption:
img, emb, txt = data
else:
img, emb = data
txt = [""] * emb.shape[0]
img = img.to(device=device, dtype=torch.float)
emb = emb.to(device=device, dtype=torch.float)
images.extend(list(img))
embeddings.extend(list(emb))
captions.extend(list(txt))
if len(images) >= n:
break
print("Generated {} examples".format(len(images)))
return list(zip(images[:n], embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, text_prepend=""):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
"""
real_images, embeddings, txts = zip(*example_data)
embeddings_tensor = torch.stack(embeddings)
samples = trainer.sample(embeddings_tensor)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, text_prepend=""):
"""
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
metrics = {}
# Prepare the data
examples = get_example_data(dataloader, device, n_evalation_samples)
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
if exists(FID):
fid = FrechetInceptionDistance(**FID)
fid.to(device=device)
fid.update(int_real_images, real=True)
fid.update(int_generated_images, real=False)
metrics["FID"] = fid.compute().item()
if exists(IS):
inception = InceptionScore(**IS)
inception.to(device=device)
inception.update(int_real_images)
is_mean, is_std = inception.compute()
metrics["IS_mean"] = is_mean.item()
metrics["IS_std"] = is_std.item()
if exists(KID):
kernel_inception = KernelInceptionDistance(**KID)
kernel_inception.to(device=device)
kernel_inception.update(int_real_images, real=True)
kernel_inception.update(int_generated_images, real=False)
kid_mean, kid_std = kernel_inception.compute()
metrics["KID_mean"] = kid_mean.item()
metrics["KID_std"] = kid_std.item()
if exists(LPIPS):
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
metrics["LPIPS"] = lpips.compute().item()
return metrics
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
"""
Logs the model with an appropriate method depending on the tracker
"""
if isinstance(relative_paths, str):
relative_paths = [relative_paths]
trainer_state_dict = {}
trainer_state_dict["trainer"] = trainer.state_dict()
trainer_state_dict['epoch'] = epoch
trainer_state_dict['step'] = step
trainer_state_dict['validation_losses'] = validation_losses
for relative_path in relative_paths:
tracker.save_state_dict(trainer_state_dict, relative_path)
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
"""
Loads the model with an appropriate method depending on the tracker
"""
print(print_ribbon(f"Loading model from {recall_source}"))
state_dict = tracker.recall_state_dict(recall_source, **load_config)
trainer.load_state_dict(state_dict["trainer"])
print("Model loaded")
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
def train(
dataloaders,
decoder,
tracker,
inference_device,
load_config=None,
evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None,
epochs = 20,
n_sample_images = 5,
save_every_n_samples = 100000,
save_all=False,
save_latest=True,
save_best=True,
unet_training_mask=None,
**kwargs
):
"""
Trains a decoder on a dataset.
"""
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
decoder,
**kwargs
)
# Set up starting model and parameters based on a recalled state dict
start_step = 0
start_epoch = 0
validation_losses = []
if exists(load_config) and exists(load_config["source"]):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
trainer.to(device=inference_device)
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
print(print_ribbon("Generating Example Data", repeat=40))
print("This can take a while to load the shard lists...")
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
step = start_step
for epoch in range(start_epoch, epochs):
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
trainer.train()
timer = Timer()
sample = 0
last_sample = 0
last_snapshot = 0
losses = []
for i, (img, emb) in enumerate(dataloaders["train"]):
step += 1
sample += img.shape[0]
img, emb = send_to_device((img, emb))
for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
trainer.update(unet_number=unet)
losses.append(loss)
samples_per_sec = (sample - last_sample) / timer.elapsed()
timer.reset()
last_sample = sample
if i % CALC_LOSS_EVERY_ITERS == 0:
average_loss = sum(losses) / len(losses)
log_data = {
"Training loss": average_loss,
"Epoch": epoch,
"Sample": sample,
"Step": i,
"Samples per second": samples_per_sec
}
tracker.log(log_data, step=step, verbose=True)
losses = []
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
last_snapshot = sample
# We need to know where the model should be saved
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
trainer.train()
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
if exists(epoch_samples) and sample >= epoch_samples:
break
trainer.eval()
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
with torch.no_grad():
sample = 0
average_loss = 0
timer = Timer()
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
sample += img.shape[0]
img, emb = send_to_device((img, emb))
for unet in range(1, len(decoder.unets)+1):
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
average_loss += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
print(f"Loss: {average_loss / (i+1)}")
print("")
if exists(validation_samples) and sample >= validation_samples:
break
average_loss /= i+1
log_data = {
"Validation loss": average_loss
}
tracker.log(log_data, step=step, verbose=True)
# Compute evaluation metrics
trainer.eval()
if exists(evaluate_config):
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
tracker.log(evaluation, step=step, verbose=True)
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
# Get the same paths
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
save_paths.append("best.pth")
validation_losses.append(average_loss)
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
tracker_config = config["tracker"]
init_config = {}
init_config["config"] = config.config
if tracker_type == "console":
tracker = ConsoleTracker(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config["load"]
if load_config["source"] == "wandb" and load_config["resume"]:
# Then we are resuming the run load_config["run_path"]
run_id = config["resume"]["wandb_run_path"].split("/")[-1]
init_config["id"] = run_id
init_config["resume"] = "must"
init_config["entity"] = tracker_config["wandb_entity"]
init_config["project"] = tracker_config["wandb_project"]
tracker = WandbTracker(data_path)
tracker.init(**init_config)
else:
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
return tracker
def initialize_training(config):
# Create the save path
if "cuda" in config["train"]["device"]:
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device(config["train"]["device"])
torch.cuda.set_device(device)
all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
dataloaders = create_dataloaders (
available_shards=all_shards,
img_preproc = config.get_preprocessing(),
train_prop = config["data"]["splits"]["train"],
val_prop = config["data"]["splits"]["val"],
test_prop = config["data"]["splits"]["test"],
n_sample_images=config["train"]["n_sample_images"],
**config["data"]
)
decoder = create_decoder(device, config["decoder"], config["unets"])
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")
tracker = create_tracker(config, **config["tracker"])
train(dataloaders, decoder,
tracker=tracker,
inference_device=device,
load_config=config["load"],
evaluate_config=config["evaluate"],
**config["train"],
)
# Create a simple click command line interface to load the config and start the training
@click.command()
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file):
print("Recalling config from {}".format(config_file))
with open(config_file) as f:
config = json.load(f)
config = TrainDecoderConfig(config)
initialize_training(config)
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,6 @@
from pathlib import Path
import click
import math
import time
import numpy as np
import torch
@@ -13,6 +12,7 @@ from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdap
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer
from embedding_reader import EmbeddingReader
@@ -29,16 +29,6 @@ tracker = WandbTracker()
def exists(val):
val is not None
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time
# functions
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):