add skip connections for all intermediate resnet blocks, also add an extra resnet block for memory efficient version of unet, time condition for both initial resnet block and last one before output

This commit is contained in:
Phil Wang
2022-06-29 08:16:58 -07:00
parent 46a2558d53
commit 908ab83799
2 changed files with 58 additions and 30 deletions

View File

@@ -45,6 +45,11 @@ def exists(val):
def identity(t, *args, **kwargs):
return t
def first(arr, d = None):
if len(arr) == 0:
return d
return arr[0]
def maybe(fn):
@wraps(fn)
def inner(x):
@@ -351,7 +356,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas_cumprod = alphas_cumprod / first(alphas_cumprod)
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
@@ -1088,8 +1093,12 @@ class DiffusionPrior(nn.Module):
# decoder
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def Upsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, dim_out, 3, padding = 1)
)
def Downsample(dim, *, dim_out = None):
dim_out = default(dim_out, dim)
@@ -1166,7 +1175,7 @@ class ResnetBlock(nn.Module):
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None):
def forward(self, x, time_emb = None, cond = None):
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
@@ -1452,6 +1461,8 @@ class Unet(nn.Module):
# resnet block klass
resnet_groups = cast_tuple(resnet_groups, len(in_out))
top_level_resnet_group = first(resnet_groups)
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
assert len(resnet_groups) == len(in_out)
@@ -1462,23 +1473,32 @@ class Unet(nn.Module):
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# give memory efficient unet an initial resnet block
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
skip_connect_dims = [] # keeping track of skip connection dimensions
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
dim_layer = dim_out if memory_efficient else dim_in
skip_connect_dims.append(dim_layer)
self.downs.append(nn.ModuleList([
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
downsample_klass(dim_out) if not is_last and not memory_efficient else None
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_layer, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
]))
mid_dim = dims[-1]
@@ -1491,17 +1511,17 @@ class Unet(nn.Module):
is_last = ind >= (len(in_out) - 1)
layer_cond_dim = cond_dim if not is_last else None
skip_connect_dim = skip_connect_dims.pop()
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
]))
self.final_conv = nn.Sequential(
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1)
)
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
self.to_out = nn.Conv2d(dim, self.channels_out, 3, padding = 1)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
@@ -1665,6 +1685,11 @@ class Unet(nn.Module):
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# initial resnet block
if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
# go through the layers of the unet, down and up
hiddens = []
@@ -1673,38 +1698,41 @@ class Unet(nn.Module):
if exists(pre_downsample):
x = pre_downsample(x)
x = init_block(x, c, t)
x = init_block(x, t, c)
x = sparse_attn(x)
hiddens.append(x)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
hiddens.append(x)
x = resnet_block(x, t, c)
hiddens.append(x)
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, mid_c, t)
x = self.mid_block1(x, t, mid_c)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_block2(x, mid_c, t)
x = self.mid_block2(x, t, mid_c)
connect_skip = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
skip_connect = hiddens.pop() * self.skip_connect_scale
x = torch.cat((x, skip_connect), dim = 1)
x = init_block(x, c, t)
x = connect_skip(x)
x = init_block(x, t, c)
x = sparse_attn(x)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
x = connect_skip(x)
x = resnet_block(x, t, c)
x = upsample(x)
x = torch.cat((x, r), dim = 1)
return self.final_conv(x)
x = self.final_resnet_block(x, t)
return self.to_out(x)
class LowresConditioner(nn.Module):
def __init__(
@@ -2299,6 +2327,6 @@ class DALLE2(nn.Module):
images = list(map(self.to_pil, images.unbind(dim = 0)))
if one_text:
return images[0]
return first(images)
return images