mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27f19ba7fa | ||
|
|
8f38339c2b | ||
|
|
6b9b4b9e5e | ||
|
|
44e09d5a4d | ||
|
|
34806663e3 | ||
|
|
dc816b1b6e | ||
|
|
05192ffac4 | ||
|
|
9440411954 | ||
|
|
981d407792 |
10
README.md
10
README.md
@@ -1264,4 +1264,14 @@ For detailed information on training the diffusion prior, please refer to the [d
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Qiao2019WeightS,
|
||||
title = {Weight Standardization},
|
||||
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
|
||||
journal = {ArXiv},
|
||||
year = {2019},
|
||||
volume = {abs/1903.10520}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -38,6 +38,8 @@ from coca_pytorch import CoCa
|
||||
|
||||
NAT = 1. / math.log(2.)
|
||||
|
||||
UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
@@ -937,9 +939,12 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
num_image_embeds = 1,
|
||||
num_text_embeds = 1,
|
||||
max_text_len = 256,
|
||||
self_cond = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
self.num_time_embeds = num_time_embeds
|
||||
self.num_image_embeds = num_image_embeds
|
||||
self.num_text_embeds = num_text_embeds
|
||||
@@ -967,6 +972,10 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||
|
||||
# whether to use self conditioning, Hinton's group's new ddpm technique
|
||||
|
||||
self.self_cond = self_cond
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
*args,
|
||||
@@ -988,12 +997,19 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
*,
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
self_cond = None,
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
|
||||
|
||||
# setup self conditioning
|
||||
|
||||
if self.self_cond:
|
||||
self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
|
||||
self_cond = rearrange(self_cond, 'b d -> b 1 d')
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
|
||||
@@ -1043,13 +1059,16 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
|
||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
|
||||
if self.self_cond:
|
||||
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
|
||||
|
||||
tokens = torch.cat((
|
||||
text_encodings,
|
||||
text_embed,
|
||||
@@ -1151,10 +1170,10 @@ class DiffusionPrior(nn.Module):
|
||||
def l2norm_clamp_embed(self, image_embed):
|
||||
return l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||
def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_start = pred
|
||||
@@ -1168,28 +1187,33 @@ class DiffusionPrior(nn.Module):
|
||||
x_start = l2norm(x_start) * self.image_embed_scale
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
return pred, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
||||
batch, device = shape[0], self.device
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
self_cond = x_start if self.net.self_cond else None
|
||||
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
|
||||
|
||||
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||
@@ -1207,6 +1231,8 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
@@ -1216,7 +1242,9 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
|
||||
self_cond = x_start if self.net.self_cond else None
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_start = pred
|
||||
@@ -1251,18 +1279,27 @@ class DiffusionPrior(nn.Module):
|
||||
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
||||
|
||||
if not is_ddim:
|
||||
return self.p_sample_loop_ddpm(*args, **kwargs)
|
||||
normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
|
||||
else:
|
||||
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||
|
||||
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||
image_embed = normalized_image_embed / self.image_embed_scale
|
||||
return image_embed
|
||||
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||
|
||||
self_cond = None
|
||||
if self.net.self_cond and random.random() < 0.5:
|
||||
with torch.no_grad():
|
||||
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
||||
|
||||
pred = self.net(
|
||||
image_embed_noisy,
|
||||
times,
|
||||
self_cond = self_cond,
|
||||
cond_drop_prob = self.cond_drop_prob,
|
||||
**text_cond
|
||||
)
|
||||
@@ -1316,8 +1353,6 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# retrieve original unscaled image embed
|
||||
|
||||
image_embeds /= self.image_embed_scale
|
||||
|
||||
text_embeds = text_cond['text_embed']
|
||||
|
||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
@@ -1416,6 +1451,26 @@ def Downsample(dim, *, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||
|
||||
class WeightStandardizedConv2d(nn.Conv2d):
|
||||
"""
|
||||
https://arxiv.org/abs/1903.10520
|
||||
weight standardization purportedly works synergistically with group normalization
|
||||
"""
|
||||
def forward(self, x):
|
||||
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
||||
|
||||
weight = self.weight
|
||||
flattened_weights = rearrange(weight, 'o ... -> o (...)')
|
||||
|
||||
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
|
||||
|
||||
var = torch.var(flattened_weights, dim = -1, unbiased = False)
|
||||
var = rearrange(var, 'o -> o 1 1 1')
|
||||
|
||||
weight = (weight - mean) * (var + eps).rsqrt()
|
||||
|
||||
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -1434,10 +1489,13 @@ class Block(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
groups = 8
|
||||
groups = 8,
|
||||
weight_standardization = False
|
||||
):
|
||||
super().__init__()
|
||||
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||
conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
|
||||
|
||||
self.project = conv_klass(dim, dim_out, 3, padding = 1)
|
||||
self.norm = nn.GroupNorm(groups, dim_out)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
@@ -1461,6 +1519,7 @@ class ResnetBlock(nn.Module):
|
||||
cond_dim = None,
|
||||
time_cond_dim = None,
|
||||
groups = 8,
|
||||
weight_standardization = False,
|
||||
cosine_sim_cross_attn = False
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1486,8 +1545,8 @@ class ResnetBlock(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
self.block1 = Block(dim, dim_out, groups = groups)
|
||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
||||
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
|
||||
self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, time_emb = None, cond = None):
|
||||
@@ -1700,7 +1759,7 @@ class Unet(nn.Module):
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
||||
self_cond = False,
|
||||
self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
|
||||
sparse_attn = False,
|
||||
cosine_sim_cross_attn = False,
|
||||
cosine_sim_self_attn = False,
|
||||
@@ -1712,6 +1771,7 @@ class Unet(nn.Module):
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
resnet_groups = 8,
|
||||
resnet_weight_standardization = False,
|
||||
num_resnet_blocks = 2,
|
||||
init_cross_embed = True,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
@@ -1859,7 +1919,7 @@ class Unet(nn.Module):
|
||||
|
||||
# prepare resnet klass
|
||||
|
||||
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
|
||||
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
|
||||
|
||||
# give memory efficient unet an initial resnet block
|
||||
|
||||
@@ -2552,6 +2612,14 @@ class Decoder(nn.Module):
|
||||
index = unet_number - 1
|
||||
return self.unets[index]
|
||||
|
||||
def parse_unet_output(self, learned_variance, output):
|
||||
var_interp_frac_unnormalized = None
|
||||
|
||||
if learned_variance:
|
||||
output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)
|
||||
|
||||
return UnetOutput(output, var_interp_frac_unnormalized)
|
||||
|
||||
@contextmanager
|
||||
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
||||
assert exists(unet_number) ^ exists(unet)
|
||||
@@ -2593,10 +2661,9 @@ class Decoder(nn.Module):
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||
|
||||
if learned_variance:
|
||||
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
||||
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
|
||||
|
||||
if predict_x_start:
|
||||
x_start = pred
|
||||
@@ -2779,10 +2846,9 @@ class Decoder(nn.Module):
|
||||
|
||||
self_cond = x_start if unet.self_cond else None
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||
|
||||
if predict_x_start:
|
||||
x_start = pred
|
||||
@@ -2854,16 +2920,13 @@ class Decoder(nn.Module):
|
||||
|
||||
if unet.self_cond and random.random() < 0.5:
|
||||
with torch.no_grad():
|
||||
self_cond = unet(x_noisy, times, **unet_kwargs)
|
||||
|
||||
if learned_variance:
|
||||
self_cond, _ = self_cond.chunk(2, dim = 1)
|
||||
|
||||
unet_output = unet(x_noisy, times, **unet_kwargs)
|
||||
self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||
self_cond = self_cond.detach()
|
||||
|
||||
# forward to get model prediction
|
||||
|
||||
model_output = unet(
|
||||
unet_output = unet(
|
||||
x_noisy,
|
||||
times,
|
||||
**unet_kwargs,
|
||||
@@ -2872,10 +2935,7 @@ class Decoder(nn.Module):
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = model_output.chunk(2, dim = 1)
|
||||
else:
|
||||
pred = model_output
|
||||
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
@@ -2898,7 +2958,7 @@ class Decoder(nn.Module):
|
||||
# if learning the variance, also include the extra weight kl loss
|
||||
|
||||
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)
|
||||
|
||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from collections.abc import Iterable
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
@@ -181,7 +181,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
group_wd_params = True,
|
||||
warmup_steps = 1,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -233,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
**self.optim_kwargs,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||
|
||||
if exists(cosine_decay_max_steps):
|
||||
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
|
||||
else:
|
||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||
|
||||
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
||||
|
||||
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
# FIXME: LambdaLR can't be saved due to pickling issues
|
||||
save_obj = dict(
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
scheduler = self.scheduler.state_dict(),
|
||||
warmup_scheduler = self.warmup_scheduler,
|
||||
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
||||
version = version.parse(__version__),
|
||||
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
# unwrap the model when loading from checkpoint
|
||||
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
||||
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
self.scheduler.load_state_dict(loaded_obj['scheduler'])
|
||||
|
||||
# set warmupstep
|
||||
if exists(self.warmup_scheduler):
|
||||
@@ -350,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||
if not self.accelerator.optimizer_step_was_skipped:
|
||||
with self.warmup_scheduler.dampening():
|
||||
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
|
||||
with sched_context():
|
||||
self.scheduler.step()
|
||||
|
||||
if self.use_ema:
|
||||
@@ -433,6 +441,7 @@ class DecoderTrainer(nn.Module):
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
@@ -454,7 +463,7 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
||||
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
||||
|
||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
@@ -462,7 +471,7 @@ class DecoderTrainer(nn.Module):
|
||||
schedulers = []
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
||||
if isinstance(unet, nn.Identity):
|
||||
optimizers.append(None)
|
||||
schedulers.append(None)
|
||||
@@ -478,7 +487,11 @@ class DecoderTrainer(nn.Module):
|
||||
)
|
||||
|
||||
optimizers.append(optimizer)
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
if exists(unet_cosine_decay_max_steps):
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
||||
else:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
@@ -558,9 +571,15 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
scheduler_key = f'sched{ind}'
|
||||
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
state_dict = optimizer.state_dict() if optimizer is not None else None
|
||||
save_obj = {**save_obj, optimizer_key: state_dict}
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
||||
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
||||
|
||||
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
@@ -581,10 +600,18 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
scheduler_key = f'sched{ind}'
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
if optimizer is not None:
|
||||
|
||||
if exists(optimizer):
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(scheduler):
|
||||
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.6.0'
|
||||
__version__ = '1.8.2'
|
||||
|
||||
Reference in New Issue
Block a user