|
|
|
|
@@ -5,7 +5,6 @@ from functools import partial
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from collections import namedtuple
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
@@ -304,7 +303,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
|
|
|
|
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
|
|
|
|
"""
|
|
|
|
|
steps = timesteps + 1
|
|
|
|
|
x = torch.linspace(0, timesteps, steps)
|
|
|
|
|
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
|
|
|
|
|
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
|
|
|
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
|
|
|
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
|
|
|
|
@@ -315,21 +314,21 @@ def linear_beta_schedule(timesteps):
|
|
|
|
|
scale = 1000 / timesteps
|
|
|
|
|
beta_start = scale * 0.0001
|
|
|
|
|
beta_end = scale * 0.02
|
|
|
|
|
return torch.linspace(beta_start, beta_end, timesteps)
|
|
|
|
|
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def quadratic_beta_schedule(timesteps):
|
|
|
|
|
scale = 1000 / timesteps
|
|
|
|
|
beta_start = scale * 0.0001
|
|
|
|
|
beta_end = scale * 0.02
|
|
|
|
|
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
|
|
|
|
|
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sigmoid_beta_schedule(timesteps):
|
|
|
|
|
scale = 1000 / timesteps
|
|
|
|
|
beta_start = scale * 0.0001
|
|
|
|
|
beta_end = scale * 0.02
|
|
|
|
|
betas = torch.linspace(-6, 6, timesteps)
|
|
|
|
|
betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)
|
|
|
|
|
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -369,17 +368,21 @@ class BaseGaussianDiffusion(nn.Module):
|
|
|
|
|
self.loss_type = loss_type
|
|
|
|
|
self.loss_fn = loss_fn
|
|
|
|
|
|
|
|
|
|
self.register_buffer('betas', betas)
|
|
|
|
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
|
|
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
|
|
|
# register buffer helper function to cast double back to float
|
|
|
|
|
|
|
|
|
|
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
|
|
|
|
|
|
|
|
|
register_buffer('betas', betas)
|
|
|
|
|
register_buffer('alphas_cumprod', alphas_cumprod)
|
|
|
|
|
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
|
|
|
|
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
|
|
|
|
|
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
|
|
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
|
|
|
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
|
|
|
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
|
|
|
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
|
|
|
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
|
|
|
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
|
|
|
|
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
|
|
|
|
|
|
|
|
@@ -387,13 +390,13 @@ class BaseGaussianDiffusion(nn.Module):
|
|
|
|
|
|
|
|
|
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
|
|
|
|
|
|
|
|
self.register_buffer('posterior_variance', posterior_variance)
|
|
|
|
|
register_buffer('posterior_variance', posterior_variance)
|
|
|
|
|
|
|
|
|
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
|
|
|
|
|
|
|
|
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
|
|
|
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
|
|
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
|
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
|
|
|
|
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
|
|
|
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
|
|
|
|
|
|
|
|
def q_mean_variance(self, x_start, t):
|
|
|
|
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
|
|
|
@@ -1229,6 +1232,33 @@ class LinearAttention(nn.Module):
|
|
|
|
|
out = self.nonlin(out)
|
|
|
|
|
return self.to_out(out)
|
|
|
|
|
|
|
|
|
|
class CrossEmbedLayer(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
dim_in,
|
|
|
|
|
kernel_sizes,
|
|
|
|
|
dim_out = None,
|
|
|
|
|
stride = 2
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
|
|
|
|
|
dim_out = default(dim_out, dim_in)
|
|
|
|
|
|
|
|
|
|
kernel_sizes = sorted(kernel_sizes)
|
|
|
|
|
num_scales = len(kernel_sizes)
|
|
|
|
|
|
|
|
|
|
# calculate the dimension at each scale
|
|
|
|
|
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
|
|
|
|
|
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
|
|
|
|
|
|
|
|
|
|
self.convs = nn.ModuleList([])
|
|
|
|
|
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
|
|
|
|
|
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
|
|
|
|
return torch.cat(fmaps, dim = 1)
|
|
|
|
|
|
|
|
|
|
class Unet(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
@@ -1252,8 +1282,10 @@ class Unet(nn.Module):
|
|
|
|
|
cond_on_image_embeds = False,
|
|
|
|
|
init_dim = None,
|
|
|
|
|
init_conv_kernel_size = 7,
|
|
|
|
|
block_type = 'resnet',
|
|
|
|
|
block_resnet_groups = 8,
|
|
|
|
|
resnet_groups = 8,
|
|
|
|
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
|
|
|
|
cross_embed_downsample = False,
|
|
|
|
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -1272,10 +1304,9 @@ class Unet(nn.Module):
|
|
|
|
|
self.channels = channels
|
|
|
|
|
|
|
|
|
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
|
|
|
|
init_dim = default(init_dim, dim // 2)
|
|
|
|
|
init_dim = default(init_dim, dim // 3 * 2)
|
|
|
|
|
|
|
|
|
|
assert (init_conv_kernel_size % 2) == 1
|
|
|
|
|
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
|
|
|
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
|
|
|
|
|
|
|
|
|
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
|
|
|
|
in_out = list(zip(dims[:-1], dims[1:]))
|
|
|
|
|
@@ -1331,7 +1362,15 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# resnet block klass
|
|
|
|
|
|
|
|
|
|
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
|
|
|
|
|
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
|
|
|
|
|
|
|
|
|
assert len(resnet_groups) == len(in_out)
|
|
|
|
|
|
|
|
|
|
# downsample klass
|
|
|
|
|
|
|
|
|
|
downsample_klass = Downsample
|
|
|
|
|
if cross_embed_downsample:
|
|
|
|
|
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
|
|
|
|
|
|
|
|
|
# layers
|
|
|
|
|
|
|
|
|
|
@@ -1339,38 +1378,39 @@ class Unet(nn.Module):
|
|
|
|
|
self.ups = nn.ModuleList([])
|
|
|
|
|
num_resolutions = len(in_out)
|
|
|
|
|
|
|
|
|
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
|
|
|
|
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
|
|
|
|
|
is_first = ind == 0
|
|
|
|
|
is_last = ind >= (num_resolutions - 1)
|
|
|
|
|
layer_cond_dim = cond_dim if not is_first else None
|
|
|
|
|
|
|
|
|
|
self.downs.append(nn.ModuleList([
|
|
|
|
|
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
|
|
|
|
|
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
|
|
|
|
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
|
|
|
|
Downsample(dim_out) if not is_last else nn.Identity()
|
|
|
|
|
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
downsample_klass(dim_out) if not is_last else nn.Identity()
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
mid_dim = dims[-1]
|
|
|
|
|
|
|
|
|
|
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
|
|
|
|
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
|
|
|
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
|
|
|
|
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
|
|
|
|
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
|
|
|
|
|
|
|
|
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
|
|
|
|
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
|
|
|
|
|
is_last = ind >= (num_resolutions - 2)
|
|
|
|
|
layer_cond_dim = cond_dim if not is_last else None
|
|
|
|
|
|
|
|
|
|
self.ups.append(nn.ModuleList([
|
|
|
|
|
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
|
|
|
|
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
|
|
|
|
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
|
|
|
|
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
Upsample(dim_in)
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
out_dim = default(out_dim, channels)
|
|
|
|
|
|
|
|
|
|
self.final_conv = nn.Sequential(
|
|
|
|
|
block_klass(dim, dim),
|
|
|
|
|
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
|
|
|
|
nn.Conv2d(dim, out_dim, 1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -1381,12 +1421,13 @@ class Unet(nn.Module):
|
|
|
|
|
*,
|
|
|
|
|
lowres_cond,
|
|
|
|
|
channels,
|
|
|
|
|
cond_on_image_embeds
|
|
|
|
|
cond_on_image_embeds,
|
|
|
|
|
cond_on_text_encodings
|
|
|
|
|
):
|
|
|
|
|
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
|
|
|
|
|
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
|
|
|
|
|
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
|
|
|
|
|
return self.__class__(**{**self._locals, **updated_kwargs})
|
|
|
|
|
|
|
|
|
|
def forward_with_cond_scale(
|
|
|
|
|
@@ -1582,7 +1623,8 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
clip_x_start = True,
|
|
|
|
|
clip_adapter_overrides = dict()
|
|
|
|
|
clip_adapter_overrides = dict(),
|
|
|
|
|
unconditional = False
|
|
|
|
|
):
|
|
|
|
|
super().__init__(
|
|
|
|
|
beta_schedule = beta_schedule,
|
|
|
|
|
@@ -1590,6 +1632,9 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
loss_type = loss_type
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.unconditional = unconditional
|
|
|
|
|
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
|
|
|
|
|
|
|
|
|
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
|
|
|
|
|
|
|
|
|
self.clip = None
|
|
|
|
|
@@ -1631,7 +1676,8 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
|
|
|
|
|
one_unet = one_unet.cast_model_parameters(
|
|
|
|
|
lowres_cond = not is_first,
|
|
|
|
|
cond_on_image_embeds = is_first,
|
|
|
|
|
cond_on_image_embeds = is_first and not unconditional,
|
|
|
|
|
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
|
|
|
|
channels = unet_channels
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -1766,12 +1812,16 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
@eval_decorator
|
|
|
|
|
def sample(
|
|
|
|
|
self,
|
|
|
|
|
image_embed,
|
|
|
|
|
image_embed = None,
|
|
|
|
|
text = None,
|
|
|
|
|
batch_size = 1,
|
|
|
|
|
cond_scale = 1.,
|
|
|
|
|
stop_at_unet_number = None
|
|
|
|
|
):
|
|
|
|
|
batch_size = image_embed.shape[0]
|
|
|
|
|
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
|
|
|
|
|
|
|
|
|
if not self.unconditional:
|
|
|
|
|
batch_size = image_embed.shape[0]
|
|
|
|
|
|
|
|
|
|
text_encodings = text_mask = None
|
|
|
|
|
if exists(text):
|
|
|
|
|
@@ -1781,10 +1831,11 @@ class Decoder(BaseGaussianDiffusion):
|
|
|
|
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
|
|
|
|
|
|
|
|
|
img = None
|
|
|
|
|
is_cuda = next(self.parameters()).is_cuda
|
|
|
|
|
|
|
|
|
|
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
|
|
|
|
|
|
|
|
|
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
|
|
|
|
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
|
|
|
|
|
|
|
|
|
with context:
|
|
|
|
|
lowres_cond_img = None
|
|
|
|
|
|