Compare commits

...

18 Commits

Author SHA1 Message Date
Phil Wang
6651eafa93 one more residual, after seeing good results on unconditional generation locally 2022-06-16 11:18:02 -07:00
Phil Wang
e6bb75e5ab fix missing residual for highest resolution of the unet 2022-06-15 20:09:43 -07:00
Giorgos Zachariadis
b4c3e5b854 changed str in order to avoid confusions and collisions with Python (#147) 2022-06-15 13:41:16 -07:00
Phil Wang
b7f9607258 make memory efficient unet design from imagen toggle-able 2022-06-15 13:40:26 -07:00
Phil Wang
2219348a6e adopt similar unet architecture as imagen 2022-06-15 12:18:21 -07:00
Phil Wang
9eea9b9862 add p2 loss reweighting for decoder training as an option 2022-06-14 10:58:57 -07:00
Phil Wang
5d958713c0 fix classifier free guidance for image hiddens summed to time hiddens, thanks to @xvjiarui for finding this bug 2022-06-13 21:01:50 -07:00
Phil Wang
0f31980362 cleanup 2022-06-07 17:31:38 -07:00
Phil Wang
bee5bf3815 fix for https://github.com/lucidrains/DALLE2-pytorch/issues/143 2022-06-07 09:03:48 -07:00
Phil Wang
350a3d6045 0.6.16 2022-06-06 08:45:46 -07:00
Kashif Rasul
1a81670718 fix quadratic_beta_schedule (#141) 2022-06-06 08:45:14 -07:00
Phil Wang
934c9728dc some cleanup 2022-06-04 16:54:15 -07:00
Phil Wang
ce4b0107c1 0.6.13 2022-06-04 13:26:57 -07:00
zion
64c2f9c4eb implement ema warmup from @crowsonkb (#140) 2022-06-04 13:26:34 -07:00
Phil Wang
22cc613278 ema fix from @nousr 2022-06-03 19:44:36 -07:00
zion
83517849e5 ema module fixes (#139) 2022-06-03 19:43:51 -07:00
Phil Wang
708809ed6c lower beta2 for adam down to 0.99, based on https://openreview.net/forum?id=2LdBqxc1Yv 2022-06-03 10:26:28 -07:00
Phil Wang
9cc475f6e7 fix update_every within EMA 2022-06-03 10:21:05 -07:00
7 changed files with 150 additions and 63 deletions

View File

@@ -1207,4 +1207,14 @@ This library would not have gotten to this working state without the help of
} }
``` ```
```bibtex
@article{Choi2022PerceptionPT,
title = {Perception Prioritized Training of Diffusion Models},
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
journal = {ArXiv},
year = {2022},
volume = {abs/2204.00227}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,7 +1,6 @@
import math import math
import random import random
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction
from functools import partial, wraps from functools import partial, wraps
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
@@ -12,7 +11,7 @@ import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
import torchvision.transforms as T import torchvision.transforms as T
from einops import rearrange, repeat from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom from einops_exts.torch import EinopsToAndFrom
@@ -57,7 +56,7 @@ def maybe(fn):
def default(val, d): def default(val, d):
if exists(val): if exists(val):
return val return val
return d() if isfunction(d) else d return d() if callable(d) else d
def cast_tuple(val, length = 1): def cast_tuple(val, length = 1):
if isinstance(val, list): if isinstance(val, list):
@@ -314,11 +313,6 @@ def extract(a, t, x_shape):
out = a.gather(-1, t) out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1))) return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def meanflat(x): def meanflat(x):
return x.mean(dim = tuple(range(1, len(x.shape)))) return x.mean(dim = tuple(range(1, len(x.shape))))
@@ -373,7 +367,7 @@ def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps scale = 1000 / timesteps
beta_start = scale * 0.0001 beta_start = scale * 0.0001
beta_end = scale * 0.02 beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2 return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
def sigmoid_beta_schedule(timesteps): def sigmoid_beta_schedule(timesteps):
@@ -385,7 +379,7 @@ def sigmoid_beta_schedule(timesteps):
class BaseGaussianDiffusion(nn.Module): class BaseGaussianDiffusion(nn.Module):
def __init__(self, *, beta_schedule, timesteps, loss_type): def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
super().__init__() super().__init__()
if beta_schedule == "cosine": if beta_schedule == "cosine":
@@ -450,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# p2 loss reweighting
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
def q_posterior(self, x_start, x_t, t): def q_posterior(self, x_start, x_t, t):
posterior_mean = ( posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
@@ -946,10 +945,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.): def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = noise_like(x.shape, device, repeat_noise) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@@ -1085,8 +1084,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
def Upsample(dim): def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1) return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def Downsample(dim): def Downsample(dim, *, dim_out = None):
return nn.Conv2d(dim, dim, 4, 2, 1) dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, 4, 2, 1)
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
@@ -1352,6 +1352,7 @@ class Unet(nn.Module):
init_cross_embed_kernel_sizes = (3, 7, 15), init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False, cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4), cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -1371,7 +1372,7 @@ class Unet(nn.Module):
self.channels_out = default(channels_out, channels) self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 3 * 2) init_dim = default(init_dim, dim)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
@@ -1428,6 +1429,7 @@ class Unet(nn.Module):
# for classifier free guidance # for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
self.max_text_len = max_text_len self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
@@ -1461,10 +1463,11 @@ class Unet(nn.Module):
layer_cond_dim = cond_dim if not is_first else None layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups), downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
downsample_klass(dim_out) if not is_last else nn.Identity() downsample_klass(dim_out) if not is_last and not memory_efficient else None
])) ]))
mid_dim = dims[-1] mid_dim = dims[-1]
@@ -1473,19 +1476,19 @@ class Unet(nn.Module):
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))): for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (num_resolutions - 2) is_last = ind >= (len(in_out) - 1)
layer_cond_dim = cond_dim if not is_last else None layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(), Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_in) Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
])) ]))
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
ResnetBlock(dim, dim, groups = resnet_groups[0]), ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1) nn.Conv2d(dim, self.channels_out, 1)
) )
@@ -1557,6 +1560,7 @@ class Unet(nn.Module):
# initial convolution # initial convolution
x = self.init_conv(x) x = self.init_conv(x)
r = x.clone() # final residual
# time conditioning # time conditioning
@@ -1565,19 +1569,28 @@ class Unet(nn.Module):
time_tokens = self.to_time_tokens(time_hiddens) time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens) t = self.to_time_cond(time_hiddens)
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
t = t + image_hiddens
# conditional dropout # conditional dropout
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device) image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device) text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1') text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
image_hiddens = torch.where(
image_keep_mask_hidden,
image_hiddens,
null_image_hiddens
)
t = t + image_hiddens
# mask out image embedding depending on condition dropout # mask out image embedding depending on condition dropout
# for classifier free guidance # for classifier free guidance
@@ -1585,11 +1598,12 @@ class Unet(nn.Module):
image_tokens = None image_tokens = None
if self.cond_on_image_embeds: if self.cond_on_image_embeds:
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
image_tokens = self.image_to_tokens(image_embed) image_tokens = self.image_to_tokens(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where( image_tokens = torch.where(
image_keep_mask, image_keep_mask_embed,
image_tokens, image_tokens,
null_image_embed null_image_embed
) )
@@ -1644,7 +1658,10 @@ class Unet(nn.Module):
hiddens = [] hiddens = []
for init_block, sparse_attn, resnet_blocks, downsample in self.downs: for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
x = init_block(x, c, t) x = init_block(x, c, t)
x = sparse_attn(x) x = sparse_attn(x)
@@ -1652,7 +1669,9 @@ class Unet(nn.Module):
x = resnet_block(x, c, t) x = resnet_block(x, c, t)
hiddens.append(x) hiddens.append(x)
x = downsample(x)
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, mid_c, t) x = self.mid_block1(x, mid_c, t)
@@ -1662,7 +1681,7 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t) x = self.mid_block2(x, mid_c, t)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups: for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1) x = torch.cat((x, hiddens.pop()), dim = 1)
x = init_block(x, c, t) x = init_block(x, c, t)
x = sparse_attn(x) x = sparse_attn(x)
@@ -1671,6 +1690,7 @@ class Unet(nn.Module):
x = upsample(x) x = upsample(x)
x = torch.cat((x, r), dim = 1)
return self.final_conv(x) return self.final_conv(x)
class LowresConditioner(nn.Module): class LowresConditioner(nn.Module):
@@ -1750,12 +1770,16 @@ class Decoder(BaseGaussianDiffusion):
unconditional = False, unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
use_dynamic_thres = False, # from the Imagen paper use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9 dynamic_thres_percentile = 0.9,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1
): ):
super().__init__( super().__init__(
beta_schedule = beta_schedule, beta_schedule = beta_schedule,
timesteps = timesteps, timesteps = timesteps,
loss_type = loss_type loss_type = loss_type,
p2_loss_weight_gamma = p2_loss_weight_gamma,
p2_loss_weight_k = p2_loss_weight_k
) )
self.unconditional = unconditional self.unconditional = unconditional
@@ -1956,10 +1980,10 @@ class Decoder(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False): def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
noise = noise_like(x.shape, device, repeat_noise) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@@ -2023,7 +2047,13 @@ class Decoder(BaseGaussianDiffusion):
target = noise if not predict_x_start else x_start target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target) loss = self.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
if self.has_p2_loss_reweighting:
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
loss = loss.mean()
if not learned_variance: if not learned_variance:
# return simple loss if not using learned variance # return simple loss if not using learned variance

View File

@@ -11,7 +11,7 @@ def get_optimizer(
params, params,
lr = 1e-4, lr = 1e-4,
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.99),
eps = 1e-8, eps = 1e-8,
filter_by_requires_grad = False, filter_by_requires_grad = False,
group_wd_params = True, group_wd_params = True,

View File

@@ -58,8 +58,15 @@ def num_to_groups(num, divisor):
arr.append(remainder) arr.append(remainder)
return arr return arr
def get_pkg_version(): def clamp(value, min_value = None, max_value = None):
return __version__ assert exists(min_value) or exists(max_value)
if exists(min_value):
value = max(value, min_value)
if exists(max_value):
value = min(value, max_value)
return value
# decorators # decorators
@@ -175,12 +182,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
# exponential moving average wrapper # exponential moving average wrapper
class EMA(nn.Module): class EMA(nn.Module):
"""
Implements exponential moving average shadowing for your model.
Utilizes an inverse decay schedule to manage longer term training runs.
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
def __init__( def __init__(
self, self,
model, model,
beta = 0.99, beta = 0.9999,
update_after_step = 1000, update_after_step = 10000,
update_every = 10, update_every = 10,
inv_gamma = 1.0,
power = 2/3,
min_value = 0.0,
): ):
super().__init__() super().__init__()
self.beta = beta self.beta = beta
@@ -188,7 +217,11 @@ class EMA(nn.Module):
self.ema_model = copy.deepcopy(model) self.ema_model = copy.deepcopy(model)
self.update_every = update_every self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0 self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0])) self.register_buffer('step', torch.tensor([0]))
@@ -198,37 +231,51 @@ class EMA(nn.Module):
self.ema_model.to(device) self.ema_model.to(device)
def copy_params_from_model_to_ema(self): def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict()) for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
ma_param.data.copy_(current_param.data)
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
ma_buffer.data.copy_(current_buffer.data)
def get_current_decay(self):
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
if epoch <= 0:
return 0.
return clamp(value, min_value = self.min_value, max_value = self.beta)
def update(self): def update(self):
step = self.step.item()
self.step += 1 self.step += 1
if (self.step % self.update_every) != 0: if (step % self.update_every) != 0:
return return
if self.step <= self.update_after_step: if step <= self.update_after_step:
self.copy_params_from_model_to_ema() self.copy_params_from_model_to_ema()
return return
if not self.initted: if not self.initted.item():
self.copy_params_from_model_to_ema() self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True])) self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model) self.update_moving_average(self.ema_model, self.online_model)
@torch.no_grad()
def update_moving_average(self, ma_model, current_model): def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new): current_decay = self.get_current_decay()
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
old_weight, up_weight = ma_params.data, current_params.data difference = ma_params.data - current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight) difference.mul_(1.0 - current_decay)
ma_params.sub_(difference)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()): for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer) difference = ma_buffer - current_buffer
ma_buffer.copy_(new_buffer_value) difference.mul_(1.0 - current_decay)
ma_buffer.sub_(difference)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs) return self.ema_model(*args, **kwargs)
@@ -488,7 +535,7 @@ class DecoderTrainer(nn.Module):
loaded_obj = torch.load(str(path)) loaded_obj = torch.load(str(path))
if version.parse(__version__) != loaded_obj['version']: if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}') print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict) self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])

View File

@@ -1 +1 @@
__version__ = '0.6.9' __version__ = '0.9.2'

View File

@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
return_val[ind][key] = d[key] return_val[ind][key] = d[key]
return (*return_val,) return (*return_val,)
def string_begins_with(prefix, str): def string_begins_with(prefix, string_input):
return str.startswith(prefix) return string_input.startswith(prefix)
def group_by_key_prefix(prefix, d): def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d) return group_dict_by_key(partial(string_begins_with, prefix), d)

View File

@@ -211,7 +211,7 @@ def recall_trainer(tracker, trainer, recall_source=None, **load_config):
Loads the model with an appropriate method depending on the tracker Loads the model with an appropriate method depending on the tracker
""" """
print(print_ribbon(f"Loading model from {recall_source}")) print(print_ribbon(f"Loading model from {recall_source}"))
state_dict = tracker.recall_state_dict(recall_source, **load_config) state_dict = tracker.recall_state_dict(recall_source, **load_config.dict())
trainer.load_state_dict(state_dict["trainer"]) trainer.load_state_dict(state_dict["trainer"])
print("Model loaded") print("Model loaded")
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"] return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]