Compare commits

...

13 Commits

Author SHA1 Message Date
Phil Wang
138079ca83 allow for setting beta schedules of unets differently in the decoder, as what was used in the paper was cosine, cosine, linear 2022-06-20 08:56:37 -07:00
zion
f5a906f5d3 prior train script bug fixes (#153) 2022-06-19 15:55:15 -07:00
Phil Wang
0215237fc6 update status 2022-06-19 09:42:24 -07:00
Phil Wang
461b91c5c1 also merge distributed training code for decoder, thanks to @Veldrovive 2022-06-19 09:26:44 -07:00
Aidan Dempster
58892135d9 Distributed Training of the Decoder (#121)
* Converted decoder trainer to use accelerate

* Fixed issue where metric evaluation would hang on distributed mode

* Implemented functional saving
Loading still fails due to some issue with the optimizer

* Fixed issue with loading decoders

* Fixed issue with tracker config

* Fixed issue with amp
Updated logging to be more logical

* Saving checkpoint now saves position in training as well
Fixed an issue with running out of gpu space due to loading weights into the gpu twice

* Fixed ema for distributed training

* Fixed isue where get_pkg_version was reintroduced

* Changed decoder trainer to upload config as a file

Fixed issue where loading best would error
2022-06-19 09:25:54 -07:00
Phil Wang
e37072a48c 0.10.0 2022-06-19 08:50:53 -07:00
Phil Wang
41ca896413 depend on huggingface accelerate, move appreciation thread up for visibility 2022-06-19 08:50:35 -07:00
zion
fe19b508ca Distributed Training of the Prior (#112)
* distributed prior trainer

better EMA support

update load and save methods of prior

* update prior training script

add test evalution & ema validation

add more tracking metrics

small cleanup
2022-06-19 08:46:14 -07:00
Phil Wang
6651eafa93 one more residual, after seeing good results on unconditional generation locally 2022-06-16 11:18:02 -07:00
Phil Wang
e6bb75e5ab fix missing residual for highest resolution of the unet 2022-06-15 20:09:43 -07:00
Giorgos Zachariadis
b4c3e5b854 changed str in order to avoid confusions and collisions with Python (#147) 2022-06-15 13:41:16 -07:00
Phil Wang
b7f9607258 make memory efficient unet design from imagen toggle-able 2022-06-15 13:40:26 -07:00
Phil Wang
2219348a6e adopt similar unet architecture as imagen 2022-06-15 12:18:21 -07:00
12 changed files with 968 additions and 635 deletions

View File

@@ -27,10 +27,27 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
- DALL-E 2 🚧
## Appreciation
This library would not have gotten to this working state without the help of
- <a href="https://github.com/nousr">Zion</a> for the distributed training code for the diffusion prior
- <a href="https://github.com/Veldrovive">Aidan</a> for the distributed training code for the decoder as well as the dataloaders
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
... and many others. Thank you! 🙏
## Install
```bash
@@ -1041,19 +1058,6 @@ Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Appreciation
This library would not have gotten to this working state without the help of
- <a href="https://github.com/nousr">Zion</a> and <a href="https://github.com/krish240574">Kumar</a> for the diffusion training script
- <a href="https://github.com/Veldrovive">Aidan</a> for the decoder training script and dataloaders
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
... and many others. Thank you! 🙏
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon

View File

@@ -378,7 +378,7 @@ def sigmoid_beta_schedule(timesteps):
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
class BaseGaussianDiffusion(nn.Module):
class NoiseScheduler(nn.Module):
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
super().__init__()
@@ -472,11 +472,10 @@ class BaseGaussianDiffusion(nn.Module):
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def sample(self, *args, **kwargs):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def p2_reweigh_loss(self, loss, times):
if not self.has_p2_loss_reweighting:
return loss
return loss * extract(self.p2_loss_weight, times, loss.shape)
# diffusion prior
@@ -687,8 +686,7 @@ class Attention(nn.Module):
# attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.dropout(attn)
# aggregate values
@@ -862,7 +860,7 @@ class DiffusionPriorNetwork(nn.Module):
return pred_image_embed
class DiffusionPrior(BaseGaussianDiffusion):
class DiffusionPrior(nn.Module):
def __init__(
self,
net,
@@ -883,7 +881,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__(
super().__init__()
self.noise_scheduler = NoiseScheduler(
beta_schedule = beta_schedule,
timesteps = timesteps,
loss_type = loss_type
@@ -923,6 +923,13 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
# device tracker
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -933,7 +940,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.)
@@ -941,7 +948,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
@@ -955,7 +962,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
@torch.no_grad()
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
device = self.betas.device
device = self.device
b = shape[0]
image_embed = torch.randn(shape, device=device)
@@ -963,7 +970,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
@@ -972,7 +979,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
pred = self.net(
image_embed_noisy,
@@ -986,7 +993,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
target = noise if not self.predict_x_start else image_embed
loss = self.loss_fn(pred, target)
loss = self.noise_scheduler.loss_fn(pred, target)
return loss
@torch.no_grad()
@@ -997,7 +1004,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = self.noise_scheduler.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
return img
@@ -1069,7 +1076,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long)
# scale image embed (Katherine)
@@ -1084,8 +1091,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
def Downsample(dim, *, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, 4, 2, 1)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
@@ -1233,8 +1241,7 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1351,6 +1358,7 @@ class Unet(nn.Module):
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False,
**kwargs
):
super().__init__()
@@ -1370,7 +1378,7 @@ class Unet(nn.Module):
self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 3 * 2)
init_dim = default(init_dim, dim)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
@@ -1461,10 +1469,11 @@ class Unet(nn.Module):
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
downsample_klass(dim_out) if not is_last else nn.Identity()
downsample_klass(dim_out) if not is_last and not memory_efficient else None
]))
mid_dim = dims[-1]
@@ -1473,19 +1482,19 @@ class Unet(nn.Module):
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (num_resolutions - 2)
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (len(in_out) - 1)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_in)
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
]))
self.final_conv = nn.Sequential(
ResnetBlock(dim, dim, groups = resnet_groups[0]),
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1)
)
@@ -1557,6 +1566,7 @@ class Unet(nn.Module):
# initial convolution
x = self.init_conv(x)
r = x.clone() # final residual
# time conditioning
@@ -1654,7 +1664,10 @@ class Unet(nn.Module):
hiddens = []
for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
x = init_block(x, c, t)
x = sparse_attn(x)
@@ -1662,7 +1675,9 @@ class Unet(nn.Module):
x = resnet_block(x, c, t)
hiddens.append(x)
x = downsample(x)
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, mid_c, t)
@@ -1672,7 +1687,7 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = torch.cat((x, hiddens.pop()), dim = 1)
x = init_block(x, c, t)
x = sparse_attn(x)
@@ -1681,6 +1696,7 @@ class Unet(nn.Module):
x = upsample(x)
x = torch.cat((x, r), dim = 1)
return self.final_conv(x)
class LowresConditioner(nn.Module):
@@ -1729,7 +1745,7 @@ class LowresConditioner(nn.Module):
return cond_fmap
class Decoder(BaseGaussianDiffusion):
class Decoder(nn.Module):
def __init__(
self,
unet,
@@ -1742,7 +1758,7 @@ class Decoder(BaseGaussianDiffusion):
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l2',
beta_schedule = 'cosine',
beta_schedule = None,
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
@@ -1764,13 +1780,7 @@ class Decoder(BaseGaussianDiffusion):
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1
):
super().__init__(
beta_schedule = beta_schedule,
timesteps = timesteps,
loss_type = loss_type,
p2_loss_weight_gamma = p2_loss_weight_gamma,
p2_loss_weight_k = p2_loss_weight_k
)
super().__init__()
self.unconditional = unconditional
@@ -1814,6 +1824,8 @@ class Decoder(BaseGaussianDiffusion):
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet)
num_unets = len(unets)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
@@ -1849,6 +1861,24 @@ class Decoder(BaseGaussianDiffusion):
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# create noise schedulers per unet
if not exists(beta_schedule):
beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
self.noise_schedulers = nn.ModuleList([])
for unet_beta_schedule in beta_schedule:
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,
timesteps = timesteps,
loss_type = loss_type,
p2_loss_weight_gamma = p2_loss_weight_gamma,
p2_loss_weight_k = p2_loss_weight_k
)
self.noise_schedulers.append(noise_scheduler)
# unet image sizes
image_sizes = default(image_sizes, (image_size,))
@@ -1898,6 +1928,14 @@ class Decoder(BaseGaussianDiffusion):
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
# device tracker
self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
@@ -1921,7 +1959,7 @@ class Decoder(BaseGaussianDiffusion):
for unet, device in zip(self.unets, devices):
unet.to(device)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
@@ -1932,7 +1970,7 @@ class Decoder(BaseGaussianDiffusion):
if predict_x_start:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
# s is the threshold amount
@@ -1951,14 +1989,14 @@ class Decoder(BaseGaussianDiffusion):
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network
# by an interpolation of the max and min log beta values
# eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(noise_scheduler.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
if self.learned_variance_constrain_frac:
@@ -1970,17 +2008,17 @@ class Decoder(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.betas.device
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.device
b = shape[0]
img = torch.randn(shape, device = device)
@@ -1988,7 +2026,7 @@ class Decoder(BaseGaussianDiffusion):
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
img = self.p_sample(
unet,
img,
@@ -1999,6 +2037,7 @@ class Decoder(BaseGaussianDiffusion):
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
@@ -2006,7 +2045,7 @@ class Decoder(BaseGaussianDiffusion):
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
@@ -2017,7 +2056,7 @@ class Decoder(BaseGaussianDiffusion):
# get x_t
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
model_output = unet(
x_noisy,
@@ -2037,11 +2076,10 @@ class Decoder(BaseGaussianDiffusion):
target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target, reduction = 'none')
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
if self.has_p2_loss_reweighting:
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
loss = noise_scheduler.p2_reweigh_loss(loss, times)
loss = loss.mean()
@@ -2056,8 +2094,8 @@ class Decoder(BaseGaussianDiffusion):
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper
@@ -2089,7 +2127,8 @@ class Decoder(BaseGaussianDiffusion):
text_encodings = None,
batch_size = 1,
cond_scale = 1.,
stop_at_unet_number = None
stop_at_unet_number = None,
distributed = False,
):
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
@@ -2106,9 +2145,9 @@ class Decoder(BaseGaussianDiffusion):
img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
with context:
lowres_cond_img = None
@@ -2134,7 +2173,8 @@ class Decoder(BaseGaussianDiffusion):
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
is_latent_diffusion = is_latent_diffusion
is_latent_diffusion = is_latent_diffusion,
noise_scheduler = noise_scheduler
)
img = vae.decode(img)
@@ -2160,6 +2200,7 @@ class Decoder(BaseGaussianDiffusion):
unet = self.get_unet(unet_number)
vae = self.vaes[unet_index]
noise_scheduler = self.noise_schedulers[unet_index]
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
@@ -2169,7 +2210,7 @@ class Decoder(BaseGaussianDiffusion):
check_shape(image, 'b c h w', c = self.channels)
assert h >= target_image_size and w >= target_image_size
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed) and not self.unconditional:
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
@@ -2200,7 +2241,7 @@ class Decoder(BaseGaussianDiffusion):
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
# main class

View File

@@ -164,9 +164,6 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
self.append(wds.split_by_node)
self.append(wds.split_by_worker)
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("pilrgb", handler=handler))
if embedding_folder_url is not None:

View File

@@ -17,15 +17,15 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
# load state dict functions
# load file functions
def load_wandb_state_dict(run_path, file_path, **kwargs):
def load_wandb_file(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return torch.load(file_reference.name)
return file_reference.name
def load_local_state_dict(file_path, **kwargs):
return torch.load(file_path)
def load_local_file(file_path, **kwargs):
return file_path
# base class
@@ -55,12 +55,43 @@ class BaseTracker(nn.Module):
"""
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
if recall_source == 'wandb':
return load_wandb_state_dict(*args, **kwargs)
return torch.load(load_wandb_file(*args, **kwargs))
elif recall_source == 'local':
return load_local_state_dict(*args, **kwargs)
return torch.load(load_local_file(*args, **kwargs))
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
def save_file(self, file_path, **kwargs):
raise NotImplementedError
def recall_file(self, recall_source, *args, **kwargs):
if recall_source == 'wandb':
return load_wandb_file(*args, **kwargs)
elif recall_source == 'local':
return load_local_file(*args, **kwargs)
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
# Tracker that no-ops all calls except for recall
class DummyTracker(BaseTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def init(self, config, **kwargs):
pass
def log(self, log, **kwargs):
pass
def log_images(self, images, **kwargs):
pass
def save_state_dict(self, state_dict, relative_path, **kwargs):
pass
def save_file(self, file_path, **kwargs):
pass
# basic stdout class
@@ -76,6 +107,10 @@ class ConsoleTracker(BaseTracker):
def save_state_dict(self, state_dict, relative_path, **kwargs):
torch.save(state_dict, str(self.data_path / relative_path))
def save_file(self, file_path, **kwargs):
# This is a no-op for local file systems since it is already saved locally
pass
# basic wandb class
@@ -107,3 +142,11 @@ class WandbTracker(BaseTracker):
full_path = str(self.data_path / relative_path)
torch.save(state_dict, full_path)
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
def save_file(self, file_path, base_path=None, **kwargs):
"""
Uploads a file from disk to wandb
"""
if base_path is None:
base_path = self.data_path
self.wandb.save(str(file_path), base_path = str(base_path))

View File

@@ -173,7 +173,7 @@ class DecoderConfig(BaseModel):
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
@@ -261,6 +261,7 @@ class TrainDecoderConfig(BaseModel):
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
seed: int = 0
@classmethod
def from_json_path(cls, json_path):

View File

@@ -14,6 +14,8 @@ from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
from accelerate import Accelerator
import numpy as np
# helper functions
@@ -22,7 +24,9 @@ def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
@@ -189,13 +193,13 @@ class EMA(nn.Module):
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
@@ -205,7 +209,7 @@ class EMA(nn.Module):
self,
model,
beta = 0.9999,
update_after_step = 10000,
update_after_step = 100,
update_every = 10,
inv_gamma = 1.0,
power = 2/3,
@@ -280,6 +284,7 @@ class EMA(nn.Module):
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# diffusion prior trainer
def prior_sample_in_chunks(fn):
@@ -303,88 +308,189 @@ class DiffusionPriorTrainer(nn.Module):
max_grad_norm = None,
amp = False,
group_wd_params = True,
device = None,
accelerator = None,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
# assign some helpful member vars
self.accelerator = accelerator
self.device = accelerator.device if exists(accelerator) else device
self.text_conditioned = diffusion_prior.condition_on_text_encodings
# save model
self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
self.optimizer = get_optimizer(
diffusion_prior.parameters(),
lr = lr,
wd = wd,
eps = eps,
group_wd_params = group_wd_params,
self.diffusion_prior.parameters(),
**self.optim_kwargs,
**kwargs
)
# distribute the model if using HFA
if exists(self.accelerator):
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
# exponential moving average stuff
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
# track steps internally
self.register_buffer('step', torch.tensor([0]))
# accelerator wrappers
def print(self, msg):
if exists(self.accelerator):
self.accelerator.print(msg)
else:
print(msg)
def unwrap_model(self, model):
if exists(self.accelerator):
return self.accelerator.unwrap_model(model)
else:
return model
def wait_for_everyone(self):
if exists(self.accelerator):
self.accelerator.wait_for_everyone()
def is_main_process(self):
if exists(self.accelerator):
return self.accelerator.is_main_process
else:
return True
def clip_grad_norm_(self, *args):
if exists(self.accelerator):
return self.accelerator.clip_grad_norm_(*args)
else:
return torch.nn.utils.clip_grad_norm_(*args)
def backprop(self, x):
if exists(self.accelerator):
self.accelerator.backward(x)
else:
try:
x.backward()
except Exception as e:
self.print(f"Caught error in backprop call: {e}")
# utility
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
# ensure we sync gradients before continuing
self.wait_for_everyone()
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = __version__,
step = self.step.item(),
**kwargs
)
# only save on the main process
if self.is_main_process():
self.print(f"Saving checkpoint at step: {self.step.item()}")
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
version = version.parse(__version__),
step = self.step.item(),
**kwargs
)
torch.save(save_obj, str(path))
if self.use_ema:
save_obj = {
**save_obj,
'ema': self.ema_diffusion_prior.state_dict(),
'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload
}
def load(self, path, only_model = False, strict = True):
torch.save(save_obj, str(path))
def load(self, path, overwrite_lr = True, strict = True):
"""
Load a checkpoint of a diffusion prior trainer.
Will load the entire trainer, including the optimizer and EMA.
Params:
- path (str): a path to the DiffusionPriorTrainer checkpoint file
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
Returns:
loaded_obj (dict): The loaded checkpoint dictionary
"""
# all processes need to load checkpoint. no restriction here
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
loaded_obj = torch.load(str(path), map_location=self.device)
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
# unwrap the model when loading from checkpoint
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
if overwrite_lr:
new_lr = self.optim_kwargs["lr"]
self.print(f"Overriding LR to be {new_lr}")
for group in self.optimizer.param_groups:
group["lr"] = new_lr
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
# sync and inform
self.wait_for_everyone()
self.print(f"Loaded model")
return loaded_obj
# model functionality
def update(self):
# only continue with updates until all ranks finish
self.wait_for_everyone()
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
# utilize HFA clipping where applicable
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
@@ -399,17 +505,32 @@ class DiffusionPriorTrainer(nn.Module):
@cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
if self.use_ema:
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
else:
return self.diffusion_prior.p_sample_loop(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
if self.use_ema:
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
else:
return self.diffusion_prior.sample(*args, **kwargs)
@torch.no_grad()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
if self.use_ema:
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
else:
return self.diffusion_prior.sample_batch_size(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
@cast_torch_tensor
def forward(
@@ -427,8 +548,10 @@ class DiffusionPriorTrainer(nn.Module):
total_loss += loss.item()
# backprop with accelerate if applicable
if self.training:
self.scaler.scale(loss).backward()
self.backprop(self.scaler.scale(loss))
return total_loss
@@ -454,6 +577,7 @@ class DecoderTrainer(nn.Module):
def __init__(
self,
decoder,
accelerator = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
@@ -467,8 +591,9 @@ class DecoderTrainer(nn.Module):
assert isinstance(decoder, Decoder)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.decoder = decoder
self.num_unets = len(self.decoder.unets)
self.accelerator = default(accelerator, Accelerator)
self.num_unets = len(decoder.unets)
self.use_ema = use_ema
self.ema_unets = nn.ModuleList([])
@@ -480,7 +605,9 @@ class DecoderTrainer(nn.Module):
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
@@ -490,19 +617,20 @@ class DecoderTrainer(nn.Module):
**kwargs
)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
optimizers.append(optimizer)
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
results = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = results.pop(0)
for opt_ind in range(len(optimizers)):
setattr(self, f'optim{opt_ind}', results.pop(0))
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
@@ -510,47 +638,42 @@ class DecoderTrainer(nn.Module):
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
model = self.decoder.state_dict(),
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
version = __version__,
step = self.step.item(),
**kwargs
)
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
torch.save(save_obj, str(path))
self.accelerator.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
loaded_obj = torch.load(str(path), map_location = 'cpu')
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
if version.parse(__version__) != version.parse(loaded_obj['version']):
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
scaler.load_state_dict(loaded_obj[scaler_key])
optimizer.load_state_dict(loaded_obj[optimizer_key])
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if self.use_ema:
assert 'ema' in loaded_obj
@@ -562,29 +685,18 @@ class DecoderTrainer(nn.Module):
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer)
scaler.update()
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
optimizer.step()
optimizer.zero_grad()
if self.use_ema:
@@ -597,15 +709,17 @@ class DecoderTrainer(nn.Module):
@cast_torch_tensor
@decoder_sample_in_chunks
def sample(self, *args, **kwargs):
distributed = self.accelerator.num_processes > 1
base_decoder = self.accelerator.unwrap_model(self.decoder)
if kwargs.pop('use_non_ema', False) or not self.use_ema:
return self.decoder.sample(*args, **kwargs)
return base_decoder.sample(*args, **kwargs, distributed = distributed)
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
output = base_decoder.sample(*args, **kwargs, distributed = distributed)
self.decoder.unets = trainable_unets # restore original training unets
base_decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
@@ -627,13 +741,14 @@ class DecoderTrainer(nn.Module):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
# with autocast(enabled = self.amp):
with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
if self.training:
self.scale(loss, unet_number = unet_number).backward()
self.accelerator.backward(loss)
return total_loss

View File

@@ -1,4 +1,5 @@
import time
import importlib
# time helpers

View File

@@ -1 +1 @@
__version__ = '0.7.1'
__version__ = '0.11.0'

View File

@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def string_begins_with(prefix, string_input):
return string_input.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)

View File

@@ -24,6 +24,7 @@ setup(
'text to image'
],
install_requires=[
'accelerate',
'click',
'clip-anytorch',
'coca-pytorch>=0.0.5',

View File

@@ -1,7 +1,8 @@
from dalle2_pytorch import Unet, Decoder
from pathlib import Path
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
@@ -12,6 +13,8 @@ from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import dataclasses as accelerate_dataclasses
import webdataset as wds
import click
@@ -42,6 +45,7 @@ def create_dataloaders(
train_prop = 0.75,
val_prop = 0.15,
test_prop = 0.10,
seed = 0,
**kwargs
):
"""
@@ -52,7 +56,7 @@ def create_dataloaders(
num_test = round(test_prop*len(available_shards))
num_val = len(available_shards) - num_train - num_test
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed))
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
@@ -117,7 +121,6 @@ def get_example_data(dataloader, device, n=5):
captions.extend(list(txt))
if len(images) >= n:
break
print("Generated {} examples".format(len(images)))
return list(zip(images[:n], embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, text_prepend=""):
@@ -155,27 +158,34 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
metrics = {}
# Prepare the data
examples = get_example_data(dataloader, device, n_evaluation_samples)
if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.")
return metrics
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
def null_sync(t, *args, **kwargs):
return [t]
if exists(FID):
fid = FrechetInceptionDistance(**FID)
fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync)
fid.to(device=device)
fid.update(int_real_images, real=True)
fid.update(int_generated_images, real=False)
metrics["FID"] = fid.compute().item()
if exists(IS):
inception = InceptionScore(**IS)
inception = InceptionScore(**IS, dist_sync_fn=null_sync)
inception.to(device=device)
inception.update(int_real_images)
is_mean, is_std = inception.compute()
metrics["IS_mean"] = is_mean.item()
metrics["IS_std"] = is_std.item()
if exists(KID):
kernel_inception = KernelInceptionDistance(**KID)
kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync)
kernel_inception.to(device=device)
kernel_inception.update(int_real_images, real=True)
kernel_inception.update(int_generated_images, real=False)
@@ -186,39 +196,47 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
metrics["LPIPS"] = lpips.compute().item()
if trainer.accelerator.num_processes > 1:
# Then we should sync the metrics
metrics_order = sorted(metrics.keys())
metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float)
for i, metric_name in enumerate(metrics_order):
metrics_tensor[0, i] = metrics[metric_name]
metrics_tensor = trainer.accelerator.gather(metrics_tensor)
metrics_tensor = metrics_tensor.mean(dim=0)
for i, metric_name in enumerate(metrics_order):
metrics[metric_name] = metrics_tensor[i].item()
return metrics
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
"""
Logs the model with an appropriate method depending on the tracker
"""
if isinstance(relative_paths, str):
relative_paths = [relative_paths]
trainer_state_dict = {}
trainer_state_dict["trainer"] = trainer.state_dict()
trainer_state_dict['epoch'] = epoch
trainer_state_dict['step'] = step
trainer_state_dict['validation_losses'] = validation_losses
for relative_path in relative_paths:
tracker.save_state_dict(trainer_state_dict, relative_path)
local_path = str(tracker.data_path / relative_path)
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
tracker.save_file(local_path)
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
"""
Loads the model with an appropriate method depending on the tracker
"""
print(print_ribbon(f"Loading model from {recall_source}"))
state_dict = tracker.recall_state_dict(recall_source, **load_config.dict())
trainer.load_state_dict(state_dict["trainer"])
print("Model loaded")
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
local_filepath = tracker.recall_file(recall_source, **load_config)
state_dict = trainer.load(local_filepath)
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
def train(
dataloaders,
decoder,
accelerator,
tracker,
inference_device,
load_config=None,
@@ -237,17 +255,30 @@ def train(
"""
Trains a decoder on a dataset.
"""
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
is_master = accelerator.process_index == 0
trainer = DecoderTrainer(
accelerator,
decoder,
**kwargs
)
# Set up starting model and parameters based on a recalled state dict
start_step = 0
start_epoch = 0
validation_losses = []
next_task = 'train'
sample = 0
val_sample = 0
step = lambda: int(trainer.step.item())
if exists(load_config) and exists(load_config.source):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
if next_task == 'train':
sample = recalled_sample
if next_task == 'val':
val_sample = recalled_sample
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device)
if not exists(unet_training_mask):
@@ -255,139 +286,185 @@ def train(
unet_training_mask = [True] * trainer.num_unets
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
print(print_ribbon("Generating Example Data", repeat=40))
print("This can take a while to load the shard lists...")
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
accelerator.print("This can take a while to load the shard lists...")
if is_master:
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
accelerator.print("Generated training examples")
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
accelerator.print("Generated testing examples")
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
step = start_step
sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device)
for epoch in range(start_epoch, epochs):
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
accelerator.print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
timer = Timer()
last_sample = sample
last_snapshot = sample
sample = 0
last_sample = 0
last_snapshot = 0
if next_task == 'train':
for i, (img, emb) in enumerate(dataloaders["train"]):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
total_samples = all_samples.sum().item()
sample += total_samples
img, emb = send_to_device((img, emb))
losses = []
trainer.train()
for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue
for i, (img, emb) in enumerate(dataloaders["train"]):
step += 1
sample += img.shape[0]
img, emb = send_to_device((img, emb))
trainer.train()
for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
samples_per_sec = (sample - last_sample) / timer.elapsed()
timer.reset()
last_sample = sample
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
trainer.update(unet_number=unet)
losses.append(loss)
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
# We want to average losses across all processes
unet_all_losses = accelerator.gather(unet_losses_tensor)
mask = unet_all_losses != 0
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
log_data = {
"Epoch": epoch,
"Sample": sample,
"Step": i,
"Samples per second": samples_per_sec,
**loss_map
}
# print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}")
if is_master:
tracker.log(log_data, step=step(), verbose=True)
samples_per_sec = (sample - last_sample) / timer.elapsed()
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
print("Saving snapshot")
last_snapshot = sample
# We need to know where the model should be saved
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples:
break
next_task = 'val'
sample = 0
timer.reset()
last_sample = sample
all_average_val_losses = None
if next_task == 'val':
trainer.eval()
accelerator.print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
last_val_sample = val_sample
val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device)
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
val_sample += total_samples
img, emb = send_to_device((img, emb))
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
average_loss = sum(losses) / len(losses)
log_data = {
"Training loss": average_loss,
"Epoch": epoch,
"Sample": sample,
"Step": i,
"Samples per second": samples_per_sec
}
tracker.log(log_data, step=step, verbose=True)
losses = []
for unet in range(1, len(decoder.unets)+1):
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
# No need to evaluate an unchanging unet
continue
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
average_val_loss_tensor[0, unet-1] += loss
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
last_snapshot = sample
# We need to know where the model should be saved
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
samples_per_sec = (val_sample - last_val_sample) / timer.elapsed()
timer.reset()
last_val_sample = val_sample
accelerator.print(f"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec")
accelerator.print(f"Loss: {(average_val_loss_tensor / (i+1))}")
accelerator.print("")
if validation_samples is not None and val_sample >= validation_samples:
break
print(f"Rank {accelerator.state.process_index} finished validation after {i} steps")
accelerator.wait_for_everyone()
average_val_loss_tensor /= i+1
# Gather all the average loss tensors
all_average_val_losses = accelerator.gather(average_val_loss_tensor)
if is_master:
unet_average_val_loss = all_average_val_losses.mean(dim=0)
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
tracker.log(val_loss_map, step=step(), verbose=True)
next_task = 'eval'
if next_task == 'eval':
if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
if is_master:
tracker.log(evaluation, step=step(), verbose=True)
next_task = 'sample'
val_sample = 0
if next_task == 'sample':
if is_master:
# Generate examples and save the model if we are the master
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
# Get the same paths
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
if all_average_val_losses is not None:
average_loss = all_average_val_losses.mean(dim=0).item()
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
save_paths.append("best.pth")
validation_losses.append(average_loss)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
next_task = 'train'
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
if exists(epoch_samples) and sample >= epoch_samples:
break
trainer.eval()
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
with torch.no_grad():
sample = 0
average_loss = 0
timer = Timer()
for i, (img, emb, *_) in enumerate(dataloaders["val"]):
sample += img.shape[0]
img, emb = send_to_device((img, emb))
for unet in range(1, len(decoder.unets)+1):
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
average_loss += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
print(f"Loss: {average_loss / (i+1)}")
print("")
if exists(validation_samples) and sample >= validation_samples:
break
average_loss /= i+1
log_data = {
"Validation loss": average_loss
}
tracker.log(log_data, step=step, verbose=True)
# Compute evaluation metrics
if exists(evaluate_config):
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
tracker.log(evaluation, step=step, verbose=True)
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
# Get the same paths
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
save_paths.append("best.pth")
validation_losses.append(average_loss)
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
tracker_config = config.tracker
init_config = {}
accelerator_config = {
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
"DistributedType": accelerator.distributed_type,
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision
}
init_config = { "config": {**config.dict(), **accelerator_config} }
data_path = data_path or tracker_config.data_path
tracker_type = tracker_type or tracker_config.tracker_type
if exists(tracker_config.init_config):
init_config["config"] = tracker_config.init_config
if tracker_type == "console":
tracker = ConsoleTracker(**init_config)
if tracker_type == "dummy":
tracker = DummyTracker(data_path)
tracker.init(**init_config)
elif tracker_type == "console":
tracker = ConsoleTracker(data_path)
tracker.init(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config.load
@@ -401,51 +478,63 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
init_config["project"] = tracker_config.wandb_project
tracker = WandbTracker(data_path)
tracker.init(**init_config)
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
else:
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
return tracker
def initialize_training(config):
# Create the save path
if "cuda" in config.train.device:
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device(config.train.device)
torch.cuda.set_device(device)
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
def initialize_training(config, config_path):
# Make sure if we are not loading, distributed models are initialized to the same values
torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
# Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
world_size = accelerator.num_processes
rank = accelerator.process_index
shards_per_process = len(all_shards) // world_size
assert shards_per_process > 0, "Not enough shards to split evenly"
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
dataloaders = create_dataloaders (
available_shards=all_shards,
available_shards=my_shards,
img_preproc = config.data.img_preproc,
train_prop = config.data.splits.train,
val_prop = config.data.splits.val,
test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images,
**config.data.dict()
**config.data.dict(),
rank = rank,
seed = config.seed,
)
decoder = config.decoder.create().to(device = device)
# Create the decoder model and print basic info
decoder = config.decoder.create()
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")
tracker = create_tracker(config, **config.tracker.dict())
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
train(dataloaders, decoder,
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Number of parameters: {num_parameters}")
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=device,
inference_device=accelerator.device,
load_config=config.load,
evaluate_config=config.evaluate,
**config.train.dict(),
)
# Create a simple click command line interface to load the config and start the training
@click.command()
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file):
print("Recalling config from {}".format(config_file))
config = TrainDecoderConfig.from_json_path(config_file)
initialize_training(config)
config_file_path = Path(config_file)
config = TrainDecoderConfig.from_json_path(str(config_file_path))
initialize_training(config, config_path=config_file_path)
if __name__ == "__main__":
main()

View File

@@ -1,77 +1,135 @@
from pathlib import Path
# TODO: add start, num_data_points, eval_every and group to config
# TODO: switch back to repo's wandb
START = 0
NUM_DATA_POINTS = 250e6
EVAL_EVERY = 1000
GROUP = "distributed"
import os
import click
import math
import numpy as np
import wandb
import torch
import clip
from torch import nn
from torch.utils.data import DataLoader
from dalle2_pytorch.dataloaders import make_splits, get_reader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
import numpy as np
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer, print_ribbon
from accelerate import Accelerator
from tqdm import tqdm
from dalle2_pytorch.dataloaders import get_reader, make_splits
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.train_configs import (
DiffusionPriorTrainConfig,
TrainDiffusionPriorConfig,
)
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
from dalle2_pytorch import DiffusionPriorTrainer
# constants
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
# helpers
tracker = WandbTracker()
# helpers functions
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def exists(val):
val is not None
return val is not None
# functions
def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
model.eval()
def make_model(
prior_config, train_config, device: str = None, accelerator: Accelerator = None
):
# create model from config
diffusion_prior = prior_config.create()
# instantiate the trainer
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=train_config.lr,
wd=train_config.wd,
max_grad_norm=train_config.max_grad_norm,
amp=train_config.amp,
use_ema=train_config.use_ema,
device=device,
accelerator=accelerator,
)
return trainer
# eval functions
def eval_model(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
loss_type: str,
tracker_context: str,
tracker: BaseTracker = None,
use_ema: bool = True,
):
trainer.eval()
if trainer.is_main_process():
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
with torch.no_grad():
total_loss = 0.
total_samples = 0.
total_loss = 0.0
total_samples = 0.0
for image_embeddings, text_data in tqdm(dataloader):
image_embeddings = image_embeddings.to(device)
text_data = text_data.to(device)
for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
batches = image_embeddings.shape[0]
input_args = dict(image_embed=image_embeddings)
if text_conditioned:
input_args = dict(**input_args, text = text_data)
input_args = dict(**input_args, text=text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
loss = model(**input_args)
if use_ema:
loss = trainer.ema_diffusion_prior(**input_args)
else:
loss = trainer(**input_args)
total_loss += loss * batches
total_samples += batches
avg_loss = (total_loss / total_samples)
avg_loss = total_loss / total_samples
tracker.log({f'{phase} {loss_type}': avg_loss})
stats = {f"{tracker_context}-{loss_type}": avg_loss}
trainer.print(stats)
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
diffusion_prior.eval()
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for test_image_embeddings, text_data in tqdm(dataloader):
test_image_embeddings = test_image_embeddings.to(device)
text_data = text_data.to(device)
def report_cosine_sims(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
tracker: BaseTracker,
tracker_context: str = "validation",
):
trainer.eval()
if trainer.is_main_process():
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
for test_image_embeddings, text_data in dataloader:
test_image_embeddings = test_image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
text_data)
text_cond = dict(text_embed=text_embedding,
text_encodings=text_encodings, mask=text_mask)
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
)
else:
text_embedding = text_data
text_cond = dict(text_embed=text_embedding)
@@ -82,8 +140,9 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
# roll the text to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(
dim=1, keepdim=True
)
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
@@ -92,294 +151,276 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
text_cond_shuffled = dict(
text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
)
# prepare the text embedding
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
# prepare image embeddings
test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(
dim=1, keepdim=True
)
# predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop(
test_image_embeddings.shape, text_cond)
predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond
)
predicted_image_embeddings = (
predicted_image_embeddings
/ predicted_image_embeddings.norm(dim=1, keepdim=True)
)
# predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
test_image_embeddings.shape, text_cond_shuffled)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
predicted_unrelated_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond_shuffled
)
predicted_unrelated_embeddings = (
predicted_unrelated_embeddings
/ predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
)
# calculate similarities
original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = (
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
)
predicted_img_similarity = (
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
)
stats = {
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
f"{tracker_context}/similarity with original image": np.mean(
predicted_img_similarity
),
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{tracker_context}/difference from baseline similarity": np.mean(
predicted_similarity - original_similarity
),
}
for k, v in stats.items():
trainer.print(f"{tracker_context}/{k}: {v}")
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
# training script
def train(
trainer: DiffusionPriorTrainer,
train_loader: DataLoader,
eval_loader: DataLoader,
test_loader: DataLoader,
config: DiffusionPriorTrainConfig,
):
# distributed tracking with wandb
if trainer.accelerator.num_processes > 1:
os.environ["WANDB_START_METHOD"] = "thread"
tracker = wandb.init(
name=f"RANK:{trainer.device}",
entity=config.tracker.wandb_entity,
project=config.tracker.wandb_project,
config=config.dict(),
group=GROUP,
)
# sync after tracker init
trainer.wait_for_everyone()
# init a timer
timer = Timer()
# do training
for img, txt in train_loader:
trainer.train()
current_step = trainer.step.item() + 1
# place data on device
img = img.to(trainer.device)
txt = txt.to(trainer.device)
# pass to model
loss = trainer(text=txt, image_embed=img)
# display & log loss (will only print from main process)
trainer.print(f"Step {current_step}: Loss {loss}")
# perform backprop & apply EMA updates
trainer.update()
# track samples/sec/rank
samples_per_sec = img.shape[0] / timer.elapsed()
# samples seen
samples_seen = (
config.data.batch_size * trainer.accelerator.num_processes * current_step
)
# ema decay
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# Log on all processes for debugging
tracker.log(
{
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
"metrics/training-loss": loss,
},
step=current_step,
)
# Metric Tracking & Checkpointing (outside of timer's scope)
if current_step % EVAL_EVERY == 0:
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/online-model-validation",
tracker=tracker,
use_ema=False,
)
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/ema-model-validation",
tracker=tracker,
use_ema=True,
)
report_cosine_sims(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
tracker=tracker,
tracker_context="metrics",
)
if current_step % config.train.save_every == 0:
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
# reset timer for next round
timer.reset()
# evaluate on test data
eval_model(
trainer=trainer,
dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="test",
tracker=tracker,
)
report_cosine_sims(
trainer,
test_loader,
config.prior.condition_on_text_encodings,
tracker,
tracker_context="test",
)
def initialize_training(config, accelerator=None):
"""
Parse the configuration file, and prepare everything necessary for training
"""
# get a device
if accelerator:
device = accelerator.device
click.secho(f"Accelerating on: {device}", fg="yellow")
else:
if torch.cuda.is_available():
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
device = "cuda:0"
else:
click.secho("No GPU detected...using cpu", fg="yellow")
device = "cpu"
# make the trainer (will automatically distribute if possible & configured)
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
# reload from chcekpoint
if config.load.resume == True:
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
trainer.load(config.load.source)
# fetch and prepare data
if trainer.is_main_process():
click.secho("Grabbing data from source", fg="blue", blink=True)
img_reader = get_reader(
text_conditioned=trainer.text_conditioned,
img_url=config.data.image_url,
meta_url=config.data.meta_url,
)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=trainer.text_conditioned,
batch_size=config.data.batch_size,
num_data_points=NUM_DATA_POINTS,
train_split=config.data.splits.train,
eval_split=config.data.splits.val,
image_reader=img_reader,
rank=accelerator.state.process_index if exists(accelerator) else 0,
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
start=START,
)
# wait for everyone to load data before continuing
trainer.wait_for_everyone()
# start training
train(
trainer=trainer,
train_loader=train_loader,
eval_loader=eval_loader,
test_loader=test_loader,
config=config,
)
@click.command()
@click.option("--wandb-entity", default="laion")
@click.option("--wandb-project", default="diffusion-prior")
@click.option("--wandb-dataset", default="LAION-5B")
@click.option("--wandb-arch", default="DiffusionPrior")
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
@click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5)
@click.option("--num-data-points", default=250e6)
@click.option("--batch-size", default=320)
@click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.9)
@click.option("--val-percent", default=1e-7)
@click.option("--test-percent", default=0.0999999)
@click.option("--dpn-depth", default=12)
@click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=12)
@click.option("--dp-condition-on-text-encodings", default=True)
@click.option("--dp-timesteps", default=1000)
@click.option("--dp-normformer", default=True)
@click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2")
@click.option("--clip", default="ViT-L/14")
@click.option("--amp", default=False)
@click.option("--save-interval", default=120)
@click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None)
@click.option("--gpu-device", default=0)
def train(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
meta_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
num_data_points,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path,
gpu_device
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device(f"cuda:{gpu_device}")
torch.cuda.set_device(device)
# Training loop
# diffusion prior network
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer
)
# Load clip model if text-conditioning
if dp_condition_on_text_encodings:
clip_adapter = OpenAIClipAdapter(clip)
@click.option("--hfa", default=True)
@click.option("--config_path", default="configs/prior.json")
def main(hfa, config_path):
# start HFA if requested
if hfa:
accelerator = Accelerator()
else:
clip_adapter = None
accelerator = None
# diffusion prior with text embeddings and image embeddings pre-computed
# load the configuration file on main process
if not exists(accelerator) or accelerator.is_main_process:
click.secho(f"Loading configuration from {config_path}", fg="green")
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip_adapter,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings
)
config = TrainDiffusionPriorConfig.from_json_path(config_path)
# Load pre-trained model from DPRIOR_PATH
if RESUME:
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
# diffusion prior trainer
trainer = DiffusionPriorTrainer(
diffusion_prior = diffusion_prior,
lr = learning_rate,
wd = weight_decay,
max_grad_norm = max_grad_norm,
amp = amp,
).to(device)
# load optimizer and scaler
if RESUME:
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
trainer.scaler.load_state_dict(loaded_obj['scaler'])
# Create save_path if it doesn't exist
Path(save_path).mkdir(exist_ok = True, parents = True)
# Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings")
reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
if dp_condition_on_text_encodings:
reader_args = dict(**reader_args, meta_url=meta_url)
img_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader
)
else:
reader_args = dict(**reader_args, txt_url=text_embed_url)
img_reader, txt_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader,
text_reader=txt_reader
)
### Training code ###
step = 1
timer = Timer()
epochs = num_epochs
for _ in range(epochs):
for image, text in tqdm(train_loader):
diffusion_prior.train()
image = image.to(device)
text = text.to(device)
input_args = dict(image_embed=image)
if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text)
else:
input_args = dict(**input_args, text_embed=text)
loss = trainer(**input_args)
# Samples per second
samples_per_sec = batch_size * step / timer.elapsed()
# Save checkpoint every save_interval minutes
if(int(timer.elapsed()) >= 60 * save_interval):
timer.reset()
save_diffusion_model(
save_path,
diffusion_prior,
trainer.optimizer,
trainer.scaler,
config,
image_embed_dim)
# Log to wandb
tracker.log({"Training loss": loss,
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
step += 1
trainer.update()
### Test run ###
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
# send config to get processed
initialize_training(config, accelerator)
if __name__ == "__main__":
train()
main()