Compare commits

...

24 Commits

Author SHA1 Message Date
Phil Wang
e37072a48c 0.10.0 2022-06-19 08:50:53 -07:00
Phil Wang
41ca896413 depend on huggingface accelerate, move appreciation thread up for visibility 2022-06-19 08:50:35 -07:00
zion
fe19b508ca Distributed Training of the Prior (#112)
* distributed prior trainer

better EMA support

update load and save methods of prior

* update prior training script

add test evalution & ema validation

add more tracking metrics

small cleanup
2022-06-19 08:46:14 -07:00
Phil Wang
6651eafa93 one more residual, after seeing good results on unconditional generation locally 2022-06-16 11:18:02 -07:00
Phil Wang
e6bb75e5ab fix missing residual for highest resolution of the unet 2022-06-15 20:09:43 -07:00
Giorgos Zachariadis
b4c3e5b854 changed str in order to avoid confusions and collisions with Python (#147) 2022-06-15 13:41:16 -07:00
Phil Wang
b7f9607258 make memory efficient unet design from imagen toggle-able 2022-06-15 13:40:26 -07:00
Phil Wang
2219348a6e adopt similar unet architecture as imagen 2022-06-15 12:18:21 -07:00
Phil Wang
9eea9b9862 add p2 loss reweighting for decoder training as an option 2022-06-14 10:58:57 -07:00
Phil Wang
5d958713c0 fix classifier free guidance for image hiddens summed to time hiddens, thanks to @xvjiarui for finding this bug 2022-06-13 21:01:50 -07:00
Phil Wang
0f31980362 cleanup 2022-06-07 17:31:38 -07:00
Phil Wang
bee5bf3815 fix for https://github.com/lucidrains/DALLE2-pytorch/issues/143 2022-06-07 09:03:48 -07:00
Phil Wang
350a3d6045 0.6.16 2022-06-06 08:45:46 -07:00
Kashif Rasul
1a81670718 fix quadratic_beta_schedule (#141) 2022-06-06 08:45:14 -07:00
Phil Wang
934c9728dc some cleanup 2022-06-04 16:54:15 -07:00
Phil Wang
ce4b0107c1 0.6.13 2022-06-04 13:26:57 -07:00
zion
64c2f9c4eb implement ema warmup from @crowsonkb (#140) 2022-06-04 13:26:34 -07:00
Phil Wang
22cc613278 ema fix from @nousr 2022-06-03 19:44:36 -07:00
zion
83517849e5 ema module fixes (#139) 2022-06-03 19:43:51 -07:00
Phil Wang
708809ed6c lower beta2 for adam down to 0.99, based on https://openreview.net/forum?id=2LdBqxc1Yv 2022-06-03 10:26:28 -07:00
Phil Wang
9cc475f6e7 fix update_every within EMA 2022-06-03 10:21:05 -07:00
Phil Wang
ffd342e9d0 allow for an option to constrain the variance interpolation fraction coming out from the unet for learned variance, if it is turned on 2022-06-03 09:34:57 -07:00
Phil Wang
f8bfd3493a make destructuring datum length agnostic when validating in training decoder script, for @YUHANG-Ma 2022-06-02 13:54:57 -07:00
Phil Wang
9025345e29 take a stab at fixing generate_grid_samples when real images have a greater image size than generated 2022-06-02 11:33:15 -07:00
9 changed files with 684 additions and 423 deletions

View File

@@ -31,6 +31,21 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- DALL-E 2 🚧
## Appreciation
This library would not have gotten to this working state without the help of
- <a href="https://github.com/nousr">Zion</a> for the distributed training code for the diffusion prior
- <a href="https://github.com/Veldrovive">Aidan</a> for the distributed training code for the decoder as well as the dataloaders
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
... and many others. Thank you! 🙏
## Install
```bash
@@ -1041,19 +1056,6 @@ Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Appreciation
This library would not have gotten to this working state without the help of
- <a href="https://github.com/nousr">Zion</a> and <a href="https://github.com/krish240574">Kumar</a> for the diffusion training script
- <a href="https://github.com/Veldrovive">Aidan</a> for the decoder training script and dataloaders
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
... and many others. Thank you! 🙏
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
@@ -1207,4 +1209,14 @@ This library would not have gotten to this working state without the help of
}
```
```bibtex
@article{Choi2022PerceptionPT,
title = {Perception Prioritized Training of Diffusion Models},
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
journal = {ArXiv},
year = {2022},
volume = {abs/2204.00227}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,7 +1,6 @@
import math
import random
from tqdm import tqdm
from inspect import isfunction
from functools import partial, wraps
from contextlib import contextmanager
from collections import namedtuple
@@ -12,7 +11,7 @@ import torch.nn.functional as F
from torch import nn, einsum
import torchvision.transforms as T
from einops import rearrange, repeat
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
@@ -57,7 +56,7 @@ def maybe(fn):
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
return d() if callable(d) else d
def cast_tuple(val, length = 1):
if isinstance(val, list):
@@ -314,11 +313,6 @@ def extract(a, t, x_shape):
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def meanflat(x):
return x.mean(dim = tuple(range(1, len(x.shape))))
@@ -373,7 +367,7 @@ def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
def sigmoid_beta_schedule(timesteps):
@@ -385,7 +379,7 @@ def sigmoid_beta_schedule(timesteps):
class BaseGaussianDiffusion(nn.Module):
def __init__(self, *, beta_schedule, timesteps, loss_type):
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
super().__init__()
if beta_schedule == "cosine":
@@ -450,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# p2 loss reweighting
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
@@ -946,10 +945,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = noise_like(x.shape, device, repeat_noise)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@@ -1085,8 +1084,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
def Downsample(dim, *, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, 4, 2, 1)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
@@ -1352,6 +1352,7 @@ class Unet(nn.Module):
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False,
**kwargs
):
super().__init__()
@@ -1371,7 +1372,7 @@ class Unet(nn.Module):
self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 3 * 2)
init_dim = default(init_dim, dim)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
@@ -1428,6 +1429,7 @@ class Unet(nn.Module):
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
@@ -1461,10 +1463,11 @@ class Unet(nn.Module):
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
downsample_klass(dim_out) if not is_last else nn.Identity()
downsample_klass(dim_out) if not is_last and not memory_efficient else None
]))
mid_dim = dims[-1]
@@ -1473,19 +1476,19 @@ class Unet(nn.Module):
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (num_resolutions - 2)
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (len(in_out) - 1)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_in)
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
]))
self.final_conv = nn.Sequential(
ResnetBlock(dim, dim, groups = resnet_groups[0]),
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1)
)
@@ -1557,6 +1560,7 @@ class Unet(nn.Module):
# initial convolution
x = self.init_conv(x)
r = x.clone() # final residual
# time conditioning
@@ -1565,19 +1569,28 @@ class Unet(nn.Module):
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
t = t + image_hiddens
# conditional dropout
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
image_hiddens = torch.where(
image_keep_mask_hidden,
image_hiddens,
null_image_hiddens
)
t = t + image_hiddens
# mask out image embedding depending on condition dropout
# for classifier free guidance
@@ -1585,11 +1598,12 @@ class Unet(nn.Module):
image_tokens = None
if self.cond_on_image_embeds:
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
image_tokens = self.image_to_tokens(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where(
image_keep_mask,
image_keep_mask_embed,
image_tokens,
null_image_embed
)
@@ -1644,7 +1658,10 @@ class Unet(nn.Module):
hiddens = []
for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
x = init_block(x, c, t)
x = sparse_attn(x)
@@ -1652,7 +1669,9 @@ class Unet(nn.Module):
x = resnet_block(x, c, t)
hiddens.append(x)
x = downsample(x)
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, mid_c, t)
@@ -1662,7 +1681,7 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = torch.cat((x, hiddens.pop()), dim = 1)
x = init_block(x, c, t)
x = sparse_attn(x)
@@ -1671,6 +1690,7 @@ class Unet(nn.Module):
x = upsample(x)
x = torch.cat((x, r), dim = 1)
return self.final_conv(x)
class LowresConditioner(nn.Module):
@@ -1745,16 +1765,21 @@ class Decoder(BaseGaussianDiffusion):
clip_x_start = True,
clip_adapter_overrides = dict(),
learned_variance = True,
learned_variance_constrain_frac = False,
vb_loss_weight = 0.001,
unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9
dynamic_thres_percentile = 0.9,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1
):
super().__init__(
beta_schedule = beta_schedule,
timesteps = timesteps,
loss_type = loss_type
loss_type = loss_type,
p2_loss_weight_gamma = p2_loss_weight_gamma,
p2_loss_weight_k = p2_loss_weight_k
)
self.unconditional = unconditional
@@ -1805,6 +1830,7 @@ class Decoder(BaseGaussianDiffusion):
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
self.learned_variance = learned_variance
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
self.vb_loss_weight = vb_loss_weight
# construct unets and vaes
@@ -1945,16 +1971,19 @@ class Decoder(BaseGaussianDiffusion):
max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
if self.learned_variance_constrain_frac:
var_interp_frac = var_interp_frac.sigmoid()
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
noise = noise_like(x.shape, device, repeat_noise)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@@ -2018,7 +2047,13 @@ class Decoder(BaseGaussianDiffusion):
target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target)
loss = self.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
if self.has_p2_loss_reweighting:
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
loss = loss.mean()
if not learned_variance:
# return simple loss if not using learned variance

View File

@@ -11,7 +11,7 @@ def get_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.999),
betas = (0.9, 0.99),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,

View File

@@ -14,6 +14,8 @@ from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
from accelerate import Accelerator
import numpy as np
# helper functions
@@ -58,8 +60,15 @@ def num_to_groups(num, divisor):
arr.append(remainder)
return arr
def get_pkg_version():
return __version__
def clamp(value, min_value = None, max_value = None):
assert exists(min_value) or exists(max_value)
if exists(min_value):
value = max(value, min_value)
if exists(max_value):
value = min(value, max_value)
return value
# decorators
@@ -175,12 +184,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
# exponential moving average wrapper
class EMA(nn.Module):
"""
Implements exponential moving average shadowing for your model.
Utilizes an inverse decay schedule to manage longer term training runs.
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
def __init__(
self,
model,
beta = 0.99,
update_after_step = 1000,
beta = 0.9999,
update_after_step = 100,
update_every = 10,
inv_gamma = 1.0,
power = 2/3,
min_value = 0.0,
):
super().__init__()
self.beta = beta
@@ -188,7 +219,11 @@ class EMA(nn.Module):
self.ema_model = copy.deepcopy(model)
self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0]))
@@ -198,41 +233,56 @@ class EMA(nn.Module):
self.ema_model.to(device)
def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict())
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
ma_param.data.copy_(current_param.data)
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
ma_buffer.data.copy_(current_buffer.data)
def get_current_decay(self):
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
if epoch <= 0:
return 0.
return clamp(value, min_value = self.min_value, max_value = self.beta)
def update(self):
step = self.step.item()
self.step += 1
if (self.step % self.update_every) != 0:
if (step % self.update_every) != 0:
return
if self.step <= self.update_after_step:
if step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return
if not self.initted:
if not self.initted.item():
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
@torch.no_grad()
def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new
current_decay = self.get_current_decay()
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
difference = ma_params.data - current_params.data
difference.mul_(1.0 - current_decay)
ma_params.sub_(difference)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
difference = ma_buffer - current_buffer
difference.mul_(1.0 - current_decay)
ma_buffer.sub_(difference)
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# diffusion prior trainer
def prior_sample_in_chunks(fn):
@@ -256,88 +306,189 @@ class DiffusionPriorTrainer(nn.Module):
max_grad_norm = None,
amp = False,
group_wd_params = True,
device = None,
accelerator = None,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
# assign some helpful member vars
self.accelerator = accelerator
self.device = accelerator.device if exists(accelerator) else device
self.text_conditioned = diffusion_prior.condition_on_text_encodings
# save model
self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
self.optimizer = get_optimizer(
diffusion_prior.parameters(),
lr = lr,
wd = wd,
eps = eps,
group_wd_params = group_wd_params,
self.diffusion_prior.parameters(),
**self.optim_kwargs,
**kwargs
)
# distribute the model if using HFA
if exists(self.accelerator):
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
# exponential moving average stuff
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
# track steps internally
self.register_buffer('step', torch.tensor([0]))
# accelerator wrappers
def print(self, msg):
if exists(self.accelerator):
self.accelerator.print(msg)
else:
print(msg)
def unwrap_model(self, model):
if exists(self.accelerator):
return self.accelerator.unwrap_model(model)
else:
return model
def wait_for_everyone(self):
if exists(self.accelerator):
self.accelerator.wait_for_everyone()
def is_main_process(self):
if exists(self.accelerator):
return self.accelerator.is_main_process
else:
return True
def clip_grad_norm_(self, *args):
if exists(self.accelerator):
return self.accelerator.clip_grad_norm_(*args)
else:
return torch.nn.utils.clip_grad_norm_(*args)
def backprop(self, x):
if exists(self.accelerator):
self.accelerator.backward(x)
else:
try:
x.backward()
except Exception as e:
self.print(f"Caught error in backprop call: {e}")
# utility
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
# ensure we sync gradients before continuing
self.wait_for_everyone()
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = __version__,
step = self.step.item(),
**kwargs
)
# only save on the main process
if self.is_main_process():
self.print(f"Saving checkpoint at step: {self.step.item()}")
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
version = version.parse(__version__),
step = self.step.item(),
**kwargs
)
torch.save(save_obj, str(path))
if self.use_ema:
save_obj = {
**save_obj,
'ema': self.ema_diffusion_prior.state_dict(),
'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload
}
def load(self, path, only_model = False, strict = True):
torch.save(save_obj, str(path))
def load(self, path, overwrite_lr = True, strict = True):
"""
Load a checkpoint of a diffusion prior trainer.
Will load the entire trainer, including the optimizer and EMA.
Params:
- path (str): a path to the DiffusionPriorTrainer checkpoint file
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
Returns:
loaded_obj (dict): The loaded checkpoint dictionary
"""
# all processes need to load checkpoint. no restriction here
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
loaded_obj = torch.load(str(path), map_location=self.device)
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
# unwrap the model when loading from checkpoint
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
if overwrite_lr:
new_lr = self.optim_kwargs["lr"]
self.print(f"Overriding LR to be {new_lr}")
for group in self.optimizer.param_groups:
group["lr"] = new_lr
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
# sync and inform
self.wait_for_everyone()
self.print(f"Loaded model")
return loaded_obj
# model functionality
def update(self):
# only continue with updates until all ranks finish
self.wait_for_everyone()
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
# utilize HFA clipping where applicable
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
@@ -352,17 +503,32 @@ class DiffusionPriorTrainer(nn.Module):
@cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
if self.use_ema:
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
else:
return self.diffusion_prior.p_sample_loop(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
if self.use_ema:
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
else:
return self.diffusion_prior.sample(*args, **kwargs)
@torch.no_grad()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
if self.use_ema:
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
else:
return self.diffusion_prior.sample_batch_size(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
@cast_torch_tensor
def forward(
@@ -380,8 +546,10 @@ class DiffusionPriorTrainer(nn.Module):
total_loss += loss.item()
# backprop with accelerate if applicable
if self.training:
self.scaler.scale(loss).backward()
self.backprop(self.scaler.scale(loss))
return total_loss
@@ -488,7 +656,7 @@ class DecoderTrainer(nn.Module):
loaded_obj = torch.load(str(path))
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])

View File

@@ -1 +1 @@
__version__ = '0.6.8'
__version__ = '0.10.0'

View File

@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def string_begins_with(prefix, string_input):
return string_input.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)

View File

@@ -24,6 +24,7 @@ setup(
'text to image'
],
install_requires=[
'accelerate',
'click',
'clip-anytorch',
'coca-pytorch>=0.0.5',

View File

@@ -4,6 +4,7 @@ from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
import torchvision
import torch
@@ -136,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
@@ -202,7 +211,7 @@ def recall_trainer(tracker, trainer, recall_source=None, **load_config):
Loads the model with an appropriate method depending on the tracker
"""
print(print_ribbon(f"Loading model from {recall_source}"))
state_dict = tracker.recall_state_dict(recall_source, **load_config)
state_dict = tracker.recall_state_dict(recall_source, **load_config.dict())
trainer.load_state_dict(state_dict["trainer"])
print("Model loaded")
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
@@ -322,7 +331,7 @@ def train(
sample = 0
average_loss = 0
timer = Timer()
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
for i, (img, emb, *_) in enumerate(dataloaders["val"]):
sample += img.shape[0]
img, emb = send_to_device((img, emb))

View File

@@ -1,77 +1,136 @@
from pathlib import Path
# TODO: add start, num_data_points, eval_every and group to config
# TODO: switch back to repo's wandb
START = 0
NUM_DATA_POINTS = 250e6
EVAL_EVERY = 1000
GROUP = "distributed"
import os
import click
import math
import numpy as np
import wandb
import torch
import clip
from torch import nn
from torch.nn.functional import normalize
from torch.utils.data import DataLoader
from dalle2_pytorch.dataloaders import make_splits, get_reader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
import numpy as np
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer, print_ribbon
from accelerate import Accelerator
from tqdm import tqdm
from dalle2_pytorch.dataloaders import get_reader, make_splits
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.train_configs import (
DiffusionPriorTrainConfig,
TrainDiffusionPriorConfig,
)
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
from dalle2_pytorch import DiffusionPriorTrainer
# constants
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
# helpers
tracker = WandbTracker()
# helpers functions
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def exists(val):
val is not None
return val is not None
# functions
def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
model.eval()
def make_model(
prior_config, train_config, device: str = None, accelerator: Accelerator = None
):
# create model from config
diffusion_prior = prior_config.create()
# instantiate the trainer
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=train_config.lr,
wd=train_config.wd,
max_grad_norm=train_config.max_grad_norm,
amp=train_config.amp,
use_ema=train_config.use_ema,
device=device,
accelerator=accelerator,
)
return trainer
# eval functions
def eval_model(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
loss_type: str,
phase: str,
tracker: BaseTracker = None,
use_ema: bool = True,
):
trainer.eval()
if trainer.is_main_process():
click.secho(f"Measuring performance on {phase}", fg="green", blink=True)
with torch.no_grad():
total_loss = 0.
total_samples = 0.
total_loss = 0.0
total_samples = 0.0
for image_embeddings, text_data in tqdm(dataloader):
image_embeddings = image_embeddings.to(device)
text_data = text_data.to(device)
for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
batches = image_embeddings.shape[0]
input_args = dict(image_embed=image_embeddings)
if text_conditioned:
input_args = dict(**input_args, text = text_data)
input_args = dict(**input_args, text=text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
loss = model(**input_args)
if use_ema:
loss = trainer.ema_diffusion_prior(**input_args)
else:
loss = trainer(**input_args)
total_loss += loss * batches
total_samples += batches
avg_loss = (total_loss / total_samples)
avg_loss = total_loss / total_samples
tracker.log({f'{phase} {loss_type}': avg_loss})
stats = {f"{phase}/{loss_type}": avg_loss}
trainer.print(stats)
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
diffusion_prior.eval()
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for test_image_embeddings, text_data in tqdm(dataloader):
test_image_embeddings = test_image_embeddings.to(device)
text_data = text_data.to(device)
def report_cosine_sims(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
tracker: BaseTracker,
phase: str = "validation",
):
trainer.eval()
if trainer.is_main_process():
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
for test_image_embeddings, text_data in dataloader:
test_image_embeddings = test_image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
text_data)
text_cond = dict(text_embed=text_embedding,
text_encodings=text_encodings, mask=text_mask)
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
)
else:
text_embedding = text_data
text_cond = dict(text_embed=text_embedding)
@@ -82,8 +141,7 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
# roll the text to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
text_embed_shuffled = text_embed_shuffled / normalize(text_embed_shuffled)
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
@@ -92,294 +150,272 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
text_cond_shuffled = dict(
text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
)
# prepare the text embedding
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
text_embed = normalize(text_embedding / text_embedding)
# prepare image embeddings
test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True)
test_image_embeddings = test_image_embeddings / normalize(test_image_embeddings)
# predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop(
test_image_embeddings.shape, text_cond)
predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond
)
predicted_image_embeddings = predicted_image_embeddings / normalize(
predicted_image_embeddings
)
# predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
test_image_embeddings.shape, text_cond_shuffled)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
predicted_unrelated_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond_shuffled
)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / normalize(
predicted_unrelated_embeddings
)
# calculate similarities
original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = (
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
)
predicted_img_similarity = (
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
)
stats = {
f"{phase}/baseline similarity": np.mean(original_similarity),
f"{phase}/similarity with text": np.mean(predicted_similarity),
f"{phase}/similarity with original image": np.mean(
predicted_img_similarity
),
f"{phase}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{phase}/difference from baseline similarity": np.mean(
predicted_similarity - original_similarity
),
}
for k, v in stats.items():
trainer.print(f"{phase}/{k}: {v}")
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
# training script
def train(
trainer: DiffusionPriorTrainer,
train_loader: DataLoader,
eval_loader: DataLoader,
test_loader: DataLoader,
config: DiffusionPriorTrainConfig,
):
# distributed tracking with wandb
if trainer.accelerator.num_processes > 1:
os.environ["WANDB_START_METHOD"] = "thread"
tracker = wandb.init(
name=f"RANK:{trainer.device}",
entity=config.tracker.wandb_entity,
project=config.tracker.wandb_project,
config=config.dict(),
group=GROUP,
)
# sync after tracker init
trainer.wait_for_everyone()
# init a timer
timer = Timer()
# do training
for img, txt in train_loader:
trainer.train()
current_step = trainer.step.item() + 1
# place data on device
img = img.to(trainer.device)
txt = txt.to(trainer.device)
# pass to model
loss = trainer(text=txt, image_embed=img)
# display & log loss (will only print from main process)
trainer.print(f"Step {current_step}: Loss {loss}")
# perform backprop & apply EMA updates
trainer.update()
# track samples/sec/rank
samples_per_sec = img.shape[0] / timer.elapsed()
# samples seen
samples_seen = (
config.data.batch_size * trainer.accelerator.num_processes * current_step
)
# ema decay
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# Log on all processes for debugging
tracker.log(
{
"training/loss": loss,
"samples/sec/rank": samples_per_sec,
"samples/seen": samples_seen,
"ema/decay": ema_decay,
},
step=current_step,
)
# Metric Tracking & Checkpointing (outside of timer's scope)
if current_step % EVAL_EVERY == 0:
eval_model(
trainer,
eval_loader,
config.prior.condition_on_text_encodings,
config.prior.loss_type,
"training/validation",
tracker,
use_ema=False,
)
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss=config.prior.loss_type,
phase="ema/validation",
tracker=tracker,
use_ema=True,
)
report_cosine_sims(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
tracker=tracker,
phase="ema/validation",
)
if current_step % config.train.save_every == 0:
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
# reset timer for next round
timer.reset()
# evaluate on test data
eval_model(
trainer=trainer,
dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
phase="test",
tracker=tracker,
)
report_cosine_sims(
trainer,
test_loader,
config.prior.condition_on_text_encodings,
tracker,
phase="test",
)
def initialize_training(config, accelerator=None):
"""
Parse the configuration file, and prepare everything necessary for training
"""
# get a device
if accelerator:
device = accelerator.device
click.secho(f"Accelerating on: {device}", fg="yellow")
else:
if torch.cuda.is_available():
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
device = "cuda:0"
else:
click.secho("No GPU detected...using cpu", fg="yellow")
device = "cpu"
# make the trainer (will automatically distribute if possible & configured)
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
# reload from chcekpoint
if config.load.resume == True:
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
trainer.load(config.load.source)
# fetch and prepare data
if trainer.is_main_process():
click.secho("Grabbing data from source", fg="blue", blink=True)
img_reader = get_reader(
text_conditioned=trainer.text_conditioned,
img_url=config.data.image_url,
meta_url=config.data.meta_url,
)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=trainer.text_conditioned,
batch_size=config.data.batch_size,
num_data_points=NUM_DATA_POINTS,
train_split=config.data.splits.train,
eval_split=config.data.splits.val,
image_reader=img_reader,
rank=accelerator.state.process_index if exists(accelerator) else 0,
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
start=START,
)
# wait for everyone to load data before continuing
trainer.wait_for_everyone()
# start training
train(
trainer=trainer,
train_loader=train_loader,
eval_loader=eval_loader,
test_loader=test_loader,
config=config,
)
@click.command()
@click.option("--wandb-entity", default="laion")
@click.option("--wandb-project", default="diffusion-prior")
@click.option("--wandb-dataset", default="LAION-5B")
@click.option("--wandb-arch", default="DiffusionPrior")
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
@click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5)
@click.option("--num-data-points", default=250e6)
@click.option("--batch-size", default=320)
@click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.9)
@click.option("--val-percent", default=1e-7)
@click.option("--test-percent", default=0.0999999)
@click.option("--dpn-depth", default=12)
@click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=12)
@click.option("--dp-condition-on-text-encodings", default=True)
@click.option("--dp-timesteps", default=1000)
@click.option("--dp-normformer", default=True)
@click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2")
@click.option("--clip", default="ViT-L/14")
@click.option("--amp", default=False)
@click.option("--save-interval", default=120)
@click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None)
@click.option("--gpu-device", default=0)
def train(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
meta_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
num_data_points,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path,
gpu_device
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device(f"cuda:{gpu_device}")
torch.cuda.set_device(device)
# Training loop
# diffusion prior network
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer
)
# Load clip model if text-conditioning
if dp_condition_on_text_encodings:
clip_adapter = OpenAIClipAdapter(clip)
@click.option("--hfa", default=True)
@click.option("--config_path", default="configs/prior.json")
def main(hfa, config_path):
# start HFA if requested
if hfa:
accelerator = Accelerator()
else:
clip_adapter = None
accelerator = None
# diffusion prior with text embeddings and image embeddings pre-computed
# load the configuration file on main process
if not exists(accelerator) or accelerator.is_main_process:
click.secho(f"Loading configuration from {config_path}", fg="green")
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip_adapter,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings
)
config = TrainDiffusionPriorConfig.from_json_path(config_path)
# Load pre-trained model from DPRIOR_PATH
if RESUME:
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
# diffusion prior trainer
trainer = DiffusionPriorTrainer(
diffusion_prior = diffusion_prior,
lr = learning_rate,
wd = weight_decay,
max_grad_norm = max_grad_norm,
amp = amp,
).to(device)
# load optimizer and scaler
if RESUME:
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
trainer.scaler.load_state_dict(loaded_obj['scaler'])
# Create save_path if it doesn't exist
Path(save_path).mkdir(exist_ok = True, parents = True)
# Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings")
reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
if dp_condition_on_text_encodings:
reader_args = dict(**reader_args, meta_url=meta_url)
img_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader
)
else:
reader_args = dict(**reader_args, txt_url=text_embed_url)
img_reader, txt_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader,
text_reader=txt_reader
)
### Training code ###
step = 1
timer = Timer()
epochs = num_epochs
for _ in range(epochs):
for image, text in tqdm(train_loader):
diffusion_prior.train()
image = image.to(device)
text = text.to(device)
input_args = dict(image_embed=image)
if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text)
else:
input_args = dict(**input_args, text_embed=text)
loss = trainer(**input_args)
# Samples per second
samples_per_sec = batch_size * step / timer.elapsed()
# Save checkpoint every save_interval minutes
if(int(timer.elapsed()) >= 60 * save_interval):
timer.reset()
save_diffusion_model(
save_path,
diffusion_prior,
trainer.optimizer,
trainer.scaler,
config,
image_embed_dim)
# Log to wandb
tracker.log({"Training loss": loss,
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
step += 1
trainer.update()
### Test run ###
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
# send config to get processed
initialize_training(config, accelerator)
if __name__ == "__main__":
train()
main()