mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
add MLP based time conditioning to all convnexts, in addition to cross attention. also add an initial convolution, given convnext first depthwise conv
This commit is contained in:
@@ -922,6 +922,7 @@ class ConvNextBlock(nn.Module):
|
|||||||
dim_out,
|
dim_out,
|
||||||
*,
|
*,
|
||||||
cond_dim = None,
|
cond_dim = None,
|
||||||
|
time_cond_dim = None,
|
||||||
mult = 2,
|
mult = 2,
|
||||||
norm = True
|
norm = True
|
||||||
):
|
):
|
||||||
@@ -940,6 +941,14 @@ class ConvNextBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.time_mlp = None
|
||||||
|
|
||||||
|
if exists(time_cond_dim):
|
||||||
|
self.time_mlp = nn.Sequential(
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(time_cond_dim, dim)
|
||||||
|
)
|
||||||
|
|
||||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||||
|
|
||||||
inner_dim = int(dim_out * mult)
|
inner_dim = int(dim_out * mult)
|
||||||
@@ -952,9 +961,13 @@ class ConvNextBlock(nn.Module):
|
|||||||
|
|
||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, cond = None):
|
def forward(self, x, cond = None, time = None):
|
||||||
h = self.ds_conv(x)
|
h = self.ds_conv(x)
|
||||||
|
|
||||||
|
if exists(time) and exists(self.time_mlp):
|
||||||
|
t = self.time_mlp(time)
|
||||||
|
h = rearrange(t, 'b c -> b c 1 1') + h
|
||||||
|
|
||||||
if exists(self.cross_attn):
|
if exists(self.cross_attn):
|
||||||
assert exists(cond)
|
assert exists(cond)
|
||||||
h = self.cross_attn(h, context = cond) + h
|
h = self.cross_attn(h, context = cond) + h
|
||||||
@@ -1076,22 +1089,33 @@ class Unet(nn.Module):
|
|||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
||||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||||
|
init_dim = dim // 2
|
||||||
|
|
||||||
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
|
self.init_conv = nn.Conv2d(init_channels, init_dim, 7, padding = 3)
|
||||||
|
|
||||||
|
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
|
|
||||||
# time, image embeddings, and optional text encoding
|
# time, image embeddings, and optional text encoding
|
||||||
|
|
||||||
cond_dim = default(cond_dim, dim)
|
cond_dim = default(cond_dim, dim)
|
||||||
|
time_cond_dim = dim * 4
|
||||||
|
|
||||||
self.time_mlp = nn.Sequential(
|
self.to_time_hiddens = nn.Sequential(
|
||||||
SinusoidalPosEmb(dim),
|
SinusoidalPosEmb(dim),
|
||||||
nn.Linear(dim, dim * 4),
|
nn.Linear(dim, time_cond_dim),
|
||||||
nn.GELU(),
|
nn.GELU()
|
||||||
nn.Linear(dim * 4, cond_dim * num_time_tokens),
|
)
|
||||||
|
|
||||||
|
self.to_time_tokens = nn.Sequential(
|
||||||
|
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
|
||||||
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.to_time_cond = nn.Sequential(
|
||||||
|
nn.Linear(time_cond_dim, time_cond_dim)
|
||||||
|
)
|
||||||
|
|
||||||
self.image_to_cond = nn.Sequential(
|
self.image_to_cond = nn.Sequential(
|
||||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
@@ -1133,26 +1157,26 @@ class Unet(nn.Module):
|
|||||||
layer_cond_dim = cond_dim if not is_first else None
|
layer_cond_dim = cond_dim if not is_first else None
|
||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
|
||||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||||
Downsample(dim_out) if not is_last else nn.Identity()
|
Downsample(dim_out) if not is_last else nn.Identity()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
|
|
||||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||||
|
|
||||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
is_last = ind >= (num_resolutions - 2)
|
is_last = ind >= (num_resolutions - 2)
|
||||||
layer_cond_dim = cond_dim if not is_last else None
|
layer_cond_dim = cond_dim if not is_last else None
|
||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||||
Upsample(dim_in)
|
Upsample(dim_in)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@@ -1214,9 +1238,16 @@ class Unet(nn.Module):
|
|||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||||
|
|
||||||
|
# initial convolution
|
||||||
|
|
||||||
|
x = self.init_conv(x)
|
||||||
|
|
||||||
# time conditioning
|
# time conditioning
|
||||||
|
|
||||||
time_tokens = self.time_mlp(time)
|
time_hiddens = self.to_time_hiddens(time)
|
||||||
|
|
||||||
|
time_tokens = self.to_time_tokens(time_hiddens)
|
||||||
|
t = self.to_time_cond(time_hiddens)
|
||||||
|
|
||||||
# conditional dropout
|
# conditional dropout
|
||||||
|
|
||||||
@@ -1283,24 +1314,24 @@ class Unet(nn.Module):
|
|||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
||||||
x = convnext(x, c)
|
x = convnext(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
x = convnext2(x, c)
|
x = convnext2(x, c, t)
|
||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x)
|
x = downsample(x)
|
||||||
|
|
||||||
x = self.mid_block1(x, mid_c)
|
x = self.mid_block1(x, mid_c, t)
|
||||||
|
|
||||||
if exists(self.mid_attn):
|
if exists(self.mid_attn):
|
||||||
x = self.mid_attn(x)
|
x = self.mid_attn(x)
|
||||||
|
|
||||||
x = self.mid_block2(x, mid_c)
|
x = self.mid_block2(x, mid_c, t)
|
||||||
|
|
||||||
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||||
x = convnext(x, c)
|
x = convnext(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
x = convnext2(x, c)
|
x = convnext2(x, c, t)
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
return self.final_conv(x)
|
return self.final_conv(x)
|
||||||
|
|||||||
Reference in New Issue
Block a user