|
|
|
|
@@ -45,6 +45,11 @@ def exists(val):
|
|
|
|
|
def identity(t, *args, **kwargs):
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
|
def first(arr, d = None):
|
|
|
|
|
if len(arr) == 0:
|
|
|
|
|
return d
|
|
|
|
|
return arr[0]
|
|
|
|
|
|
|
|
|
|
def maybe(fn):
|
|
|
|
|
@wraps(fn)
|
|
|
|
|
def inner(x):
|
|
|
|
|
@@ -351,7 +356,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
|
|
|
|
steps = timesteps + 1
|
|
|
|
|
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
|
|
|
|
|
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
|
|
|
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
|
|
|
|
alphas_cumprod = alphas_cumprod / first(alphas_cumprod)
|
|
|
|
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
|
|
|
|
return torch.clip(betas, 0, 0.999)
|
|
|
|
|
|
|
|
|
|
@@ -1088,8 +1093,16 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# decoder
|
|
|
|
|
|
|
|
|
|
def Upsample(dim):
|
|
|
|
|
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
|
|
|
|
def ConvTransposeUpsample(dim, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
|
|
|
|
|
|
|
|
|
def NearestUpsample(dim, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
|
|
|
|
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def Downsample(dim, *, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
@@ -1166,7 +1179,7 @@ class ResnetBlock(nn.Module):
|
|
|
|
|
self.block2 = Block(dim_out, dim_out, groups = groups)
|
|
|
|
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(self, x, cond = None, time_emb = None):
|
|
|
|
|
def forward(self, x, time_emb = None, cond = None):
|
|
|
|
|
|
|
|
|
|
scale_shift = None
|
|
|
|
|
if exists(self.time_mlp) and exists(time_emb):
|
|
|
|
|
@@ -1247,20 +1260,6 @@ class CrossAttention(nn.Module):
|
|
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
|
|
|
return self.to_out(out)
|
|
|
|
|
|
|
|
|
|
class GridAttention(nn.Module):
|
|
|
|
|
def __init__(self, *args, window_size = 8, **kwargs):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
self.attn = Attention(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
h, w = x.shape[-2:]
|
|
|
|
|
wsz = self.window_size
|
|
|
|
|
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
|
|
|
|
|
out = self.attn(x)
|
|
|
|
|
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
class LinearAttention(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
@@ -1360,6 +1359,8 @@ class Unet(nn.Module):
|
|
|
|
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
|
|
|
|
memory_efficient = False,
|
|
|
|
|
scale_skip_connection = False,
|
|
|
|
|
nearest_upsample = False,
|
|
|
|
|
final_conv_kernel_size = 1,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -1452,6 +1453,8 @@ class Unet(nn.Module):
|
|
|
|
|
# resnet block klass
|
|
|
|
|
|
|
|
|
|
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
|
|
|
|
top_level_resnet_group = first(resnet_groups)
|
|
|
|
|
|
|
|
|
|
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
|
|
|
|
|
|
|
|
|
assert len(resnet_groups) == len(in_out)
|
|
|
|
|
@@ -1462,23 +1465,36 @@ class Unet(nn.Module):
|
|
|
|
|
if cross_embed_downsample:
|
|
|
|
|
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
|
|
|
|
|
|
|
|
|
# upsample klass
|
|
|
|
|
|
|
|
|
|
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
|
|
|
|
|
|
|
|
|
# give memory efficient unet an initial resnet block
|
|
|
|
|
|
|
|
|
|
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
|
|
|
|
|
|
|
|
|
|
# layers
|
|
|
|
|
|
|
|
|
|
self.downs = nn.ModuleList([])
|
|
|
|
|
self.ups = nn.ModuleList([])
|
|
|
|
|
num_resolutions = len(in_out)
|
|
|
|
|
|
|
|
|
|
skip_connect_dims = [] # keeping track of skip connection dimensions
|
|
|
|
|
|
|
|
|
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
|
|
|
|
is_first = ind == 0
|
|
|
|
|
is_last = ind >= (num_resolutions - 1)
|
|
|
|
|
layer_cond_dim = cond_dim if not is_first else None
|
|
|
|
|
|
|
|
|
|
dim_layer = dim_out if memory_efficient else dim_in
|
|
|
|
|
skip_connect_dims.append(dim_layer)
|
|
|
|
|
|
|
|
|
|
self.downs.append(nn.ModuleList([
|
|
|
|
|
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
|
|
|
|
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
|
|
|
|
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
|
|
|
|
downsample_klass(dim_out) if not is_last and not memory_efficient else None
|
|
|
|
|
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
Residual(LinearAttention(dim_layer, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
|
|
|
|
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
|
|
|
|
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
mid_dim = dims[-1]
|
|
|
|
|
@@ -1491,17 +1507,17 @@ class Unet(nn.Module):
|
|
|
|
|
is_last = ind >= (len(in_out) - 1)
|
|
|
|
|
layer_cond_dim = cond_dim if not is_last else None
|
|
|
|
|
|
|
|
|
|
skip_connect_dim = skip_connect_dims.pop()
|
|
|
|
|
|
|
|
|
|
self.ups.append(nn.ModuleList([
|
|
|
|
|
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
|
|
|
|
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
|
|
|
|
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
|
|
|
|
|
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
|
|
|
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
|
|
|
|
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
|
|
|
|
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
self.final_conv = nn.Sequential(
|
|
|
|
|
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
|
|
|
|
|
nn.Conv2d(dim, self.channels_out, 1)
|
|
|
|
|
)
|
|
|
|
|
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
|
|
|
|
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
|
|
|
|
|
|
|
|
|
# if the current settings for the unet are not correct
|
|
|
|
|
# for cascading DDPM, then reinit the unet with the right settings
|
|
|
|
|
@@ -1665,6 +1681,11 @@ class Unet(nn.Module):
|
|
|
|
|
c = self.norm_cond(c)
|
|
|
|
|
mid_c = self.norm_mid_cond(mid_c)
|
|
|
|
|
|
|
|
|
|
# initial resnet block
|
|
|
|
|
|
|
|
|
|
if exists(self.init_resnet_block):
|
|
|
|
|
x = self.init_resnet_block(x, t)
|
|
|
|
|
|
|
|
|
|
# go through the layers of the unet, down and up
|
|
|
|
|
|
|
|
|
|
hiddens = []
|
|
|
|
|
@@ -1673,38 +1694,41 @@ class Unet(nn.Module):
|
|
|
|
|
if exists(pre_downsample):
|
|
|
|
|
x = pre_downsample(x)
|
|
|
|
|
|
|
|
|
|
x = init_block(x, c, t)
|
|
|
|
|
x = init_block(x, t, c)
|
|
|
|
|
x = sparse_attn(x)
|
|
|
|
|
hiddens.append(x)
|
|
|
|
|
|
|
|
|
|
for resnet_block in resnet_blocks:
|
|
|
|
|
x = resnet_block(x, c, t)
|
|
|
|
|
|
|
|
|
|
hiddens.append(x)
|
|
|
|
|
x = resnet_block(x, t, c)
|
|
|
|
|
hiddens.append(x)
|
|
|
|
|
|
|
|
|
|
if exists(post_downsample):
|
|
|
|
|
x = post_downsample(x)
|
|
|
|
|
|
|
|
|
|
x = self.mid_block1(x, mid_c, t)
|
|
|
|
|
x = self.mid_block1(x, t, mid_c)
|
|
|
|
|
|
|
|
|
|
if exists(self.mid_attn):
|
|
|
|
|
x = self.mid_attn(x)
|
|
|
|
|
|
|
|
|
|
x = self.mid_block2(x, mid_c, t)
|
|
|
|
|
x = self.mid_block2(x, t, mid_c)
|
|
|
|
|
|
|
|
|
|
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
|
|
|
|
|
|
|
|
|
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
|
|
|
|
skip_connect = hiddens.pop() * self.skip_connect_scale
|
|
|
|
|
|
|
|
|
|
x = torch.cat((x, skip_connect), dim = 1)
|
|
|
|
|
x = init_block(x, c, t)
|
|
|
|
|
x = connect_skip(x)
|
|
|
|
|
x = init_block(x, t, c)
|
|
|
|
|
x = sparse_attn(x)
|
|
|
|
|
|
|
|
|
|
for resnet_block in resnet_blocks:
|
|
|
|
|
x = resnet_block(x, c, t)
|
|
|
|
|
x = connect_skip(x)
|
|
|
|
|
x = resnet_block(x, t, c)
|
|
|
|
|
|
|
|
|
|
x = upsample(x)
|
|
|
|
|
|
|
|
|
|
x = torch.cat((x, r), dim = 1)
|
|
|
|
|
return self.final_conv(x)
|
|
|
|
|
|
|
|
|
|
x = self.final_resnet_block(x, t)
|
|
|
|
|
return self.to_out(x)
|
|
|
|
|
|
|
|
|
|
class LowresConditioner(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
@@ -1771,7 +1795,7 @@ class Decoder(nn.Module):
|
|
|
|
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
|
|
|
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
|
|
|
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
|
|
|
|
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
|
|
|
|
blur_sigma = 0.6, # cascading ddpm - blur sigma
|
|
|
|
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
clip_x_start = True,
|
|
|
|
|
@@ -2299,6 +2323,6 @@ class DALLE2(nn.Module):
|
|
|
|
|
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
|
|
|
|
|
|
|
|
|
if one_text:
|
|
|
|
|
return images[0]
|
|
|
|
|
return first(images)
|
|
|
|
|
|
|
|
|
|
return images
|
|
|
|
|
|