|
|
|
|
@@ -1550,6 +1550,7 @@ class Unet(nn.Module):
|
|
|
|
|
init_conv_kernel_size = 7,
|
|
|
|
|
resnet_groups = 8,
|
|
|
|
|
num_resnet_blocks = 2,
|
|
|
|
|
init_cross_embed = True,
|
|
|
|
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
|
|
|
|
cross_embed_downsample = False,
|
|
|
|
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
|
|
|
|
@@ -1578,7 +1579,7 @@ class Unet(nn.Module):
|
|
|
|
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
|
|
|
|
init_dim = default(init_dim, dim)
|
|
|
|
|
|
|
|
|
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
|
|
|
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
|
|
|
|
|
|
|
|
|
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
|
|
|
|
in_out = list(zip(dims[:-1], dims[1:]))
|
|
|
|
|
@@ -1731,7 +1732,10 @@ class Unet(nn.Module):
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
|
|
|
|
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
|
|
|
|
|
|
|
|
|
out_dim_in = dim + (channels if lowres_cond else 0)
|
|
|
|
|
|
|
|
|
|
self.to_out = nn.Conv2d(out_dim_in, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
|
|
|
|
|
|
|
|
|
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
|
|
|
|
|
|
|
|
|
@@ -1923,7 +1927,7 @@ class Unet(nn.Module):
|
|
|
|
|
hiddens.append(x)
|
|
|
|
|
|
|
|
|
|
x = attn(x)
|
|
|
|
|
hiddens.append(x)
|
|
|
|
|
hiddens.append(x.contiguous())
|
|
|
|
|
|
|
|
|
|
if exists(post_downsample):
|
|
|
|
|
x = post_downsample(x)
|
|
|
|
|
@@ -1951,6 +1955,10 @@ class Unet(nn.Module):
|
|
|
|
|
x = torch.cat((x, r), dim = 1)
|
|
|
|
|
|
|
|
|
|
x = self.final_resnet_block(x, t)
|
|
|
|
|
|
|
|
|
|
if exists(lowres_cond_img):
|
|
|
|
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
|
|
|
|
|
|
|
|
|
return self.to_out(x)
|
|
|
|
|
|
|
|
|
|
class LowresConditioner(nn.Module):
|
|
|
|
|
@@ -2165,6 +2173,7 @@ class Decoder(nn.Module):
|
|
|
|
|
# random crop sizes (for super-resoluting unets at the end of cascade?)
|
|
|
|
|
|
|
|
|
|
self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
|
|
|
|
|
assert not exists(self.random_crop_sizes[0]), 'you would not need to randomly crop the image for the base unet'
|
|
|
|
|
|
|
|
|
|
# predict x0 config
|
|
|
|
|
|
|
|
|
|
|