Compare commits

...

9 Commits
1.6.0 ... 1.8.2

4 changed files with 145 additions and 48 deletions

View File

@@ -1264,4 +1264,14 @@ For detailed information on training the diffusion prior, please refer to the [d
} }
``` ```
```bibtex
@article{Qiao2019WeightS,
title = {Weight Standardization},
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
journal = {ArXiv},
year = {2019},
volume = {abs/1903.10520}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -38,6 +38,8 @@ from coca_pytorch import CoCa
NAT = 1. / math.log(2.) NAT = 1. / math.log(2.)
UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])
# helper functions # helper functions
def exists(val): def exists(val):
@@ -937,9 +939,12 @@ class DiffusionPriorNetwork(nn.Module):
num_image_embeds = 1, num_image_embeds = 1,
num_text_embeds = 1, num_text_embeds = 1,
max_text_len = 256, max_text_len = 256,
self_cond = False,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.dim = dim
self.num_time_embeds = num_time_embeds self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds self.num_image_embeds = num_image_embeds
self.num_text_embeds = num_text_embeds self.num_text_embeds = num_text_embeds
@@ -967,6 +972,10 @@ class DiffusionPriorNetwork(nn.Module):
self.max_text_len = max_text_len self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim)) self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
# whether to use self conditioning, Hinton's group's new ddpm technique
self.self_cond = self_cond
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
*args, *args,
@@ -988,12 +997,19 @@ class DiffusionPriorNetwork(nn.Module):
*, *,
text_embed, text_embed,
text_encodings = None, text_encodings = None,
self_cond = None,
cond_drop_prob = 0. cond_drop_prob = 0.
): ):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
# setup self conditioning
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
self_cond = rearrange(self_cond, 'b d -> b 1 d')
# in section 2.2, last paragraph # in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
@@ -1043,13 +1059,16 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right # but let's just do it right
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.to_time_embeds(diffusion_timesteps) time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
tokens = torch.cat(( tokens = torch.cat((
text_encodings, text_encodings,
text_embed, text_embed,
@@ -1151,10 +1170,10 @@ class DiffusionPrior(nn.Module):
def l2norm_clamp_embed(self, image_embed): def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.): def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond) pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
if self.predict_x_start: if self.predict_x_start:
x_start = pred x_start = pred
@@ -1168,28 +1187,33 @@ class DiffusionPrior(nn.Module):
x_start = l2norm(x_start) * self.image_embed_scale x_start = l2norm(x_start) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad() @torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.): def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = torch.randn_like(x) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad() @torch.no_grad()
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.): def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
batch, device = shape[0], self.device batch, device = shape[0], self.device
image_embed = torch.randn(shape, device = device) image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm: if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps): for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((batch,), i, device = device, dtype = torch.long) times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
self_cond = x_start if self.net.self_cond else None
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start: if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed) image_embed = self.l2norm_clamp_embed(image_embed)
@@ -1207,6 +1231,8 @@ class DiffusionPrior(nn.Module):
image_embed = torch.randn(shape, device = device) image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm: if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale image_embed = l2norm(image_embed) * self.image_embed_scale
@@ -1216,7 +1242,9 @@ class DiffusionPrior(nn.Module):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long) time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond) self_cond = x_start if self.net.self_cond else None
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start: if self.predict_x_start:
x_start = pred x_start = pred
@@ -1251,18 +1279,27 @@ class DiffusionPrior(nn.Module):
is_ddim = timesteps < self.noise_scheduler.num_timesteps is_ddim = timesteps < self.noise_scheduler.num_timesteps
if not is_ddim: if not is_ddim:
return self.p_sample_loop_ddpm(*args, **kwargs) normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
else:
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps) image_embed = normalized_image_embed / self.image_embed_scale
return image_embed
def p_losses(self, image_embed, times, text_cond, noise = None): def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed)) noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise) image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
self_cond = None
if self.net.self_cond and random.random() < 0.5:
with torch.no_grad():
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
pred = self.net( pred = self.net(
image_embed_noisy, image_embed_noisy,
times, times,
self_cond = self_cond,
cond_drop_prob = self.cond_drop_prob, cond_drop_prob = self.cond_drop_prob,
**text_cond **text_cond
) )
@@ -1316,8 +1353,6 @@ class DiffusionPrior(nn.Module):
# retrieve original unscaled image embed # retrieve original unscaled image embed
image_embeds /= self.image_embed_scale
text_embeds = text_cond['text_embed'] text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch) text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
@@ -1416,6 +1451,26 @@ def Downsample(dim, *, dim_out = None):
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, 4, 2, 1) return nn.Conv2d(dim, dim_out, 4, 2, 1)
class WeightStandardizedConv2d(nn.Conv2d):
"""
https://arxiv.org/abs/1903.10520
weight standardization purportedly works synergistically with group normalization
"""
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
weight = self.weight
flattened_weights = rearrange(weight, 'o ... -> o (...)')
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
var = torch.var(flattened_weights, dim = -1, unbiased = False)
var = rearrange(var, 'o -> o 1 1 1')
weight = (weight - mean) * (var + eps).rsqrt()
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
@@ -1434,10 +1489,13 @@ class Block(nn.Module):
self, self,
dim, dim,
dim_out, dim_out,
groups = 8 groups = 8,
weight_standardization = False
): ):
super().__init__() super().__init__()
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1) conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
self.project = conv_klass(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out) self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU() self.act = nn.SiLU()
@@ -1461,6 +1519,7 @@ class ResnetBlock(nn.Module):
cond_dim = None, cond_dim = None,
time_cond_dim = None, time_cond_dim = None,
groups = 8, groups = 8,
weight_standardization = False,
cosine_sim_cross_attn = False cosine_sim_cross_attn = False
): ):
super().__init__() super().__init__()
@@ -1486,8 +1545,8 @@ class ResnetBlock(nn.Module):
) )
) )
self.block1 = Block(dim, dim_out, groups = groups) self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
self.block2 = Block(dim_out, dim_out, groups = groups) self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None, cond = None): def forward(self, x, time_emb = None, cond = None):
@@ -1700,7 +1759,7 @@ class Unet(nn.Module):
attn_heads = 16, attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
self_cond = False, self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
sparse_attn = False, sparse_attn = False,
cosine_sim_cross_attn = False, cosine_sim_cross_attn = False,
cosine_sim_self_attn = False, cosine_sim_self_attn = False,
@@ -1712,6 +1771,7 @@ class Unet(nn.Module):
init_dim = None, init_dim = None,
init_conv_kernel_size = 7, init_conv_kernel_size = 7,
resnet_groups = 8, resnet_groups = 8,
resnet_weight_standardization = False,
num_resnet_blocks = 2, num_resnet_blocks = 2,
init_cross_embed = True, init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15), init_cross_embed_kernel_sizes = (3, 7, 15),
@@ -1859,7 +1919,7 @@ class Unet(nn.Module):
# prepare resnet klass # prepare resnet klass
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn) resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
# give memory efficient unet an initial resnet block # give memory efficient unet an initial resnet block
@@ -2552,6 +2612,14 @@ class Decoder(nn.Module):
index = unet_number - 1 index = unet_number - 1
return self.unets[index] return self.unets[index]
def parse_unet_output(self, learned_variance, output):
var_interp_frac_unnormalized = None
if learned_variance:
output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)
return UnetOutput(output, var_interp_frac_unnormalized)
@contextmanager @contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None): def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet) assert exists(unet_number) ^ exists(unet)
@@ -2593,10 +2661,9 @@ class Decoder(nn.Module):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None): def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level)) model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
if learned_variance: pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
if predict_x_start: if predict_x_start:
x_start = pred x_start = pred
@@ -2779,10 +2846,9 @@ class Decoder(nn.Module):
self_cond = x_start if unet.self_cond else None self_cond = x_start if unet.self_cond else None
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
if learned_variance: pred, _ = self.parse_unet_output(learned_variance, unet_output)
pred, _ = pred.chunk(2, dim = 1)
if predict_x_start: if predict_x_start:
x_start = pred x_start = pred
@@ -2854,16 +2920,13 @@ class Decoder(nn.Module):
if unet.self_cond and random.random() < 0.5: if unet.self_cond and random.random() < 0.5:
with torch.no_grad(): with torch.no_grad():
self_cond = unet(x_noisy, times, **unet_kwargs) unet_output = unet(x_noisy, times, **unet_kwargs)
self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
if learned_variance:
self_cond, _ = self_cond.chunk(2, dim = 1)
self_cond = self_cond.detach() self_cond = self_cond.detach()
# forward to get model prediction # forward to get model prediction
model_output = unet( unet_output = unet(
x_noisy, x_noisy,
times, times,
**unet_kwargs, **unet_kwargs,
@@ -2872,10 +2935,7 @@ class Decoder(nn.Module):
text_cond_drop_prob = self.text_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob,
) )
if learned_variance: pred, _ = self.parse_unet_output(learned_variance, unet_output)
pred, _ = model_output.chunk(2, dim = 1)
else:
pred = model_output
target = noise if not predict_x_start else x_start target = noise if not predict_x_start else x_start
@@ -2898,7 +2958,7 @@ class Decoder(nn.Module):
# if learning the variance, also include the extra weight kl loss # if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times) true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output) model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)
# kl loss with detached model predicted mean, for stability reasons as in paper # kl loss with detached model predicted mean, for stability reasons as in paper

View File

@@ -9,7 +9,7 @@ from collections.abc import Iterable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -181,7 +181,8 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6, eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
group_wd_params = True, group_wd_params = True,
warmup_steps = 1, warmup_steps = None,
cosine_decay_max_steps = None,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -233,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module):
**self.optim_kwargs, **self.optim_kwargs,
**kwargs **kwargs
) )
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0) if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
# FIXME: LambdaLR can't be saved due to pickling issues # FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict( save_obj = dict(
optimizer = self.optimizer.state_dict(), optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup_scheduler = self.warmup_scheduler, warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(), model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__), version = version.parse(__version__),
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
# unwrap the model when loading from checkpoint # unwrap the model when loading from checkpoint
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device)) self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.optimizer.load_state_dict(loaded_obj['optimizer']) self.optimizer.load_state_dict(loaded_obj['optimizer'])
self.scheduler.load_state_dict(loaded_obj['scheduler'])
# set warmupstep # set warmupstep
if exists(self.warmup_scheduler): if exists(self.warmup_scheduler):
@@ -350,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy" # accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped: if not self.accelerator.optimizer_step_was_skipped:
with self.warmup_scheduler.dampening(): sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
with sched_context():
self.scheduler.step() self.scheduler.step()
if self.use_ema: if self.use_ema:
@@ -433,6 +441,7 @@ class DecoderTrainer(nn.Module):
wd = 1e-2, wd = 1e-2,
eps = 1e-8, eps = 1e-8,
warmup_steps = None, warmup_steps = None,
cosine_decay_max_steps = None,
max_grad_norm = 0.5, max_grad_norm = 0.5,
amp = False, amp = False,
group_wd_params = True, group_wd_params = True,
@@ -454,7 +463,7 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay # be able to finely customize learning rate, weight decay
# per unet # per unet
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps)) lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4' assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
@@ -462,7 +471,7 @@ class DecoderTrainer(nn.Module):
schedulers = [] schedulers = []
warmup_schedulers = [] warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps): for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
if isinstance(unet, nn.Identity): if isinstance(unet, nn.Identity):
optimizers.append(None) optimizers.append(None)
schedulers.append(None) schedulers.append(None)
@@ -478,7 +487,11 @@ class DecoderTrainer(nn.Module):
) )
optimizers.append(optimizer) optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
if exists(unet_cosine_decay_max_steps):
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
else:
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler) warmup_schedulers.append(warmup_scheduler)
@@ -558,9 +571,15 @@ class DecoderTrainer(nn.Module):
for ind in range(0, self.num_unets): for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}' optimizer_key = f'optim{ind}'
scheduler_key = f'sched{ind}'
optimizer = getattr(self, optimizer_key) optimizer = getattr(self, optimizer_key)
state_dict = optimizer.state_dict() if optimizer is not None else None scheduler = getattr(self, scheduler_key)
save_obj = {**save_obj, optimizer_key: state_dict}
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
if self.use_ema: if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -581,10 +600,18 @@ class DecoderTrainer(nn.Module):
optimizer_key = f'optim{ind}' optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key) optimizer = getattr(self, optimizer_key)
scheduler_key = f'sched{ind}'
scheduler = getattr(self, scheduler_key)
warmup_scheduler = self.warmup_schedulers[ind] warmup_scheduler = self.warmup_schedulers[ind]
if optimizer is not None:
if exists(optimizer):
optimizer.load_state_dict(loaded_obj[optimizer_key]) optimizer.load_state_dict(loaded_obj[optimizer_key])
if exists(scheduler):
scheduler.load_state_dict(loaded_obj[scheduler_key])
if exists(warmup_scheduler): if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step warmup_scheduler.last_step = last_step

View File

@@ -1 +1 @@
__version__ = '1.6.0' __version__ = '1.8.2'