wrap up cross embed layer feature

This commit is contained in:
Phil Wang
2022-05-10 12:19:34 -07:00
parent 8dc8a3de0d
commit 908088cfea
3 changed files with 15 additions and 9 deletions

View File

@@ -41,9 +41,6 @@ def exists(val):
def identity(t, *args, **kwargs):
return t
def is_odd(n):
return (n % 2) == 1
def default(val, d):
if exists(val):
return val
@@ -1235,12 +1232,13 @@ class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
dim_out,
kernel_sizes,
dim_out = None,
stride = 2
):
super().__init__()
assert all([*map(is_odd, kernel_sizes)])
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
@@ -1282,6 +1280,8 @@ class Unet(nn.Module):
init_conv_kernel_size = 7,
resnet_groups = 8,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
**kwargs
):
super().__init__()
@@ -1302,7 +1302,7 @@ class Unet(nn.Module):
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = CrossEmbedLayer(init_channels, init_dim, init_cross_embed_kernel_sizes, stride = 1)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
@@ -1362,6 +1362,12 @@ class Unet(nn.Module):
assert len(resnet_groups) == len(in_out)
# downsample klass
downsample_klass = Downsample
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# layers
self.downs = nn.ModuleList([])
@@ -1377,7 +1383,7 @@ class Unet(nn.Module):
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Downsample(dim_out) if not is_last else nn.Identity()
downsample_klass(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]