mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
adopt similar unet architecture as imagen
This commit is contained in:
@@ -1084,8 +1084,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
def Upsample(dim):
|
def Upsample(dim):
|
||||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||||
|
|
||||||
def Downsample(dim):
|
def Downsample(dim, *, dim_out = None):
|
||||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
dim_out = default(dim_out, dim)
|
||||||
|
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
@@ -1370,7 +1371,7 @@ class Unet(nn.Module):
|
|||||||
self.channels_out = default(channels_out, channels)
|
self.channels_out = default(channels_out, channels)
|
||||||
|
|
||||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||||
init_dim = default(init_dim, dim // 3 * 2)
|
init_dim = default(init_dim, dim)
|
||||||
|
|
||||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
||||||
|
|
||||||
@@ -1461,10 +1462,10 @@ class Unet(nn.Module):
|
|||||||
layer_cond_dim = cond_dim if not is_first else None
|
layer_cond_dim = cond_dim if not is_first else None
|
||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
downsample_klass(dim_in, dim_out = dim_out),
|
||||||
|
ResnetBlock(dim_out, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
|
||||||
]))
|
]))
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
@@ -1473,7 +1474,7 @@ class Unet(nn.Module):
|
|||||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||||
|
|
||||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||||
is_last = ind >= (num_resolutions - 2)
|
is_last = ind >= (num_resolutions - 2)
|
||||||
layer_cond_dim = cond_dim if not is_last else None
|
layer_cond_dim = cond_dim if not is_last else None
|
||||||
|
|
||||||
@@ -1654,7 +1655,8 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
|
for downsample, init_block, sparse_attn, resnet_blocks in self.downs:
|
||||||
|
x = downsample(x)
|
||||||
x = init_block(x, c, t)
|
x = init_block(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
|
|
||||||
@@ -1662,7 +1664,6 @@ class Unet(nn.Module):
|
|||||||
x = resnet_block(x, c, t)
|
x = resnet_block(x, c, t)
|
||||||
|
|
||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x)
|
|
||||||
|
|
||||||
x = self.mid_block1(x, mid_c, t)
|
x = self.mid_block1(x, mid_c, t)
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.7.1'
|
__version__ = '0.8.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user