Compare commits

...

35 Commits

Author SHA1 Message Date
Phil Wang
80046334ad make sure entire readme runs without errors 2022-07-28 10:17:43 -07:00
Phil Wang
36fb46a95e fix readme and a small bug in DALLE2 class 2022-07-28 08:33:51 -07:00
Phil Wang
07abfcf45b rescale values in linear attention to mitigate overflows in fp16 setting 2022-07-27 12:27:38 -07:00
Phil Wang
2e35a9967d product management 2022-07-26 11:10:16 -07:00
Phil Wang
406e75043f add upsample combiner feature for the unets 2022-07-26 10:46:04 -07:00
Phil Wang
9646dfc0e6 fix path_or_state bug 2022-07-26 09:47:54 -07:00
Phil Wang
62043acb2f fix repaint 2022-07-24 15:29:06 -07:00
Phil Wang
417ff808e6 1.0.3 2022-07-22 13:16:57 -07:00
Aidan Dempster
f3d7e226ba Changed types to be generic instead of functions (#215)
This allows pylance to do proper type hinting and makes developing
extensions to the package much easier
2022-07-22 13:16:29 -07:00
Phil Wang
48a1302428 1.0.2 2022-07-20 23:01:51 -07:00
Aidan Dempster
ccaa46b81b Re-introduced change that was accidentally rolled back (#212) 2022-07-20 23:01:19 -07:00
Phil Wang
76d08498cc diffusion prior training updates from @nousr 2022-07-20 18:05:27 -07:00
zion
f9423d308b Prior updates (#211)
* update configs for prior

add prior warmup to config

update example prior config

* update prior trainer & script

add deepspeed amp & warmup

adopt full accelerator support

reload at sample point

finish epoch resume code

* update tracker save method for prior

* helper functions for prior_loader
2022-07-20 18:04:26 -07:00
Phil Wang
06c65b60d2 1.0.0 2022-07-19 19:08:17 -07:00
Aidan Dempster
4145474bab Improved upsampler training (#181)
Sampling is now possible without the first decoder unet

Non-training unets are deleted in the decoder trainer since they are never used and it is harder merge the models is they have keys in this state dict

Fixed a mistake where clip was not re-added after saving
2022-07-19 19:07:50 -07:00
Phil Wang
4b912a38c6 0.26.2 2022-07-19 17:50:36 -07:00
Aidan Dempster
f97e55ec6b Quality of life improvements for tracker savers (#210)
The default save location is now none so if keys are not specified the
corresponding checkpoint type is not saved.

Models and checkpoints are now both saved with version number and the
config used to create them in order to simplify loading.

Documentation was fixed to be in line with current usage.
2022-07-19 17:50:18 -07:00
Phil Wang
291377bb9c @jacobwjs reports dynamic thresholding works very well and 0.95 is a better value 2022-07-19 11:31:56 -07:00
Phil Wang
7f120a8b56 cleanup, CLI no longer necessary since Zion + Aidan have https://github.com/LAION-AI/dalle2-laion and colab notebook going 2022-07-19 09:47:44 -07:00
Phil Wang
8c003ab1e1 readme and citation 2022-07-19 09:36:45 -07:00
Phil Wang
723bf0abba complete inpainting ability using inpaint_image and inpaint_mask passed into sample function for decoder 2022-07-19 09:26:55 -07:00
Phil Wang
d88c7ba56c fix a bug with ddim and predict x0 objective 2022-07-18 19:04:26 -07:00
Phil Wang
3676a8ce78 comments 2022-07-18 15:02:04 -07:00
Phil Wang
da8e99ada0 fix sample bug 2022-07-18 13:50:22 -07:00
Phil Wang
6afb886cf4 complete imagen-like noise level conditioning 2022-07-18 13:43:57 -07:00
Phil Wang
c7fe4f2f44 project management 2022-07-17 17:27:44 -07:00
Phil Wang
a2ee3fa3cc offer way to turn off initial cross embed convolutional module, for debugging upsampler artifacts 2022-07-15 17:29:10 -07:00
Phil Wang
a58a370d75 takes care of a grad strides error at https://github.com/lucidrains/DALLE2-pytorch/issues/196 thanks to @YUHANG-Ma 2022-07-14 15:28:34 -07:00
Phil Wang
1662bbf226 protect against random cropping for base unet 2022-07-14 12:49:43 -07:00
Phil Wang
5be1f57448 update 2022-07-14 12:03:42 -07:00
Phil Wang
c52ce58e10 update 2022-07-14 10:54:51 -07:00
Phil Wang
a34f60962a let the neural network peek at the low resolution conditioning one last time before making prediction, for upsamplers 2022-07-14 10:27:04 -07:00
Phil Wang
0b40cbaa54 just always use nearest neighbor interpolation when resizing for low resolution conditioning, for https://github.com/lucidrains/DALLE2-pytorch/pull/181 2022-07-13 20:59:43 -07:00
Phil Wang
f141144a6d allow for using classifier free guidance for some unets but not others, by passing in a tuple of cond_scale during sampling for decoder, just in case it is causing issues for upsamplers 2022-07-13 13:12:30 -07:00
Phil Wang
f988207718 hack around some inplace error, also make sure for openai clip text encoding, only tokens after eos_id is masked out 2022-07-13 12:56:02 -07:00
14 changed files with 1340 additions and 586 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
github: [lucidrains]
github: [nousr, Veldrovive, lucidrains]

125
README.md
View File

@@ -371,6 +371,7 @@ loss.backward()
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8),
@@ -395,7 +396,7 @@ decoder = Decoder(
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -628,6 +629,82 @@ images = dalle2(
Now you'll just have to worry about training the Prior and the Decoder!
## Inpainting
Inpainting is also built into the `Decoder`. You simply have to pass in the `inpaint_image` and `inpaint_mask` (boolean tensor where `True` indicates which regions of the inpaint image to keep)
This repository uses the formulation put forth by <a href="https://arxiv.org/abs/2201.09865">Lugmayr et al. in Repaint</a>
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# 2 unets for the decoder (a la cascading DDPM)
unet = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 1, 1, 1)
).cuda()
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
unet = (unet,), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256,), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
loss = decoder(images, unet_number = 1)
loss.backward()
# do the above for many steps for both unets
mock_image_embed = torch.randn(1, 512).cuda()
# then to do inpainting
inpaint_image = torch.randn(1, 3, 256, 256).cuda() # (batch, channels, height, width)
inpaint_mask = torch.ones(1, 256, 256).bool().cuda() # (batch, height, width)
inpainted_images = decoder.sample(
image_embed = mock_image_embed,
inpaint_image = inpaint_image, # just pass in the inpaint image
inpaint_mask = inpaint_mask # and the mask
)
inpainted_images.shape # (1, 3, 256, 256)
```
## Experimental
### DALL-E2 with Latent Diffusion
@@ -784,25 +861,23 @@ unet1 = Unet(
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True,
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1000,
condition_on_text_encodings = True
timesteps = 1000
).cuda()
decoder_trainer = DecoderTrainer(
@@ -827,8 +902,8 @@ for unet_number in (1, 2):
# after much training
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
mock_image_embed = torch.randn(32, 512).cuda()
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
```
### Diffusion Prior Training
@@ -991,26 +1066,12 @@ dataset = ImageEmbeddingDataset(
)
```
### Scripts (wip)
### Scripts
#### `train_diffusion_prior.py`
For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
## CLI (wip)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/big-sleep">template</a>
## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
@@ -1048,11 +1109,11 @@ Once built, images will be saved to the same directory the command is invoked
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
- [x] allow for unet to be able to condition non-cross attention style as well
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [x] speed up inference, read up on papers (ddim)
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
## Citations
@@ -1170,4 +1231,14 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@article{Lugmayr2022RePaintIU,
title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
journal = {ArXiv},
year = {2022},
volume = {abs/2201.09865}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -69,14 +69,12 @@ Settings for controlling the training hyperparameters.
| `wd` | No | `0.01` | The weight decay. |
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
| `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. |
| `device` | No | `cuda:0` | The device to train on. |
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
| `ema_beta` | No | `0.99` | The ema coefficient. |
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
**<ins>Evaluate</ins>:**
@@ -163,9 +161,10 @@ All save locations have these configuration options
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
| `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. |
| `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). |
If using `local`
| Option | Required | Default | Description |
@@ -177,7 +176,6 @@ If using `huggingface`
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `huggingface`. |
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
If using `wandb`

View File

@@ -56,9 +56,6 @@
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
@@ -96,14 +93,15 @@
},
"save": [{
"save_to": "wandb"
"save_to": "wandb",
"save_latest_to": "latest.pth"
}, {
"save_to": "huggingface",
"huggingface_repo": "Veldrovive/test_model",
"save_all": true,
"save_latest": true,
"save_best": true,
"save_latest_to": "path/to/model_dir/latest.pth",
"save_best_to": "path/to/model_dir/best.pth",
"save_meta_to": "path/to/directory/for/assorted/files",
"save_type": "model"
}]

View File

@@ -61,9 +61,6 @@
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
@@ -96,7 +93,8 @@
},
"save": [{
"save_to": "local"
"save_to": "local",
"save_latest_to": "latest.pth"
}]
}
}

View File

@@ -1,18 +1,14 @@
{
"prior": {
"clip": {
"make": "x-clip",
"model": "ViT-L/14",
"base_model_kwargs": {
"dim_text": 768,
"dim_image": 768,
"dim_latent": 768
}
"make": "openai",
"model": "ViT-L/14"
},
"net": {
"dim": 768,
"depth": 12,
"num_timesteps": 1000,
"max_text_len": 77,
"num_time_embeds": 1,
"num_image_embeds": 1,
"num_text_embeds": 1,
@@ -20,8 +16,8 @@
"heads": 12,
"ff_mult": 4,
"norm_out": true,
"attn_dropout": 0.0,
"ff_dropout": 0.0,
"attn_dropout": 0.05,
"ff_dropout": 0.05,
"final_proj": true,
"normformer": true,
"rotary_emb": true
@@ -30,6 +26,7 @@
"image_size": 224,
"image_channels": 3,
"timesteps": 1000,
"sample_timesteps": 64,
"cond_drop_prob": 0.1,
"loss_type": "l2",
"predict_x_start": true,
@@ -37,34 +34,48 @@
"condition_on_text_encodings": true
},
"data": {
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
"batch_size": 256,
"batch_size": 128,
"num_data_points": 100000,
"eval_every_seconds": 1600,
"image_url": "<path to your images>",
"meta_url": "<path to your metadata>",
"splits": {
"train": 0.9,
"val": 1e-7,
"test": 0.0999999
"train": 0.8,
"val": 0.1,
"test": 0.1
}
},
"train": {
"epochs": 1,
"epochs": 5,
"lr": 1.1e-4,
"wd": 6.02e-2,
"max_grad_norm": 0.5,
"use_ema": true,
"ema_beta": 0.9999,
"ema_update_after_step": 50,
"warmup_steps": 50,
"amp": false,
"save_every": 10000
},
"load": {
"source": null,
"resume": false
"save_every_seconds": 3600,
"eval_timesteps": [64, 1000],
"random_seed": 84513
},
"tracker": {
"tracker_type": "wandb",
"data_path": "./prior_checkpoints",
"wandb_entity": "laion",
"wandb_project": "diffusion-prior",
"verbose": true
"data_path": ".prior",
"overwrite_data_path": true,
"log": {
"log_type": "wandb",
"wandb_entity": "<your wandb username>",
"wandb_project": "prior_debugging",
"wandb_resume": false,
"verbose": true
},
"save": [
{
"save_to": "local",
"save_type": "checkpoint",
"save_latest_to": ".prior/latest_checkpoint.pth",
"save_best_to": ".prior/best_checkpoint.pth"
}
]
}
}

View File

@@ -52,10 +52,10 @@ def first(arr, d = None):
def maybe(fn):
@wraps(fn)
def inner(x):
def inner(x, *args, **kwargs):
if not exists(x):
return x
return fn(x)
return fn(x, *args, **kwargs)
return inner
def default(val, d):
@@ -63,18 +63,20 @@ def default(val, d):
return val
return d() if callable(d) else d
def cast_tuple(val, length = None):
def cast_tuple(val, length = None, validate = True):
if isinstance(val, list):
val = tuple(val)
out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
if exists(length) and validate:
assert len(out) == length
return out
def module_device(module):
if isinstance(module, nn.Identity):
return 'cpu' # It doesn't matter
return next(module.parameters()).device
def zero_init_(m):
@@ -146,7 +148,7 @@ def resize_image_to(
scale_factors = target_image_size / orig_image_size
out = resize(image, scale_factors = scale_factors, **kwargs)
else:
out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False)
out = F.interpolate(image, target_image_size, mode = 'nearest')
if exists(clamp_range):
out = out.clamp(*clamp_range)
@@ -278,6 +280,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
import clip
openai_clip, preprocess = clip.load(name)
super().__init__(openai_clip)
self.eos_id = 49407 # for handling 0 being also '!'
text_attention_final = self.find_layer('ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
@@ -316,7 +319,10 @@ class OpenAIClipAdapter(BaseClipAdapter):
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
assert not self.cleared
text_embed = self.clip.encode_text(text)
@@ -490,6 +496,9 @@ class NoiseScheduler(nn.Module):
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
def sample_random_times(self, batch):
return torch.randint(0, self.num_timesteps, (batch,), device = self.betas.device, dtype = torch.long)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
@@ -507,6 +516,17 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape = x_from.shape
noise = default(noise, lambda: torch.randn_like(x_from))
alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)
sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)
alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)
sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
@@ -515,7 +535,7 @@ class NoiseScheduler(nn.Module):
def predict_noise_from_start(self, x_t, t, x0):
return (
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
@@ -900,7 +920,7 @@ class DiffusionPriorNetwork(nn.Module):
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
text_encodings = torch.where(
rearrange(mask, 'b n -> b n 1'),
rearrange(mask, 'b n -> b n 1').clone(),
text_encodings,
null_text_embeds
)
@@ -1239,7 +1259,7 @@ class DiffusionPrior(nn.Module):
# timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long)
times = self.noise_scheduler.sample_random_times(batch)
# scale image embed (Katherine)
@@ -1483,6 +1503,7 @@ class LinearAttention(nn.Module):
k = k.softmax(dim = -2)
q = q * self.scale
v = v / (x * y)
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
@@ -1518,6 +1539,38 @@ class CrossEmbedLayer(nn.Module):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
class UpsampleCombiner(nn.Module):
def __init__(
self,
dim,
*,
enabled = False,
dim_ins = tuple(),
dim_outs = tuple()
):
super().__init__()
assert len(dim_ins) == len(dim_outs)
self.enabled = enabled
if not self.enabled:
self.dim_out = dim
return
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
def forward(self, x, fmaps = None):
target_size = x.shape[-1]
fmaps = default(fmaps, tuple())
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
return x
fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
return torch.cat((x, *outs), dim = 1)
class Unet(nn.Module):
def __init__(
self,
@@ -1535,9 +1588,10 @@ class Unet(nn.Module):
self_attn = False,
attn_dim_head = 32,
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
sparse_attn = False,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
cond_on_image_embeds = False,
@@ -1546,6 +1600,7 @@ class Unet(nn.Module):
init_conv_kernel_size = 7,
resnet_groups = 8,
num_resnet_blocks = 2,
init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
@@ -1553,6 +1608,7 @@ class Unet(nn.Module):
scale_skip_connection = False,
pixel_shuffle_upsample = True,
final_conv_kernel_size = 1,
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
**kwargs
):
super().__init__()
@@ -1574,7 +1630,7 @@ class Unet(nn.Module):
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
@@ -1624,6 +1680,17 @@ class Unet(nn.Module):
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
self.text_embed_dim = text_embed_dim
# low resolution noise conditiong, based on Imagen's upsampler training technique
self.lowres_noise_cond = lowres_noise_cond
self.to_lowres_noise_cond = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, time_cond_dim),
nn.GELU(),
nn.Linear(time_cond_dim, time_cond_dim)
) if lowres_noise_cond else None
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1677,7 +1744,8 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
skip_connect_dims = [] # keeping track of skip connection dimensions
skip_connect_dims = [] # keeping track of skip connection dimensions
upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
is_first = ind == 0
@@ -1719,6 +1787,8 @@ class Unet(nn.Module):
elif sparse_attn:
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
upsample_combiner_dims.append(dim_out)
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
@@ -1726,8 +1796,22 @@ class Unet(nn.Module):
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
]))
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
# whether to combine outputs from all upsample blocks for final resnet block
self.upsample_combiner = UpsampleCombiner(
dim = dim,
enabled = combine_upsample_fmaps,
dim_ins = upsample_combiner_dims,
dim_outs = (dim,) * len(upsample_combiner_dims)
)
# a final resnet block
self.final_resnet_block = ResnetBlock(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
out_dim_in = dim + (channels if lowres_cond else 0)
self.to_out = nn.Conv2d(out_dim_in, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
@@ -1737,15 +1821,17 @@ class Unet(nn.Module):
self,
*,
lowres_cond,
lowres_noise_cond,
channels,
channels_out,
cond_on_image_embeds,
cond_on_text_encodings
cond_on_text_encodings,
):
if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \
lowres_noise_cond == self.lowres_noise_cond and \
channels_out == self.channels_out:
return self
@@ -1754,7 +1840,8 @@ class Unet(nn.Module):
channels = channels,
channels_out = channels_out,
cond_on_image_embeds = cond_on_image_embeds,
cond_on_text_encodings = cond_on_text_encodings
cond_on_text_encodings = cond_on_text_encodings,
lowres_noise_cond = lowres_noise_cond
)
return self.__class__(**{**self._locals, **updated_kwargs})
@@ -1780,6 +1867,7 @@ class Unet(nn.Module):
*,
image_embed,
lowres_cond_img = None,
lowres_noise_level = None,
text_encodings = None,
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
@@ -1808,6 +1896,13 @@ class Unet(nn.Module):
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
# low res noise conditioning (similar to time above)
if exists(lowres_noise_level):
assert exists(self.to_lowres_noise_cond), 'lowres_noise_cond must be set to True on instantiation of the unet in order to conditiong on lowres noise'
lowres_noise_level = lowres_noise_level.type_as(x)
t = t + self.to_lowres_noise_cond(lowres_noise_level)
# conditional dropout
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
@@ -1906,7 +2001,8 @@ class Unet(nn.Module):
# go through the layers of the unet, down and up
hiddens = []
down_hiddens = []
up_hiddens = []
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
if exists(pre_downsample):
@@ -1916,10 +2012,10 @@ class Unet(nn.Module):
for resnet_block in resnet_blocks:
x = resnet_block(x, t, c)
hiddens.append(x)
down_hiddens.append(x.contiguous())
x = attn(x)
hiddens.append(x)
down_hiddens.append(x.contiguous())
if exists(post_downsample):
x = post_downsample(x)
@@ -1931,7 +2027,7 @@ class Unet(nn.Module):
x = self.mid_block2(x, t, mid_c)
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, resnet_blocks, attn, upsample in self.ups:
x = connect_skip(x)
@@ -1942,49 +2038,77 @@ class Unet(nn.Module):
x = resnet_block(x, t, c)
x = attn(x)
up_hiddens.append(x.contiguous())
x = upsample(x)
x = self.upsample_combiner(x, up_hiddens)
x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t)
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
return self.to_out(x)
class LowresConditioner(nn.Module):
def __init__(
self,
downsample_first = True,
downsample_mode_nearest = False,
use_blur = True,
blur_prob = 0.5,
blur_sigma = 0.6,
blur_kernel_size = 3,
input_image_range = None
use_noise = False,
input_image_range = None,
normalize_img_fn = identity,
unnormalize_img_fn = identity
):
super().__init__()
self.downsample_first = downsample_first
self.downsample_mode_nearest = downsample_mode_nearest
self.input_image_range = input_image_range
self.use_blur = use_blur
self.blur_prob = blur_prob
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
self.use_noise = use_noise
self.normalize_img = normalize_img_fn
self.unnormalize_img = unnormalize_img_fn
self.noise_scheduler = NoiseScheduler(beta_schedule = 'linear', timesteps = 1000, loss_type = 'l2') if use_noise else None
def noise_image(self, cond_fmap, noise_levels = None):
assert exists(self.noise_scheduler)
batch = cond_fmap.shape[0]
cond_fmap = self.normalize_img(cond_fmap)
random_noise_levels = default(noise_levels, lambda: self.noise_scheduler.sample_random_times(batch))
cond_fmap = self.noise_scheduler.q_sample(cond_fmap, t = random_noise_levels, noise = torch.randn_like(cond_fmap))
cond_fmap = self.unnormalize_img(cond_fmap)
return cond_fmap, random_noise_levels
def forward(
self,
cond_fmap,
*,
target_image_size,
downsample_image_size = None,
should_blur = True,
blur_sigma = None,
blur_kernel_size = None
):
if self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = True)
# blur is only applied 50% of the time
# section 3.1 in https://arxiv.org/abs/2106.15282
if random.random() < self.blur_prob:
if self.use_blur and should_blur and random.random() < self.blur_prob:
# when training, blur the low resolution conditional image
@@ -2006,8 +2130,21 @@ class LowresConditioner(nn.Module):
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
return cond_fmap
# resize to target image size
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True)
# noise conditioning, as done in Imagen
# as a replacement for the BSR noising, and potentially replace blurring for first stage too
random_noise_levels = None
if self.use_noise:
cond_fmap, random_noise_levels = self.noise_image(cond_fmap)
# return conditioning feature map, as well as the augmentation noise levels
return cond_fmap, random_noise_levels
class Decoder(nn.Module):
def __init__(
@@ -2028,11 +2165,13 @@ class Decoder(nn.Module):
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
use_noise_for_lowres_cond = False, # whether to use Imagen-like noising for low resolution conditioning
use_blur_for_lowres_cond = True, # whether to use the blur conditioning used in the original cascading ddpm paper, as well as DALL-E2
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
blur_sigma = 0.6, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
lowres_noise_sample_level = 0.2, # in imagen paper, they use a 0.2 noise level at sample time for low resolution conditioning
clip_denoised = True,
clip_x_start = True,
clip_adapter_overrides = dict(),
@@ -2042,7 +2181,7 @@ class Decoder(nn.Module):
unconditional = False, # set to True for generating images without conditioning
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9,
dynamic_thres_percentile = 0.95,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1,
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
@@ -2080,10 +2219,17 @@ class Decoder(nn.Module):
self.channels = channels
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
# verify conditioning method
unets = cast_tuple(unet)
num_unets = len(unets)
self.num_unets = num_unets
self.unconditional = unconditional
@@ -2099,12 +2245,28 @@ class Decoder(nn.Module):
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
self.vb_loss_weight = vb_loss_weight
# default and validate conditioning parameters
use_noise_for_lowres_cond = cast_tuple(use_noise_for_lowres_cond, num_unets - 1, validate = False)
use_blur_for_lowres_cond = cast_tuple(use_blur_for_lowres_cond, num_unets - 1, validate = False)
if len(use_noise_for_lowres_cond) < num_unets:
use_noise_for_lowres_cond = (False, *use_noise_for_lowres_cond)
if len(use_blur_for_lowres_cond) < num_unets:
use_blur_for_lowres_cond = (False, *use_blur_for_lowres_cond)
assert not use_noise_for_lowres_cond[0], 'first unet will never need low res noise conditioning'
assert not use_blur_for_lowres_cond[0], 'first unet will never need low res blur conditioning'
assert num_unets == 1 or all((use_noise or use_blur) for use_noise, use_blur in zip(use_noise_for_lowres_cond[1:], use_blur_for_lowres_cond[1:]))
# construct unets and vaes
self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
for ind, (one_unet, one_vae, one_unet_learned_var, lowres_noise_cond) in enumerate(zip(unets, vaes, learned_variance, use_noise_for_lowres_cond)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
@@ -2116,6 +2278,7 @@ class Decoder(nn.Module):
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
lowres_noise_cond = lowres_noise_cond,
cond_on_image_embeds = not unconditional and is_first,
cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings,
channels = unet_channels,
@@ -2158,13 +2321,14 @@ class Decoder(nn.Module):
image_sizes = default(image_sizes, (image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
assert self.num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({self.num_unets}) for resolutions {image_sizes}'
self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# random crop sizes (for super-resoluting unets at the end of cascade?)
self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
assert not exists(self.random_crop_sizes[0]), 'you would not need to randomly crop the image for the base unet'
# predict x0 config
@@ -2177,18 +2341,30 @@ class Decoder(nn.Module):
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.lowres_downsample_mode_nearest = lowres_downsample_mode_nearest
self.lowres_conds = nn.ModuleList([])
self.to_lowres_cond = LowresConditioner(
downsample_first = lowres_downsample_first,
downsample_mode_nearest = lowres_downsample_mode_nearest,
blur_prob = blur_prob,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
input_image_range = self.input_image_range
)
for unet_index, use_noise, use_blur in zip(range(num_unets), use_noise_for_lowres_cond, use_blur_for_lowres_cond):
if unet_index == 0:
self.lowres_conds.append(None)
continue
lowres_cond = LowresConditioner(
downsample_first = lowres_downsample_first,
use_blur = use_blur,
use_noise = use_noise,
blur_prob = blur_prob,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
input_image_range = self.input_image_range,
normalize_img_fn = self.normalize_img,
unnormalize_img_fn = self.unnormalize_img
)
self.lowres_conds.append(lowres_cond)
self.lowres_noise_sample_level = lowres_noise_sample_level
# classifier free guidance
@@ -2206,11 +2382,6 @@ class Decoder(nn.Module):
self.use_dynamic_thres = use_dynamic_thres
self.dynamic_thres_percentile = dynamic_thres_percentile
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
# device tracker
self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)
@@ -2221,10 +2392,10 @@ class Decoder(nn.Module):
@property
def condition_on_text_encodings(self):
return any([unet.cond_on_text_encodings for unet in self.unets])
return any([unet.cond_on_text_encodings for unet in self.unets if isinstance(unet, Unet)])
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
assert 0 < unet_number <= self.num_unets
index = unet_number - 1
return self.unets[index]
@@ -2266,10 +2437,10 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s
return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level))
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
@@ -2301,44 +2472,111 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
def p_sample_loop_ddpm(
self,
unet,
shape,
image_embed,
noise_scheduler,
predict_x_start = False,
learned_variance = False,
clip_denoised = True,
lowres_cond_img = None,
text_encodings = None,
cond_scale = 1,
is_latent_diffusion = False,
lowres_noise_level = None,
inpaint_image = None,
inpaint_mask = None,
inpaint_resample_times = 5
):
device = self.device
b = shape[0]
img = torch.randn(shape, device = device)
is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1
if is_inpaint:
inpaint_image = self.normalize_img(inpaint_image)
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
inpaint_mask = inpaint_mask.bool()
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
img = self.p_sample(
unet,
img,
torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
is_last_timestep = time == 0
for r in reversed(range(0, resample_times)):
is_last_resample_step = r == 0
times = torch.full((b,), time, device = device, dtype = torch.long)
if is_inpaint:
# following the repaint paper
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
img = self.p_sample(
unet,
img,
times,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
if is_inpaint and not (is_last_timestep or is_last_resample_step):
# in repaint, you renoise and resample up to 10 times every step
img = noise_scheduler.q_sample_from_to(img, times - 1, times)
if is_inpaint:
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
@torch.no_grad()
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
def p_sample_loop_ddim(
self,
unet,
shape,
image_embed,
noise_scheduler,
timesteps,
eta = 1.,
predict_x_start = False,
learned_variance = False,
clip_denoised = True,
lowres_cond_img = None,
text_encodings = None,
cond_scale = 1,
is_latent_diffusion = False,
lowres_noise_level = None,
inpaint_image = None,
inpaint_mask = None,
inpaint_resample_times = 5
):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
@@ -2346,39 +2584,68 @@ class Decoder(nn.Module):
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1
if is_inpaint:
inpaint_image = self.normalize_img(inpaint_image)
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
inpaint_mask = inpaint_mask.bool()
img = torch.randn(shape, device = device)
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
is_last_timestep = time_next == 0
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
for r in reversed(range(0, resample_times)):
is_last_resample_step = r == 0
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
alpha = alphas[time]
alpha_next = alphas[time_next]
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
if predict_x_start:
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
if is_inpaint:
# following the repaint paper
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
if clip_denoised:
x_start = self.dynamic_threshold(x_start)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if time_next > 0 else 0.
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
img = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
if predict_x_start:
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
if clip_denoised:
x_start = self.dynamic_threshold(x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if not is_last_timestep else 0.
img = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
if is_inpaint and not (is_last_timestep or is_last_resample_step):
# in repaint, you renoise and resample up to 10 times every step
time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond)
if exists(inpaint_image):
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
img = self.unnormalize_img(img)
return img
@@ -2396,7 +2663,7 @@ class Decoder(nn.Module):
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
@@ -2415,6 +2682,7 @@ class Decoder(nn.Module):
image_embed = image_embed,
text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
)
@@ -2471,13 +2739,18 @@ class Decoder(nn.Module):
@eval_decorator
def sample(
self,
image = None,
image_embed = None,
text = None,
text_encodings = None,
batch_size = 1,
cond_scale = 1.,
start_at_unet_number = 1,
stop_at_unet_number = None,
distributed = False,
inpaint_image = None,
inpaint_mask = None,
inpaint_resample_times = 5
):
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
@@ -2491,19 +2764,40 @@ class Decoder(nn.Module):
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
img = None
if start_at_unet_number > 1:
# Then we are not generating the first image and one must have been passed in
assert exists(image), 'image must be passed in if starting at unet number > 1'
assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
img = resize_image_to(image, prev_unet_output_size, nearest = True)
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
num_unets = self.num_unets
cond_scale = cast_tuple(cond_scale, num_unets)
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
if unet_number < start_at_unet_number:
continue # It's the easiest way to do it
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
with context:
lowres_cond_img = None
# prepare low resolution conditioning for upsamplers
lowres_cond_img = lowres_noise_level = None
shape = (batch_size, channel, image_size, image_size)
if unet.lowres_cond:
lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = self.lowres_downsample_mode_nearest)
lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True)
if lowres_cond.use_noise:
lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)
lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)
# latent diffusion
is_latent_diffusion = isinstance(vae, VQGanVAE)
image_size = vae.get_encoded_fmap_size(image_size)
@@ -2511,19 +2805,25 @@ class Decoder(nn.Module):
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
# denoising loop for image
img = self.p_sample_loop(
unet,
shape,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
cond_scale = unet_cond_scale,
predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
is_latent_diffusion = is_latent_diffusion,
noise_scheduler = noise_scheduler,
timesteps = sample_timesteps
timesteps = sample_timesteps,
inpaint_image = inpaint_image,
inpaint_mask = inpaint_mask,
inpaint_resample_times = inpaint_resample_times
)
img = vae.decode(img)
@@ -2542,7 +2842,7 @@ class Decoder(nn.Module):
unet_number = None,
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
assert not (self.num_unets > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {self.num_unets}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
unet_index = unet_number - 1
@@ -2550,6 +2850,7 @@ class Decoder(nn.Module):
vae = self.vaes[unet_index]
noise_scheduler = self.noise_schedulers[unet_index]
lowres_conditioner = self.lowres_conds[unet_index]
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
@@ -2572,8 +2873,8 @@ class Decoder(nn.Module):
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)
lowres_cond_img, lowres_noise_level = lowres_conditioner(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if exists(lowres_conditioner) else (None, None)
image = resize_image_to(image, target_image_size, nearest = True)
if exists(random_crop_size):
aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
@@ -2590,7 +2891,7 @@ class Decoder(nn.Module):
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
if not return_lowres_cond_image:
return losses
@@ -2637,7 +2938,7 @@ class DALLE2(nn.Module):
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
if return_pil_images:
images = list(map(self.to_pil, images.unbind(dim = 0)))

View File

@@ -67,6 +67,15 @@ class PriorEmbeddingDataset(IterableDataset):
def __str__(self):
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
def set_start(self, start):
"""
Adjust the starting point within the reader, useful for resuming an epoch
"""
self.start = start
def get_start(self):
return self.start
def get_sample(self):
"""
pre-proocess data from either reader into a common format

View File

@@ -4,13 +4,15 @@ import json
from pathlib import Path
import shutil
from itertools import zip_longest
from typing import Optional, List, Union
from typing import Any, Optional, List, Union
from pydantic import BaseModel
import torch
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.utils import import_or_print_error
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.version import __version__
from packaging import version
# constants
@@ -21,16 +23,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
# load file functions
def load_wandb_file(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return file_reference.name
def load_local_file(file_path, **kwargs):
return file_path
class BaseLogger:
"""
An abstract class representing an object that can log data.
@@ -234,7 +226,7 @@ class LocalLoader(BaseLoader):
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be loaded
if not self.file_path.exists():
if not self.file_path.exists() and not self.only_auto_resume:
raise FileNotFoundError(f'Model not found at {self.file_path}')
def recall(self) -> dict:
@@ -283,9 +275,9 @@ def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
class BaseSaver:
def __init__(self,
data_path: str,
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
save_best_to: Optional[Union[str, bool]] = 'best.pth',
save_meta_to: str = './',
save_latest_to: Optional[Union[str, bool]] = None,
save_best_to: Optional[Union[str, bool]] = None,
save_meta_to: Optional[str] = None,
save_type: str = 'checkpoint',
**kwargs
):
@@ -295,10 +287,10 @@ class BaseSaver:
self.save_best_to = save_best_to
self.saving_best = save_best_to is not None and save_best_to is not False
self.save_meta_to = save_meta_to
self.saving_meta = save_meta_to is not None
self.save_type = save_type
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
@@ -459,6 +451,11 @@ class Tracker:
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
print(f"New logger config: {self.logger.__dict__}")
self.save_metadata = dict(
version = version.parse(__version__)
) # Data that will be saved alongside the checkpoint or model
self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata
assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode:
# The only thing we need is a loader
@@ -507,8 +504,15 @@ class Tracker:
# Save the config under config_name in the root folder of data_path
shutil.copy(current_config_path, self.data_path / config_name)
for saver in self.savers:
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
if saver.saving_meta:
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
def add_save_metadata(self, state_dict_key: str, metadata: Any):
"""
Adds a new piece of metadata that will be saved along with the model or decoder.
"""
self.save_metadata[state_dict_key] = metadata
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
"""
@@ -518,24 +522,38 @@ class Tracker:
"""
assert save_type in ['checkpoint', 'model']
if save_type == 'checkpoint':
trainer.save(file_path, overwrite=True, **kwargs)
# Create a metadata dict without the blacklisted keys so we do not error when we create the state dict
metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
trainer.save(file_path, overwrite=True, **kwargs, **metadata)
elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
state_dict = trainer.unwrap_model(prior).state_dict()
torch.save(state_dict, file_path)
prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
# Remove CLIP if it is part of the model
original_clip = prior.clip
prior.clip = None
model_state_dict = prior.state_dict()
prior.clip = original_clip
elif isinstance(trainer, DecoderTrainer):
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# Remove CLIP if it is part of the model
original_clip = decoder.clip
decoder.clip = None
if trainer.use_ema:
trainable_unets = decoder.unets
decoder.unets = trainer.unets # Swap EMA unets in
state_dict = decoder.state_dict()
model_state_dict = decoder.state_dict()
decoder.unets = trainable_unets # Swap back
else:
state_dict = decoder.state_dict()
torch.save(state_dict, file_path)
model_state_dict = decoder.state_dict()
decoder.clip = original_clip
else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
state_dict = {
**self.save_metadata,
'model': model_state_dict
}
torch.save(state_dict, file_path)
return Path(file_path)
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):

View File

@@ -1,7 +1,7 @@
import json
from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
from x_clip import CLIP as XCLIP
from coca_pytorch import CoCa
@@ -25,11 +25,9 @@ def exists(val):
def default(val, d):
return val if exists(val) else d
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
def SingularOrIterable(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
InnerType = TypeVar('InnerType')
ListOrTuple = Union[List[InnerType], Tuple[InnerType]]
SingularOrIterable = Union[InnerType, ListOrTuple[InnerType]]
# general pydantic classes
@@ -145,6 +143,9 @@ class DiffusionPriorNetworkConfig(BaseModel):
normformer: bool = False
rotary_emb: bool = True
class Config:
extra = "allow"
def create(self):
kwargs = self.dict()
return DiffusionPriorNetwork(**kwargs)
@@ -187,23 +188,26 @@ class DiffusionPriorTrainConfig(BaseModel):
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_every: int = 10000 # what steps to save on
warmup_steps: int = None # number of warmup steps
save_every_seconds: int = 3600 # how often to save
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
best_validation_loss: float = 1e9 # the current best valudation loss observed
current_epoch: int = 0 # the current epoch
num_samples_seen: int = 0 # the current number of samples seen
random_seed: int = 0 # manual seed for torch
class DiffusionPriorDataConfig(BaseModel):
image_url: str # path to embeddings folder
meta_url: str # path to metadata (captions) for images
splits: TrainSplitConfig
batch_size: int = 64
class DiffusionPriorLoadConfig(BaseModel):
source: str = None
resume: bool = False
image_url: str # path to embeddings folder
meta_url: str # path to metadata (captions) for images
splits: TrainSplitConfig # define train, validation, test splits for your dataset
batch_size: int # per-gpu batch size used to train the model
num_data_points: int = 25e7 # total number of datapoints to train on
eval_every_seconds: int = 3600 # validation statistics will be performed this often
class TrainDiffusionPriorConfig(BaseModel):
prior: DiffusionPriorConfig
data: DiffusionPriorDataConfig
train: DiffusionPriorTrainConfig
load: DiffusionPriorLoadConfig
tracker: TrackerConfig
@classmethod
@@ -216,30 +220,31 @@ class TrainDiffusionPriorConfig(BaseModel):
class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple(int)
dim_mults: ListOrTuple[int]
image_embed_dim: int = None
text_embed_dim: int = None
cond_on_text_encodings: bool = None
cond_dim: int = None
channels: int = 3
self_attn: ListOrTuple(int)
self_attn: ListOrTuple[int]
attn_dim_head: int = 32
attn_heads: int = 16
init_cross_embed: bool = True
class Config:
extra = "allow"
class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig)
unets: ListOrTuple[UnetConfig]
image_size: int = None
image_sizes: ListOrTuple(int) = None
image_sizes: ListOrTuple[int] = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable(int)] = None
sample_timesteps: Optional[SingularOrIterable[int]] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True
beta_schedule: ListOrTuple[str] = None # None means all cosine
learned_variance: SingularOrIterable[bool] = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
@@ -298,20 +303,22 @@ class DecoderDataConfig(BaseModel):
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
warmup_steps: Optional[SingularOrIterable(int)] = None
lr: SingularOrIterable[float] = 1e-4
wd: SingularOrIterable[float] = 0.01
warmup_steps: Optional[SingularOrIterable[int]] = None
find_unused_parameters: bool = True
max_grad_norm: SingularOrIterable(float) = 0.5
max_grad_norm: SingularOrIterable[float] = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
save_immediately: bool = False
use_ema: bool = True
ema_beta: float = 0.999
amp: bool = False
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
unet_training_mask: ListOrTuple[bool] = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
@@ -320,12 +327,6 @@ class DecoderEvaluateConfig(BaseModel):
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig
data: DecoderDataConfig

View File

@@ -174,26 +174,24 @@ class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
accelerator = None,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
eps = 1e-6,
max_grad_norm = None,
amp = False,
group_wd_params = True,
device = None,
accelerator = None,
verbose = True,
warmup_steps = 1,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
# verbosity
self.verbose = verbose
if not exists(accelerator):
accelerator = Accelerator(**accelerator_kwargs)
# assign some helpful member vars
@@ -202,23 +200,31 @@ class DiffusionPriorTrainer(nn.Module):
# setting the device
if not exists(accelerator) and not exists(device):
diffusion_prior_device = next(diffusion_prior.parameters()).device
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
self.device = diffusion_prior_device
else:
self.device = accelerator.device if exists(accelerator) else device
diffusion_prior.to(self.device)
self.device = accelerator.device
diffusion_prior.to(self.device)
# save model
self.diffusion_prior = diffusion_prior
# optimizer and mixed precision stuff
# mixed precision checks
self.amp = amp
if (
exists(self.accelerator)
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
and self.diffusion_prior.clip is not None
):
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
self.diffusion_prior.clip.to(precision_type)
self.scaler = GradScaler(enabled = amp)
# optimizer stuff
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
@@ -227,17 +233,21 @@ class DiffusionPriorTrainer(nn.Module):
**self.optim_kwargs,
**kwargs
)
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
# distribute the model if using HFA
if exists(self.accelerator):
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)
# exponential moving average stuff
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)
# gradient clipping if needed
@@ -247,67 +257,24 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0], device = self.device))
# accelerator wrappers
def print(self, msg):
if not self.verbose:
return
if exists(self.accelerator):
self.accelerator.print(msg)
else:
print(msg)
def unwrap_model(self, model):
if exists(self.accelerator):
return self.accelerator.unwrap_model(model)
else:
return model
def wait_for_everyone(self):
if exists(self.accelerator):
self.accelerator.wait_for_everyone()
def is_main_process(self):
if exists(self.accelerator):
return self.accelerator.is_main_process
else:
return True
def clip_grad_norm_(self, *args):
if exists(self.accelerator):
return self.accelerator.clip_grad_norm_(*args)
else:
return torch.nn.utils.clip_grad_norm_(*args)
def backprop(self, x):
if exists(self.accelerator):
self.accelerator.backward(x)
else:
try:
x.backward()
except Exception as e:
self.print(f"Caught error in backprop call: {e}")
# utility
def save(self, path, overwrite = True, **kwargs):
# ensure we sync gradients before continuing
self.wait_for_everyone()
# only save on the main process
if self.is_main_process():
self.print(f"Saving checkpoint at step: {self.step.item()}")
if self.accelerator.is_main_process:
print(f"Saving checkpoint at step: {self.step.item()}")
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
# FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__),
step = self.step.item(),
step = self.step,
**kwargs
)
@@ -320,14 +287,14 @@ class DiffusionPriorTrainer(nn.Module):
torch.save(save_obj, str(path))
def load(self, path, overwrite_lr = True, strict = True):
def load(self, path_or_state, overwrite_lr = True, strict = True):
"""
Load a checkpoint of a diffusion prior trainer.
Will load the entire trainer, including the optimizer and EMA.
Params:
- path (str): a path to the DiffusionPriorTrainer checkpoint file
- path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
@@ -336,56 +303,56 @@ class DiffusionPriorTrainer(nn.Module):
"""
# all processes need to load checkpoint. no restriction here
path = Path(path)
assert path.exists()
if isinstance(path_or_state, str):
path = Path(path_or_state)
assert path.exists()
loaded_obj = torch.load(str(path), map_location=self.device)
loaded_obj = torch.load(str(path), map_location=self.device)
elif isinstance(path_or_state, dict):
loaded_obj = path_or_state
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
# unwrap the model when loading from checkpoint
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
self.scaler.load_state_dict(loaded_obj['scaler'])
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.optimizer.load_state_dict(loaded_obj['optimizer'])
# set warmupstep
if exists(self.warmup_scheduler):
self.warmup_scheduler.last_step = self.step.item()
# ensure new lr is used if different from old one
if overwrite_lr:
new_lr = self.optim_kwargs["lr"]
self.print(f"Overriding LR to be {new_lr}")
for group in self.optimizer.param_groups:
group["lr"] = new_lr
group["lr"] = new_lr if group["lr"] > 0.0 else 0.0
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
# below might not be necessary, but I had a suspicion that this wasn't being loaded correctly
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
# sync and inform
self.wait_for_everyone()
self.print(f"Loaded model")
return loaded_obj
# model functionality
def update(self):
# only continue with updates until all ranks finish
self.wait_for_everyone()
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
# utilize HFA clipping where applicable
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped:
with self.warmup_scheduler.dampening():
self.scheduler.step()
if self.use_ema:
self.ema_diffusion_prior.update()
@@ -414,7 +381,7 @@ class DiffusionPriorTrainer(nn.Module):
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
@cast_torch_tensor
def forward(
@@ -426,16 +393,14 @@ class DiffusionPriorTrainer(nn.Module):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
with self.accelerator.autocast():
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
# backprop with accelerate if applicable
if self.training:
self.backprop(self.scaler.scale(loss))
self.accelerator.backward(loss)
return total_loss
@@ -498,23 +463,27 @@ class DecoderTrainer(nn.Module):
warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)
if isinstance(unet, nn.Identity):
optimizers.append(None)
schedulers.append(None)
warmup_schedulers.append(None)
else:
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)
optimizers.append(optimizer)
optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
schedulers.append(scheduler)
schedulers.append(scheduler)
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
@@ -590,7 +559,8 @@ class DecoderTrainer(nn.Module):
for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
state_dict = optimizer.state_dict() if optimizer is not None else None
save_obj = {**save_obj, optimizer_key: state_dict}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -612,8 +582,8 @@ class DecoderTrainer(nn.Module):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
warmup_scheduler = self.warmup_schedulers[ind]
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if optimizer is not None:
optimizer.load_state_dict(loaded_obj[optimizer_key])
if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step
@@ -714,23 +684,32 @@ class DecoderTrainer(nn.Module):
*args,
unet_number = None,
max_batch_size = None,
return_lowres_cond_image=False,
**kwargs
):
unet_number = self.validate_and_return_unet_number(unet_number)
total_loss = 0.
using_amp = self.accelerator.mixed_precision != 'no'
cond_images = []
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
# loss_obj may be a tuple with loss and cond_image
if return_lowres_cond_image:
loss, cond_image = loss_obj
else:
loss = loss_obj
cond_image = None
loss = loss * chunk_size_frac
if cond_image is not None:
cond_images.append(cond_image)
total_loss += loss.item()
if self.training:
self.accelerator.backward(loss)
return total_loss
if return_lowres_cond_image:
return total_loss, torch.stack(cond_images)
else:
return total_loss

View File

@@ -1 +1 @@
__version__ = '0.23.7'
__version__ = '1.2.2'

View File

@@ -1,5 +1,6 @@
from pathlib import Path
from typing import List
from datetime import timedelta
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
@@ -11,11 +12,12 @@ from clip import tokenize
import torchvision
import torch
from torch import nn
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs
from accelerate.utils import dataclasses as accelerate_dataclasses
import webdataset as wds
import click
@@ -132,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True):
def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
@@ -157,6 +159,13 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
# Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings)
sample_params["text_encodings"] = text_embeddings
sample_params["start_at_unet_number"] = start_unet
sample_params["stop_at_unet_number"] = end_unet
if start_unet > 1:
# If we are only training upsamplers
sample_params["image"] = torch.stack(real_images)
if device is not None:
sample_params["_device"] = device
samples = trainer.sample(**sample_params)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
@@ -165,15 +174,15 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
"""
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
@@ -183,7 +192,7 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.")
return metrics
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -259,11 +268,13 @@ def train(
evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None,
save_immediately=False,
epochs = 20,
n_sample_images = 5,
save_every_n_samples = 100000,
unet_training_mask=None,
condition_on_text_encodings=False,
cond_scale=1.0,
**kwargs
):
"""
@@ -271,6 +282,21 @@ def train(
"""
is_master = accelerator.process_index == 0
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * len(decoder.unets)
assert len(unet_training_mask) == len(decoder.unets), f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable]
first_trainable_unet = trainable_unet_numbers[0]
last_trainable_unet = trainable_unet_numbers[-1]
def move_unets(unet_training_mask):
for i in range(len(decoder.unets)):
if not unet_training_mask[i]:
# Replace the unet from the module list with a nn.Identity(). This training script never uses unets that aren't being trained so this is fine.
decoder.unets[i] = nn.Identity().to(inference_device)
# Remove non-trainable unets
move_unets(unet_training_mask)
trainer = DecoderTrainer(
decoder=decoder,
accelerator=accelerator,
@@ -285,6 +311,7 @@ def train(
sample = 0
samples_seen = 0
val_sample = 0
step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet))
if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
@@ -296,13 +323,6 @@ def train(
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device)
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
accelerator.print("This can take a while to load the shard lists...")
if is_master:
@@ -360,7 +380,7 @@ def train(
tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet)
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -373,10 +393,10 @@ def train(
unet_all_losses = accelerator.gather(unet_losses_tensor)
mask = unet_all_losses != 0
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if unet_training_mask[index] }
# gather decay rate on each UNet
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets) if unet_training_mask[index]}
log_data = {
"Epoch": epoch,
@@ -391,7 +411,7 @@ def train(
if is_master:
tracker.log(log_data, step=step())
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
if is_master and (last_snapshot + save_every_n_samples < sample or (save_immediately and i == 0)): # This will miss by some amount every time, but it's not a big deal... I hope
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
print("Saving snapshot")
last_snapshot = sample
@@ -399,7 +419,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples:
@@ -449,8 +469,9 @@ def train(
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
average_val_loss_tensor[0, unet-1] += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
@@ -477,7 +498,7 @@ def train(
if next_task == 'eval':
if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
if is_master:
tracker.log(evaluation, step=step())
next_task = 'sample'
@@ -488,15 +509,15 @@ def train(
# Generate examples and save the model if we are the master
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
is_best = False
if all_average_val_losses is not None:
average_loss = all_average_val_losses.mean(dim=0).item()
average_loss = all_average_val_losses.mean(dim=0).sum() / sum(unet_training_mask)
if len(validation_losses) == 0 or average_loss < min(validation_losses):
is_best = True
validation_losses.append(average_loss)
@@ -513,6 +534,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
}
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json')
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
return tracker
def initialize_training(config: TrainDecoderConfig, config_path):
@@ -521,7 +543,8 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
if accelerator.num_processes > 1:
# We are using distributed training and want to immediately ensure all can connect

View File

@@ -1,31 +1,23 @@
# TODO: add start, num_data_points, eval_every and group to config
# TODO: switch back to repo's wandb
START = 0
NUM_DATA_POINTS = 250e6
EVAL_EVERY = 1000
GROUP = "distributed"
import os
import click
import wandb
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
from typing import List
from accelerate import Accelerator
from accelerate.utils import set_seed
from torch.utils.data import DataLoader
from embedding_reader import EmbeddingReader
from accelerate.utils import dataclasses as accelerate_dataclasses
from dalle2_pytorch.dataloaders import get_reader, make_splits
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.trackers import Tracker
from dalle2_pytorch import DiffusionPriorTrainer
from dalle2_pytorch.dataloaders import get_reader, make_splits
from dalle2_pytorch.train_configs import (
DiffusionPriorConfig,
DiffusionPriorTrainConfig,
TrainDiffusionPriorConfig,
)
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
from dalle2_pytorch import DiffusionPriorTrainer
# helpers
@@ -38,8 +30,19 @@ def exists(val):
return val is not None
def all_between(values: list, lower_bound, upper_bound):
for value in values:
if value < lower_bound or value > upper_bound:
return False
return True
def make_model(
prior_config, train_config, device: str = None, accelerator: Accelerator = None
prior_config: DiffusionPriorConfig,
train_config: DiffusionPriorTrainConfig,
device: str = None,
accelerator: Accelerator = None,
):
# create model from config
diffusion_prior = prior_config.create()
@@ -54,71 +57,214 @@ def make_model(
use_ema=train_config.use_ema,
device=device,
accelerator=accelerator,
warmup_steps=train_config.warmup_steps,
)
return trainer
def create_tracker(
accelerator: Accelerator,
config: TrainDiffusionPriorConfig,
config_path: str,
dummy: bool = False,
) -> Tracker:
tracker_config = config.tracker
accelerator_config = {
"Distributed": accelerator.distributed_type
!= accelerate_dataclasses.DistributedType.NO,
"DistributedType": accelerator.distributed_type,
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision,
}
tracker: Tracker = tracker_config.create(
config, accelerator_config, dummy_mode=dummy
)
tracker.save_config(config_path, config_name="prior_config.json")
return tracker
def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"):
"""
pad a value or tensor across all processes and gather
params:
- trainer: a trainer that carries an accelerator object
- x: a number or torch tensor to reduce
- method: "mean", "sum", "max", "min"
return:
- the average tensor after maskin out 0's
- None if the gather resulted in an empty tensor
"""
assert method in [
"mean",
"sum",
"max",
"min",
], "This function has limited capabilities [sum, mean, max, min]"
assert type(x) is not None, "Cannot reduce a None type object"
# wait for everyone to arrive here before gathering
if type(x) is not torch.Tensor:
x = torch.tensor([x])
# verify that the tensor is on the proper device
x = x.to(trainer.device)
# pad across processes
padded_x = trainer.accelerator.pad_across_processes(x, dim=0)
# gather across all procesess
gathered_x = trainer.accelerator.gather(padded_x)
# mask out zeros
masked_x = gathered_x[gathered_x != 0]
# if the tensor is empty, warn and return None
if len(masked_x) == 0:
click.secho(
f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.",
fg="red",
)
return None
if method == "mean":
return torch.mean(masked_x)
elif method == "sum":
return torch.sum(masked_x)
elif method == "max":
return torch.max(masked_x)
elif method == "min":
return torch.min(masked_x)
def save_trainer(
tracker: Tracker,
trainer: DiffusionPriorTrainer,
is_latest: bool,
is_best: bool,
epoch: int,
samples_seen: int,
best_validation_loss: float,
):
"""
Logs the model with an appropriate method depending on the tracker
"""
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
click.secho(
f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}",
fg="magenta",
)
tracker.save(
trainer=trainer,
is_best=is_best,
is_latest=is_latest,
epoch=int(epoch),
samples_seen=int(samples_seen),
best_validation_loss=best_validation_loss,
)
def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):
"""
Loads the model with an appropriate method depending on the tracker
"""
if trainer.accelerator.is_main_process:
click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow")
state_dict = tracker.recall()
trainer.load(state_dict, strict=True)
return (
int(state_dict.get("epoch", 0)),
state_dict.get("best_validation_loss", 0),
int(state_dict.get("samples_seen", 0)),
)
# eval functions
def eval_model(
def report_validation_loss(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
use_ema: bool,
tracker: Tracker,
split: str,
tracker_folder: str,
loss_type: str,
tracker_context: str,
tracker: BaseTracker = None,
use_ema: bool = True,
):
trainer.eval()
if trainer.is_main_process():
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
"""
Compute the validation loss on a given subset of data.
"""
with torch.no_grad():
total_loss = 0.0
total_samples = 0.0
if trainer.accelerator.is_main_process:
click.secho(
f"Measuring performance on {use_ema}-{split} split",
fg="green",
blink=True,
)
for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device)
batches = image_embeddings.shape[0]
for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
input_args = dict(image_embed=image_embeddings)
input_args = dict(image_embed=image_embeddings)
if text_conditioned:
input_args = dict(**input_args, text=text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
if text_conditioned:
input_args = dict(**input_args, text=text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
if use_ema:
loss = trainer.ema_diffusion_prior(**input_args)
else:
loss = trainer(**input_args)
if use_ema:
loss = trainer.ema_diffusion_prior(**input_args)
else:
loss = trainer(**input_args)
total_loss += loss * batches
total_samples += batches
total_loss += loss
avg_loss = total_loss / total_samples
# compute the average loss across all processes
stats = {f"{tracker_context}-{loss_type}": avg_loss}
trainer.print(stats)
avg_loss = pad_gather_reduce(trainer, total_loss, method="mean")
stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss}
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
# print and log results on main process
tracker.log(stats, step=trainer.step.item() + 1)
return avg_loss
def report_cosine_sims(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
tracker: BaseTracker,
tracker_context: str = "validation",
tracker: Tracker,
split: str,
timesteps: int,
tracker_folder: str,
):
trainer.eval()
if trainer.is_main_process():
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
if trainer.accelerator.is_main_process:
click.secho(
f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps",
fg="green",
blink=True,
)
for test_image_embeddings, text_data in dataloader:
test_image_embeddings = test_image_embeddings.to(trainer.device)
@@ -127,9 +273,7 @@ def report_cosine_sims(
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings = trainer.embed_text(text_data)
text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings
)
text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings)
else:
text_embedding = text_data
text_cond = dict(text_embed=text_embedding)
@@ -150,8 +294,7 @@ def report_cosine_sims(
text_encodings_shuffled = None
text_cond_shuffled = dict(
text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled
text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled
)
# prepare the text embedding
@@ -164,7 +307,9 @@ def report_cosine_sims(
# predict on the unshuffled text embeddings
predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond
test_image_embeddings.shape,
text_cond,
timesteps=timesteps,
)
predicted_image_embeddings = (
@@ -174,7 +319,9 @@ def report_cosine_sims(
# predict on the shuffled embeddings
predicted_unrelated_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond_shuffled
test_image_embeddings.shape,
text_cond_shuffled,
timesteps=timesteps,
)
predicted_unrelated_embeddings = (
@@ -183,32 +330,97 @@ def report_cosine_sims(
)
# calculate similarities
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = (
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
orig_sim = pad_gather_reduce(
trainer, cos(text_embed, test_image_embeddings), method="mean"
)
predicted_img_similarity = (
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
pred_sim = pad_gather_reduce(
trainer, cos(text_embed, predicted_image_embeddings), method="mean"
)
unrel_sim = pad_gather_reduce(
trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean"
)
pred_img_sim = pad_gather_reduce(
trainer,
cos(test_image_embeddings, predicted_image_embeddings),
method="mean",
)
stats = {
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
f"{tracker_context}/similarity with original image": np.mean(
predicted_img_similarity
),
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{tracker_context}/difference from baseline similarity": np.mean(
predicted_similarity - original_similarity
),
f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim,
f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim,
f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim,
f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim,
f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim
- orig_sim,
}
for k, v in stats.items():
trainer.print(f"{tracker_context}/{k}: {v}")
tracker.log(stats, step=trainer.step.item() + 1)
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
def eval_model(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
split: str,
tracker: Tracker,
use_ema: bool,
report_cosine: bool,
report_loss: bool,
timesteps: List[int],
loss_type: str = None,
):
"""
Run evaluation on a model and track metrics
returns: loss if requested
"""
trainer.eval()
use_ema = "ema" if use_ema else "online"
tracker_folder = f"metrics/{use_ema}-{split}"
# detemine if valid timesteps are passed
min_timesteps = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).sample_timesteps
max_timesteps = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).noise_scheduler.num_timesteps
assert all_between(
timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps
), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}"
# measure cosine metrics across various eta and timesteps
if report_cosine:
for timestep in timesteps:
report_cosine_sims(
trainer,
dataloader=dataloader,
text_conditioned=text_conditioned,
tracker=tracker,
split=split,
timesteps=timestep,
tracker_folder=tracker_folder,
)
# measure loss on a seperate split of data
if report_loss:
loss = report_validation_loss(
trainer=trainer,
dataloader=dataloader,
text_conditioned=text_conditioned,
use_ema=use_ema,
tracker=tracker,
split=split,
tracker_folder=tracker_folder,
loss_type=loss_type,
)
return loss
# training script
@@ -216,182 +428,327 @@ def report_cosine_sims(
def train(
trainer: DiffusionPriorTrainer,
tracker: Tracker,
train_loader: DataLoader,
eval_loader: DataLoader,
test_loader: DataLoader,
config: DiffusionPriorTrainConfig,
):
# distributed tracking with wandb
if trainer.accelerator.num_processes > 1:
os.environ["WANDB_START_METHOD"] = "thread"
# init timers
save_timer = Timer() # when to save
samples_timer = Timer() # samples/sec
validation_profiler = Timer() # how long is validation taking
validation_countdown = Timer() # when to perform evalutation
tracker = wandb.init(
name=f"RANK:{trainer.device}",
entity=config.tracker.wandb_entity,
project=config.tracker.wandb_project,
config=config.dict(),
group=GROUP,
)
# keep track of best validation loss
# sync after tracker init
trainer.wait_for_everyone()
# init a timer
timer = Timer()
best_validation_loss = config.train.best_validation_loss
samples_seen = config.train.num_samples_seen
# do training
for img, txt in train_loader:
trainer.train()
current_step = trainer.step.item() + 1
# place data on device
img = img.to(trainer.device)
txt = txt.to(trainer.device)
start_epoch = config.train.current_epoch
# pass to model
loss = trainer(text=txt, image_embed=img)
for epoch in range(start_epoch, config.train.epochs):
# if we finished out an old epoch, reset the distribution to be a full epoch
tracker.log({"tracking/epoch": epoch}, step=trainer.step.item())
# display & log loss (will only print from main process)
trainer.print(f"Step {current_step}: Loss {loss}")
if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1:
if trainer.accelerator.is_main_process:
click.secho(f"Finished resumed epoch...resetting dataloader.")
train_loader.dataset.set_start(0)
# perform backprop & apply EMA updates
trainer.update()
for img, txt in train_loader:
# setup things every step
# track samples/sec/rank
samples_per_sec = img.shape[0] / timer.elapsed()
trainer.train()
current_step = trainer.step.item()
samples_timer.reset()
# samples seen
samples_seen = (
config.data.batch_size * trainer.accelerator.num_processes * current_step
)
# place data on device
# ema decay
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
img = img.to(trainer.device)
txt = txt.to(trainer.device)
# Log on all processes for debugging
tracker.log(
{
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
"metrics/training-loss": loss,
},
step=current_step,
)
# pass to model
# Metric Tracking & Checkpointing (outside of timer's scope)
if current_step % EVAL_EVERY == 0:
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/online-model-validation",
tracker=tracker,
use_ema=False,
loss = trainer(text=txt, image_embed=img)
# perform backprop & apply EMA updates
trainer.update()
# gather info about training step
all_loss = pad_gather_reduce(trainer, loss, method="mean")
num_samples = pad_gather_reduce(trainer, len(txt), method="sum")
samples_per_sec = num_samples / samples_timer.elapsed()
samples_seen += num_samples
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# log
tracker.log(
{
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
f"tracking/training-{config.prior.loss_type}": all_loss,
},
step=current_step,
)
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/ema-model-validation",
tracker=tracker,
use_ema=True,
# Metric Tracking @ Timed Intervals
eval_delta = pad_gather_reduce(
trainer, validation_countdown.elapsed(), method="min"
)
report_cosine_sims(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
tracker=tracker,
tracker_context="metrics",
)
if eval_delta != None and eval_delta > config.data.eval_every_seconds:
# begin timing how long this takes
if current_step % config.train.save_every == 0:
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
validation_profiler.reset()
# reset timer for next round
timer.reset()
# package kwargs for evaluation
eval_kwargs = {
"trainer": trainer,
"tracker": tracker,
"text_conditioned": config.prior.condition_on_text_encodings,
"timesteps": config.train.eval_timesteps,
}
# ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT
eval_model(
dataloader=eval_loader,
loss_type=config.prior.loss_type,
split="validation",
use_ema=False,
report_cosine=False,
report_loss=True,
**eval_kwargs,
)
# EMA MODEL : COSINE : LOSS : VALIDATION DATA
ema_val_loss = eval_model(
dataloader=eval_loader,
loss_type=config.prior.loss_type,
split="validation",
use_ema=True,
report_cosine=True,
report_loss=True,
**eval_kwargs,
)
tracker.log(
{
"tracking/validation length (minutes)": validation_profiler.elapsed()
/ 60
}
)
# check if the ema validation is the lowest seen yet
if ema_val_loss < best_validation_loss:
best_validation_loss = ema_val_loss
# go save the model as best
save_trainer(
trainer=trainer,
tracker=tracker,
is_best=True,
is_latest=False,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
# reset timer for validaiton
validation_countdown.reset()
elif eval_delta is None:
click.secho(
f"Error occured reading the eval time on rank: {trainer.device}",
fg="yellow",
)
# save as latest model on schedule
save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method="min")
if save_delta != None and save_delta >= config.train.save_every_seconds:
save_trainer(
trainer=trainer,
tracker=tracker,
is_best=False,
is_latest=True,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
save_timer.reset()
elif save_delta is None:
click.secho(
f"Error occured reading the save time on rank: {trainer.device}",
fg="yellow",
)
# evaluate on test data
eval_model(
if trainer.accelerator.is_main_process:
click.secho(f"Starting Test", fg="red")
# save one last time as latest before beginning validation
save_trainer(
tracker=tracker,
trainer=trainer,
is_best=False,
is_latest=True,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
test_loss = eval_model(
trainer=trainer,
dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="test",
split="test",
tracker=tracker,
use_ema=True,
report_cosine=False,
report_loss=True,
timesteps=config.train.eval_timesteps,
loss_type=config.prior.loss_type,
)
report_cosine_sims(
trainer,
test_loader,
config.prior.condition_on_text_encodings,
tracker,
tracker_context="test",
)
if test_loss < best_validation_loss:
best_validation_loss = test_loss
# go save the model as best
save_trainer(
trainer=trainer,
tracker=tracker,
is_best=True,
is_latest=False,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=test_loss,
)
def initialize_training(config, accelerator=None):
def initialize_training(config_file, accelerator):
"""
Parse the configuration file, and prepare everything necessary for training
"""
# load the configuration file
if accelerator.is_main_process:
click.secho(f"Loading configuration from {config_file}", fg="green")
config = TrainDiffusionPriorConfig.from_json_path(config_file)
# seed
set_seed(config.train.random_seed)
# get a device
if accelerator:
device = accelerator.device
click.secho(f"Accelerating on: {device}", fg="yellow")
else:
if torch.cuda.is_available():
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
device = "cuda:0"
else:
click.secho("No GPU detected...using cpu", fg="yellow")
device = "cpu"
device = accelerator.device
# make the trainer (will automatically distribute if possible & configured)
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
trainer: DiffusionPriorTrainer = make_model(
config.prior, config.train, device, accelerator
).to(device)
# create a tracker
tracker = create_tracker(
accelerator, config, config_file, dummy=accelerator.process_index != 0
)
# reload from chcekpoint
if config.load.resume == True:
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
trainer.load(config.load.source)
if tracker.can_recall:
current_epoch, best_validation_loss, samples_seen = recall_trainer(
tracker=tracker, trainer=trainer
)
# display best values
if trainer.accelerator.is_main_process:
click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow")
# update config to reflect recalled values
config.train.num_samples_seen = samples_seen
config.train.current_epoch = current_epoch
config.train.best_validation_loss = best_validation_loss
# fetch and prepare data
if trainer.is_main_process():
click.secho("Grabbing data from source", fg="blue", blink=True)
if trainer.accelerator.is_main_process:
click.secho("Grabbing data...", fg="blue", blink=True)
trainer.accelerator.wait_for_everyone()
img_reader = get_reader(
text_conditioned=trainer.text_conditioned,
img_url=config.data.image_url,
meta_url=config.data.meta_url,
)
# calculate start point within epoch
trainer.accelerator.wait_for_everyone()
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=trainer.text_conditioned,
batch_size=config.data.batch_size,
num_data_points=NUM_DATA_POINTS,
num_data_points=config.data.num_data_points,
train_split=config.data.splits.train,
eval_split=config.data.splits.val,
image_reader=img_reader,
rank=accelerator.state.process_index if exists(accelerator) else 0,
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
start=START,
rank=accelerator.state.process_index,
world_size=accelerator.state.num_processes,
start=0,
)
# wait for everyone to load data before continuing
trainer.wait_for_everyone()
# update the start point to finish out the epoch on a resumed run
if tracker.can_recall:
samples_seen = config.train.num_samples_seen
length = (
config.data.num_data_points
if samples_seen <= img_reader.count
else img_reader.count
)
scaled_samples = length * config.train.current_epoch
start_point = (
scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen
)
if trainer.accelerator.is_main_process:
click.secho(f"Resuming at sample: {start_point}", fg="yellow")
train_loader.dataset.set_start(start_point)
# start training
if trainer.accelerator.is_main_process:
click.secho(
f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}",
fg="yellow",
)
train(
trainer=trainer,
tracker=tracker,
train_loader=train_loader,
eval_loader=eval_loader,
test_loader=test_loader,
@@ -400,23 +757,13 @@ def initialize_training(config, accelerator=None):
@click.command()
@click.option("--hfa", default=True)
@click.option("--config_path", default="configs/prior.json")
def main(hfa, config_path):
# start HFA if requested
if hfa:
accelerator = Accelerator()
else:
accelerator = None
@click.option("--config_file", default="configs/train_prior_config.example.json")
def main(config_file):
# start HFA
accelerator = Accelerator()
# load the configuration file on main process
if not exists(accelerator) or accelerator.is_main_process:
click.secho(f"Loading configuration from {config_path}", fg="green")
config = TrainDiffusionPriorConfig.from_json_path(config_path)
# send config to get processed
initialize_training(config, accelerator)
# setup training
initialize_training(config_file, accelerator)
if __name__ == "__main__":