mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
bring in the cross embed layer from Crossformer paper for initial convolution in unet
This commit is contained in:
@@ -41,6 +41,9 @@ def exists(val):
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def is_odd(n):
|
||||
return (n % 2) == 1
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
@@ -1228,6 +1231,32 @@ class LinearAttention(nn.Module):
|
||||
out = self.nonlin(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class CrossEmbedLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
kernel_sizes,
|
||||
stride = 2
|
||||
):
|
||||
super().__init__()
|
||||
assert all([*map(is_odd, kernel_sizes)])
|
||||
|
||||
kernel_sizes = sorted(kernel_sizes)
|
||||
num_scales = len(kernel_sizes)
|
||||
|
||||
# calculate the dimension at each scale
|
||||
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
|
||||
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
|
||||
|
||||
self.convs = nn.ModuleList([])
|
||||
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
|
||||
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
|
||||
|
||||
def forward(self, x):
|
||||
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
||||
return torch.cat(fmaps, dim = 1)
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1252,6 +1281,7 @@ class Unet(nn.Module):
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
resnet_groups = 8,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1270,10 +1300,9 @@ class Unet(nn.Module):
|
||||
self.channels = channels
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 2)
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
|
||||
assert (init_conv_kernel_size % 2) == 1
|
||||
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||
self.init_conv = CrossEmbedLayer(init_channels, init_dim, init_cross_embed_kernel_sizes, stride = 1)
|
||||
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
Reference in New Issue
Block a user