|
|
|
|
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
|
|
|
|
from torch import nn, einsum
|
|
|
|
|
import torchvision.transforms as T
|
|
|
|
|
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
|
from einops import rearrange, repeat, reduce
|
|
|
|
|
from einops.layers.torch import Rearrange
|
|
|
|
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
|
|
|
|
from einops_exts.torch import EinopsToAndFrom
|
|
|
|
|
@@ -379,7 +379,7 @@ def sigmoid_beta_schedule(timesteps):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseGaussianDiffusion(nn.Module):
|
|
|
|
|
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
|
|
|
|
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
if beta_schedule == "cosine":
|
|
|
|
|
@@ -444,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
|
|
|
|
|
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
|
|
|
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
|
|
|
|
|
|
# p2 loss reweighting
|
|
|
|
|
|
|
|
|
|
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
|
|
|
|
|
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
|
|
|
|
|
|
|
|
|
|
def q_posterior(self, x_start, x_t, t):
|
|
|
|
|
posterior_mean = (
|
|
|
|
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
|
|
|
@@ -1079,8 +1084,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|
|
|
|
def Upsample(dim):
|
|
|
|
|
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
|
|
|
|
|
|
|
|
|
def Downsample(dim):
|
|
|
|
|
return nn.Conv2d(dim, dim, 4, 2, 1)
|
|
|
|
|
def Downsample(dim, *, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
|
|
|
|
|
|
|
|
|
class SinusoidalPosEmb(nn.Module):
|
|
|
|
|
def __init__(self, dim):
|
|
|
|
|
@@ -1346,6 +1352,7 @@ class Unet(nn.Module):
|
|
|
|
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
|
|
|
|
cross_embed_downsample = False,
|
|
|
|
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
|
|
|
|
memory_efficient = False,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -1365,7 +1372,7 @@ class Unet(nn.Module):
|
|
|
|
|
self.channels_out = default(channels_out, channels)
|
|
|
|
|
|
|
|
|
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
|
|
|
|
init_dim = default(init_dim, dim // 3 * 2)
|
|
|
|
|
init_dim = default(init_dim, dim)
|
|
|
|
|
|
|
|
|
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
|
|
|
|
|
|
|
|
|
@@ -1456,10 +1463,11 @@ class Unet(nn.Module):
|
|
|
|
|
layer_cond_dim = cond_dim if not is_first else None
|
|
|
|
|
|
|
|
|
|
self.downs.append(nn.ModuleList([
|
|
|
|
|
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
|
|
|
|
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
|
|
|
|
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
|
|
|
|
downsample_klass(dim_out) if not is_last else nn.Identity()
|
|
|
|
|
downsample_klass(dim_out) if not is_last and not memory_efficient else None
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
mid_dim = dims[-1]
|
|
|
|
|
@@ -1468,19 +1476,19 @@ class Unet(nn.Module):
|
|
|
|
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
|
|
|
|
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
|
|
|
|
|
|
|
|
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
|
|
|
|
is_last = ind >= (num_resolutions - 2)
|
|
|
|
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
|
|
|
|
is_last = ind >= (len(in_out) - 1)
|
|
|
|
|
layer_cond_dim = cond_dim if not is_last else None
|
|
|
|
|
|
|
|
|
|
self.ups.append(nn.ModuleList([
|
|
|
|
|
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
|
|
|
|
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
|
|
|
|
Upsample(dim_in)
|
|
|
|
|
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
self.final_conv = nn.Sequential(
|
|
|
|
|
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
|
|
|
|
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
|
|
|
|
|
nn.Conv2d(dim, self.channels_out, 1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -1552,6 +1560,7 @@ class Unet(nn.Module):
|
|
|
|
|
# initial convolution
|
|
|
|
|
|
|
|
|
|
x = self.init_conv(x)
|
|
|
|
|
r = x.clone() # final residual
|
|
|
|
|
|
|
|
|
|
# time conditioning
|
|
|
|
|
|
|
|
|
|
@@ -1649,7 +1658,10 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
hiddens = []
|
|
|
|
|
|
|
|
|
|
for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
|
|
|
|
|
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
|
|
|
|
|
if exists(pre_downsample):
|
|
|
|
|
x = pre_downsample(x)
|
|
|
|
|
|
|
|
|
|
x = init_block(x, c, t)
|
|
|
|
|
x = sparse_attn(x)
|
|
|
|
|
|
|
|
|
|
@@ -1657,7 +1669,9 @@ class Unet(nn.Module):
|
|
|
|
|
x = resnet_block(x, c, t)
|
|
|
|
|
|
|
|
|
|
hiddens.append(x)
|
|
|
|
|
x = downsample(x)
|
|
|
|
|
|
|
|
|
|
if exists(post_downsample):
|
|
|
|
|
x = post_downsample(x)
|
|
|
|
|
|
|
|
|
|
x = self.mid_block1(x, mid_c, t)
|
|
|
|
|
|
|
|
|
|
@@ -1667,7 +1681,7 @@ class Unet(nn.Module):
|
|
|
|
|
x = self.mid_block2(x, mid_c, t)
|
|
|
|
|
|
|
|
|
|
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
|
|
|
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
|
|
|
|
x = torch.cat((x, hiddens.pop()), dim = 1)
|
|
|
|
|
x = init_block(x, c, t)
|
|
|
|
|
x = sparse_attn(x)
|
|
|
|
|
|
|
|
|
|
@@ -1676,6 +1690,7 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
x = upsample(x)
|
|
|
|
|
|
|
|
|
|
x = torch.cat((x, r), dim = 1)
|
|
|
|
|
return self.final_conv(x)
|
|
|
|
|
|
|
|
|
|
class LowresConditioner(nn.Module):
|
|
|
|
|
@@ -1755,12 +1770,16 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
unconditional = False,
|
|
|
|
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
|
|
|
|
use_dynamic_thres = False, # from the Imagen paper
|
|
|
|
|
dynamic_thres_percentile = 0.9
|
|
|
|
|
dynamic_thres_percentile = 0.9,
|
|
|
|
|
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
|
|
|
|
p2_loss_weight_k = 1
|
|
|
|
|
):
|
|
|
|
|
super().__init__(
|
|
|
|
|
beta_schedule = beta_schedule,
|
|
|
|
|
timesteps = timesteps,
|
|
|
|
|
loss_type = loss_type
|
|
|
|
|
loss_type = loss_type,
|
|
|
|
|
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
|
|
|
|
p2_loss_weight_k = p2_loss_weight_k
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.unconditional = unconditional
|
|
|
|
|
@@ -2028,7 +2047,13 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
|
|
|
|
|
target = noise if not predict_x_start else x_start
|
|
|
|
|
|
|
|
|
|
loss = self.loss_fn(pred, target)
|
|
|
|
|
loss = self.loss_fn(pred, target, reduction = 'none')
|
|
|
|
|
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
|
|
|
|
|
|
|
|
|
if self.has_p2_loss_reweighting:
|
|
|
|
|
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
|
|
|
|
|
|
|
|
|
|
loss = loss.mean()
|
|
|
|
|
|
|
|
|
|
if not learned_variance:
|
|
|
|
|
# return simple loss if not using learned variance
|
|
|
|
|
|